[06/10] Backend: Workgroup reduce optimization using SLM

Submitted by Grigore Lupescu on March 31, 2016, 3:28 p.m.

Details

Message ID 1459438083-31957-1-git-send-email-grigore.lupescu@intel.com
State New
Headers show
Series "Series without cover letter" ( rev: 1 ) in Beignet

Not browsing as part of any series.

Commit Message

Grigore Lupescu March 31, 2016, 3:28 p.m.
From: Grigore Lupescu <grigore.lupescu@intel.com>

Signed-off-by: Grigore Lupescu <grigore.lupescu@intel.com>
---
 backend/src/backend/gen_context.cpp        | 313 ++++++++++++-----------------
 backend/src/backend/gen_insn_selection.cpp | 199 ++++++------------
 backend/src/llvm/llvm_gen_backend.cpp      |  13 +-
 3 files changed, 196 insertions(+), 329 deletions(-)

Patch hide | download patch | download mbox

diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index e55c25e..84df912 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2346,26 +2346,22 @@  namespace gbe
     p->TYPED_WRITE(header, true, bti);
   }
 
-  static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
-      uint32_t simd, uint32_t wg_op, GenEncoder *p) {
-    p->push();
-    p->curr.predicate = GEN_PREDICATE_NONE;
-    p->curr.noMask = 1;
-    p->curr.execWidth = 1;
-
-    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+  static void workgroupOpBetweenThread(GenRegister partialRes,
+                                       GenRegister value,
+                                       uint32_t wg_op,
+                                       GenEncoder *p)
+  {
+    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN
+        || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
       uint32_t cond;
       if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
         cond = GEN_CONDITIONAL_LE;
       else
         cond = GEN_CONDITIONAL_GE;
-
-      p->SEL_CMP(cond, msgData, threadData, msgData);
+      p->SEL_CMP(cond, partialRes, partialRes, value);
     } else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
-      p->ADD(msgData, threadData, msgData);
+      p->ADD(partialRes, partialRes, value);
     }
-
-    p->pop();
   }
 
   static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) {
@@ -2441,209 +2437,146 @@  namespace gbe
           }
         }
       }
-    } else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
+    }
+    else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
+      tmp.hstride = GEN_HORIZONTAL_STRIDE_1;
+      tmp.vstride = GEN_VERTICAL_STRIDE_4;
+      tmp.width = 2;
+
       GBE_ASSERT(tmp.type == theVal.type);
-      GenRegister v = GenRegister::toUniform(tmp, theVal.type);
-      for (uint32_t i = 0; i < simd; i++) {
-        p->ADD(threadData, threadData, v);
-        v.subnr += typeSize(theVal.type);
-        if (v.subnr == 32) {
-          v.subnr = 0;
-          v.nr++;
-        }
+      GenRegister partialSum = tmp;
+
+      /* adjust offset, compute add with ADD4/ADD */
+      for (uint32_t i = 1; i < simd/4; i++){
+        p->push(); {
+          tmp = tmp.suboffset(tmp, 4);
+          p->curr.execWidth = 4;
+          p->ADD(partialSum, partialSum, tmp);
+        } p->pop();
       }
-    }
 
+      for (uint32_t i = 0; i < 4; i++){
+        p->push(); {
+        p->ADD(threadData, threadData, partialSum);
+        partialSum = partialSum.suboffset(partialSum, 1);
+        } p->pop();
+      }
+    }
     p->pop();
-  }
+}
 
-#define SEND_RESULT_MSG() \
-do { \
-  p->push(); { /* then send msg. */ \
-    p->curr.noMask = 1; \
-    p->curr.predicate = GEN_PREDICATE_NONE; \
-    p->curr.execWidth = 1; \
-    GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \
-    offLen.vstride = GEN_VERTICAL_STRIDE_0; \
-    offLen.width = GEN_WIDTH_1; \
-    offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \
-    uint32_t szEnc = typeSize(theVal.type) >> 1; \
-    if (szEnc == 4) { \
-      szEnc = 3; \
-    } \
-    p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \
-    \
-    GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \
-    tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \
-    tidEuid.width = GEN_WIDTH_1; \
-    tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \
-    p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \
-    \
-    p->curr.execWidth = 8; \
-    p->FWD_GATEWAY_MSG(nextThreadID, 2); \
-  } p->pop(); \
-} while(0)
-
-
-  /* The basic idea is like this:
-     1. All the threads firstly calculate the max/min/add value within their own thread, that is finding
-        the max/min/add value within their 16 work items when SIMD == 16.
-     2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated
-        result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding.
-     3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it
-        compares the result in the message and the result within its thread, then forward the correct result to
-        the next thread by sending a message again. If it is the last thread, send it to thread 0.
-     4. Thread 0 finally get the message from the last one and broadcast the final result. */
-  void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) {
+/**
+ * Basic idea:
+ * 1. All the threads firstly calculate the max/min/add value for the
+ * allocated work-items. SIMD16=> 16 work-items allocated for each thread
+ * 2. Each thread will write the computed reduce OP result in SLM memory
+ * based on the threadId
+ * 3. After a memory fence, each thread will read in chunks of 4 elements,
+ * the SLM region, using a loop based on the thread count value (threadN)
+ * 4. At the end each thread has the final value computed individually
+ */
+  void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn){
     const GenRegister dst = ra->genReg(insn.dst(0));
-    const GenRegister tmp = ra->genReg(insn.dst(2));
-    GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag);
-    GenRegister nextThreadID = ra->genReg(insn.src(1));
+    const GenRegister tmp = ra->genReg(insn.dst(1));
+
     const GenRegister theVal = ra->genReg(insn.src(0));
-    GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid));
-    GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn));
-    GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward.
-    msgData.vstride = GEN_VERTICAL_STRIDE_0;
-    msgData.width = GEN_WIDTH_1;
-    msgData.hstride = GEN_HORIZONTAL_STRIDE_0;
-    GenRegister threadData =
-      GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread.
-    threadData.vstride = GEN_VERTICAL_STRIDE_0;
-    threadData.width = GEN_WIDTH_1;
-    threadData.hstride = GEN_HORIZONTAL_STRIDE_0;
+    GenRegister threadData = ra->genReg(insn.src(1));
+    GenRegister partialData = GenRegister::toUniform(threadData, dst.type);
+    GenRegister threadId = ra->genReg(insn.src(2));
+    GenRegister threadNum = ra->genReg(insn.src(3));
+
+    threadId = GenRegister::toUniform(threadId, GEN_TYPE_UD);
+    threadNum = GenRegister::toUniform(threadNum, GEN_TYPE_UD);
+
     uint32_t wg_op = insn.extra.workgroupOp;
     uint32_t simd = p->curr.execWidth;
-    GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW);
-    flag_save.vstride = GEN_VERTICAL_STRIDE_0;
-    flag_save.width = GEN_WIDTH_1;
-    flag_save.hstride = GEN_HORIZONTAL_STRIDE_0;
-    int32_t jip;
-    int32_t oneThreadJip = -1;
-
-    p->push(); { /* First, so something within thread. */
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      /* Do some calculation within each thread. */
-      workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
-    } p->pop();
+    int32_t jip0, jip1;
 
-    /* If we are the only one thread, no need to send msg, just broadcast the result.*/
-    p->push(); {
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+    GenRegister result = GenRegister::offset(dst, 0, 16);
+    result = GenRegister::toUniform(result, dst.type);
 
-      /* Broadcast result. */
-      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.inversePredicate = 1;
-        p->MOV(flag_save, GenRegister::immuw(0x0));
-        p->curr.inversePredicate = 0;
-        p->MOV(flag_save, GenRegister::immuw(0xffff));
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->MOV(flagReg, flag_save);
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.execWidth = simd;
-        p->MOV(dst, threadData);
-      }
+    /* Use of continuous GRF allocation from insn selection */
+    GenRegister msg = GenRegister::retype(ra->genReg(insn.dst(2)), dst.type);
+    GenRegister msgSlmOff = GenRegister::retype(ra->genReg(insn.src(4)), GEN_TYPE_UD);
+    GenRegister msgAddr = GenRegister::retype(GenRegister::offset(msg, 0), GEN_TYPE_UD);
+    GenRegister msgData = GenRegister::retype(GenRegister::offset(msg, 1), dst.type);
 
-      /* Bail out. */
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 0;
-      p->curr.execWidth = 1;
-      oneThreadJip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-    } p->pop();
+    /* Do some calculation within each thread */
+    workgroupOpInThread(msg, theVal, threadData, tmp, simd, wg_op, p);
+    threadData = GenRegister::toUniform(threadData, dst.type);
 
-    p->push(); {
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
+    /* Store thread count for future use on read/write to SLM */
+    GenRegister threadN = GenRegister::retype(tmp, GEN_TYPE_UD);
+    p->MOV(threadN, ra->genReg(GenRegister::ud1grf(ir::ocl::threadn)));
 
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 1;
-      p->MOV(flag_save, GenRegister::immuw(0x0));
-      p->curr.inversePredicate = 0;
-      p->MOV(flag_save, GenRegister::immuw(0xffff));
+    /* All threads write the partial results to SLM memory */
+    p->curr.execWidth = 8;
+    p->MOV(msgData, threadData);
+    p->MUL(msgAddr, threadId, GenRegister::immd(0x4));
+    p->ADD(msgAddr, msgAddr, msgSlmOff);
+    p->UNTYPED_WRITE(msg, GenRegister::immw(0xFE), 1);
 
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(flagReg, flag_save);
-    } p->pop();
+    /* Init partialData register, it will hold the final result */
+    initValue(p, partialData, wg_op);
 
-    p->push(); {
+    /* Thread 0 will write extra elements for future reads in chunks of 4 */
+    p->push();{
       p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-
-      /* threadid 0, send the msg and wait */
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->curr.inversePredicate = 1;
+      p->curr.flag = 0;
+      p->curr.subFlag = 1;
+      p->CMP(GEN_CONDITIONAL_EQ, threadId, GenRegister::immd(0x0));
       p->curr.predicate = GEN_PREDICATE_NORMAL;
-      jip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(msgData, threadData);
-      SEND_RESULT_MSG();
-      p->WAIT(2);
-      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+      p->curr.execWidth = 8;
+      p->MOV(msgData.offset(msgData, 0), partialData);
+      p->MOV(msgData.offset(msgData, 1), partialData);
+      p->MOV(msgData.offset(msgData, 2), partialData);
+      p->MUL(msgAddr, threadN, GenRegister::immd(0x4));
+      p->ADD(msgAddr, msgAddr, msgSlmOff);
+      p->UNTYPED_WRITE(msg, GenRegister::immw(0xFE), 3);
+    } p->pop();
 
-      /* Others wait and send msg, and do something when we get the msg. */
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 0;
-      jip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->WAIT(2);
-      workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
-      SEND_RESULT_MSG();
-      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+    /* Round threadN to nearest upper number divisible with 4 required for
+     * reading in chunks of 4 elements from SLM */
+    p->ADD(threadN, threadN, GenRegister::immd(0x3));
+    p->SHR(threadN, threadN, GenRegister::immd(0x2));
+    p->SHL(threadN, threadN, GenRegister::immd(0x2));
 
-      /* Restore the flag. */
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(flagReg, flag_save);
-    } p->pop();
+    /* Wait for all writes to complete in work-group */
+    p->FENCE(tmp);
 
-    /* Broadcast the result. */
-    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX
-        || wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
-      p->push(); {
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.noMask = 1;
-        p->curr.execWidth = 1;
-        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-        p->curr.inversePredicate = 0;
+    /* Perform a loop, based on thread count (which is now multiple of 4) */
+    p->push();{
+      jip0 = p->n_instruction();
 
-        /* Not the first thread, wait for msg first. */
-        jip = p->n_instruction();
-        p->JMPI(GenRegister::immud(0));
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->WAIT(2);
-        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-    
-        /* Do something when get the msg. */
-        p->curr.execWidth = simd;
-        p->MOV(dst, msgData);
+      p->curr.execWidth = 8;
+      p->curr.predicate = GEN_PREDICATE_NONE;
 
-        p->curr.execWidth = 8;
-        p->FWD_GATEWAY_MSG(nextThreadID, 2);
+      /* Read in chunks of 4 to optimize SLM reads and reduce SEND messages */
+      p->ADD(threadN, threadN, GenRegister::immd(-4));
+      p->MUL(msgAddr, threadN, GenRegister::immd(0x4));
+      p->ADD(msgAddr, msgAddr, msgSlmOff);
+      p->UNTYPED_READ(msgData, msgAddr, GenRegister::immw(0xFE), 4);
 
-        p->curr.execWidth = 1;
-        p->curr.inversePredicate = 1;
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
+      /* Perform operation, process 4 elements, partialData will hold result */
+      workgroupOpBetweenThread(partialData, msgData.offset(msgData, 0), wg_op, p);
+      workgroupOpBetweenThread(partialData, msgData.offset(msgData, 1), wg_op, p);
+      workgroupOpBetweenThread(partialData, msgData.offset(msgData, 2), wg_op, p);
+      workgroupOpBetweenThread(partialData, msgData.offset(msgData, 3), wg_op, p);
 
-        /* The first thread, the last one will notify us. */
-        jip = p->n_instruction();
-        p->JMPI(GenRegister::immud(0));
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->WAIT(2);
-        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-      } p->pop();
-    }
+      /* While threadN is not 0, cycle read SLM / update value */
+      p->curr.noMask = 1;
+      p->curr.flag = 0;
+      p->curr.subFlag = 1;
+      p->CMP(GEN_CONDITIONAL_G, threadN, GenRegister::immd(0x0));
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      jip1 = p->n_instruction();
+      p->JMPI(GenRegister::immud(0));
+      p->patchJMPI(jip1, jip0 - jip1, 0);
+    } p->pop();
 
-    if (oneThreadJip >=0)
-      p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0);
+    /* Save result to final register location dst */
+    p->curr.execWidth = 16;
+    p->MOV(dst, partialData);
   }
 
   void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) {
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 77614b6..024c3e7 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -699,8 +699,10 @@  namespace gbe
     /*! double division */
     void F64DIV(Reg dst, Reg src0, Reg src1, GenRegister* tmp, int tmpNum);
     /*! Work Group Operations */
-    void WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister nextThreadID,
-                     GenRegister threadID, GenRegister threadn, GenRegister tmp);
+    void WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister data,
+                      GenRegister threadId, GenRegister threadN,
+                      GenRegister tmp, GenRegister slmOff, vector<GenRegister> msg,
+                      uint32_t msgSizeReq);
     /* common functions for both binary instruction and sel_cmp and compare instruction.
        It will handle the IMM or normal register assignment, and will try to avoid LOADI
        as much as possible. */
@@ -2040,18 +2042,38 @@  namespace gbe
       insn->dst(i + 1) = tmp[i];
   }
 
-  void Selection::Opaque::WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister nextThreadID,
-                      GenRegister threadID, GenRegister threadn, GenRegister tmp) {
-    SelectionInstruction *insn = this->appendInsn(SEL_OP_WORKGROUP_OP, 3, 4);
+  void Selection::Opaque::WORKGROUP_OP(uint32_t wg_op,
+                                       Reg dst,
+                                       GenRegister src,
+                                       GenRegister data,
+                                       GenRegister threadId,
+                                       GenRegister threadN,
+                                       GenRegister tmp,
+                                       GenRegister slmOff,
+                                       vector<GenRegister> msg,
+                                       uint32_t msgSizeReq = 6)
+  {
+    SelectionInstruction *insn = this->appendInsn(SEL_OP_WORKGROUP_OP, 2 + msg.size(), 5);
+    SelectionVector *vector = this->appendVector();
+
+    /* allocate continuous GRF registers for READ/WRITE to SLM */
+    GBE_ASSERT(msg.size() >= msgSizeReq);
+    vector->regNum = msg.size();
+    vector->offsetID = 0;
+    vector->reg = &insn->dst(2);
+    vector->isSrc = 0;
     insn->extra.workgroupOp = wg_op;
+
     insn->dst(0) = dst;
-    insn->dst(1) = nextThreadID;
-    insn->dst(2) = tmp;
+    insn->dst(1) = tmp;
+    for(uint32_t i = 0; i < msg.size(); i++)
+      insn->dst(2 + i) = msg[i];
 
     insn->src(0) = src;
-    insn->src(1) = nextThreadID;
-    insn->src(2) = threadID;
-    insn->src(3) = threadn;
+    insn->src(1) = data;
+    insn->src(2) = threadId;
+    insn->src(3) = threadN;
+    insn->src(4) = slmOff;
   }
 
   // Boiler plate to initialize the selection library at c++ pre-main
@@ -6424,106 +6446,37 @@  namespace gbe
   /*! WorkGroup instruction pattern */
   DECL_PATTERN(WorkGroupInstruction)
   {
-    INLINE bool storeThreadID(Selection::Opaque &sel, uint32_t slmAddr) const
+    /* SLM bassed communication between threads, most of the logic bellow */
+    INLINE bool emitWGReduce(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn) const
     {
       using namespace ir;
-      GenRegister sr0_0 = GenRegister::retype(GenRegister::sr(0), GEN_TYPE_UW);
-      const uint32_t simdWidth = sel.ctx.getSimdWidth();
-      GenRegister tmp;
-      GenRegister addr;
-      vector<GenRegister> fakeTemps;
-
-      if (simdWidth == 16) {
-        tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_WORD), ir::TYPE_U16), GEN_TYPE_UD);
-        addr = GenRegister::retype(sel.selReg(sel.reg(FAMILY_WORD), ir::TYPE_U16), GEN_TYPE_UD);
-      } else {
-        tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32), GEN_TYPE_UD);
-        addr = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32), GEN_TYPE_UD);
-      }
-
-      sr0_0 = GenRegister::vec1(sr0_0);
-      sel.push(); {
-        sel.curr.predicate = GEN_PREDICATE_NONE;
-        sel.curr.noMask = 1;
-        sel.curr.execWidth = 8;
-
-        sel.MOV(tmp, sr0_0);
+      const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
 
-        sel.MUL(addr, sel.selReg(ocl::threadid, ir::TYPE_U32), GenRegister::immud(2));
-        sel.ADD(addr, addr, GenRegister::immud(slmAddr));
+      const Type type = insn.getType();
+      GenRegister dst = sel.selReg(insn.getDst(0), type);
+      GenRegister src = sel.selReg(insn.getSrc(2), type);
+      GenRegister threadId = sel.selReg(ocl::threadid, ir::TYPE_U32);
+      GenRegister threadN = sel.selReg(ocl::threadn, ir::TYPE_U32);
+      const uint32_t srcNum = insn.getSrcNum();
 
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.push(); {
-            sel.curr.execWidth = 1;
-            sel.MOV(GenRegister::flag(0, 1), GenRegister::immuw(0x01));
-          } sel.pop();
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-          sel.BYTE_SCATTER(addr, tmp, 1, GenRegister::immw(0xfe), fakeTemps);
-        } sel.pop();
-      } sel.pop();
+      GBE_ASSERT(srcNum == 3);
+      GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
+      GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
+      GenRegister tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD)), type);
+      GenRegister data = sel.selReg(sel.reg(FAMILY_WORD), type);
+      GenRegister slmOff = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+
+      vector<GenRegister> msg;
+      for(uint32_t i = 0; i < 6; i++)
+        msg.push_back(sel.selReg(sel.reg(FAMILY_DWORD), type));
+
+      /* compute individual slice of workitems, (e.g. 0->16 workitems) */
+      sel.MOV(slmOff, GenRegister::immud(insn.getSlmAddr()));
+      sel.WORKGROUP_OP(workGroupOp, dst, src, data, threadId,
+                       threadN, tmp, slmOff, msg);
       return true;
     }
 
-    INLINE GenRegister getNextThreadID(Selection::Opaque &sel, uint32_t slmAddr) const
-    {
-      using namespace ir;
-      const uint32_t simdWidth = sel.ctx.getSimdWidth();
-      GenRegister addr;
-      GenRegister nextThread;
-      GenRegister tid;
-      vector<GenRegister> fakeTemps;
-
-      if (simdWidth == 16) {
-        addr = GenRegister::retype(sel.selReg(sel.reg(FAMILY_WORD), ir::TYPE_U16), GEN_TYPE_UD);
-        nextThread = GenRegister::retype(sel.selReg(sel.reg(FAMILY_WORD), ir::TYPE_U16), GEN_TYPE_UD);
-        tid = GenRegister::retype(sel.selReg(sel.reg(FAMILY_WORD), ir::TYPE_U16), GEN_TYPE_UD);
-      } else {
-        addr = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        nextThread = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        tid = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-      }
-
-      sel.push(); {
-        sel.curr.execWidth = 8;
-        sel.curr.predicate = GEN_PREDICATE_NONE;
-        sel.curr.noMask = 1;
-        sel.ADD(nextThread, sel.selReg(ocl::threadid, ir::TYPE_U32), GenRegister::immud(1));
-
-        /* Wrap the next thread id. */
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.CMP(GEN_CONDITIONAL_EQ, nextThread, sel.selReg(ocl::threadn, ir::TYPE_U32), GenRegister::null());
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-          sel.MOV(nextThread, GenRegister::immud(0));
-        } sel.pop();
-
-        sel.MUL(addr, nextThread, GenRegister::immud(2));
-        sel.ADD(addr, addr, GenRegister::immud(slmAddr));
-
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.push(); {
-            sel.curr.execWidth = 1;
-            sel.MOV(GenRegister::flag(0, 1), GenRegister::immuw(0x010));
-          } sel.pop();
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-          sel.BYTE_GATHER(tid, addr, 1, GenRegister::immw(0xfe), fakeTemps);
-        } sel.pop();
-
-      } sel.pop();
-      return tid;
-    }
-
     INLINE bool emitWGBroadcast(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn) const {
       /*  1. BARRIER    Ensure all the threads have set the correct value for the var which will be broadcasted.
           2. CMP IDs    Compare the local IDs with the specified ones in the function call.
@@ -6595,42 +6548,14 @@  namespace gbe
       using namespace ir;
       const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
 
-      if (workGroupOp == WORKGROUP_OP_BROADCAST) {
+      if (workGroupOp == WORKGROUP_OP_BROADCAST){
         return emitWGBroadcast(sel, insn);
-      } else if (workGroupOp >= WORKGROUP_OP_REDUCE_ADD && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
-        const uint32_t slmAddr = insn.getSlmAddr();
-        /* First, we create the TheadID/localID map, in order to get which thread hold the next 16 workitems. */
-
-        if (!sel.storeThreadMap) {
-          this->storeThreadID(sel, slmAddr);
-          sel.storeThreadMap = true;
-        }
-
-        /* Then we insert a barrier to make sure all the var we are interested in
-           have been assigned the final value. */
-        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)), sel.selReg(sel.reg(FAMILY_DWORD)), syncLocalBarrier);
-
-        /* Third, get the next thread ID which we will Forward MSG to. */
-        GenRegister nextThreadID = getNextThreadID(sel, slmAddr);
-        GenRegister threadID = sel.selReg(ocl::threadid, ir::TYPE_U32);
-        GenRegister threadNum = sel.selReg(ocl::threadn, ir::TYPE_U32);
-        GenRegister tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD)), GEN_TYPE_UD);
-
-        const Type type = insn.getType();
-        const GenRegister dst = sel.selReg(insn.getDst(0), type);
-        const uint32_t srcNum = insn.getSrcNum();
-        GBE_ASSERT(srcNum == 3);
-        GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
-        GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
-        GenRegister src = sel.selReg(insn.getSrc(2), type);
-        sel.push(); {
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.WORKGROUP_OP(workGroupOp, dst, src, nextThreadID, threadID, threadNum, tmp);
-        } sel.pop();
-      } else {
-        GBE_ASSERT(0);
       }
+      else if (workGroupOp >= WORKGROUP_OP_REDUCE_ADD && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX){
+        return emitWGReduce(sel, insn);
+      }
+      else
+        GBE_ASSERT(0);
 
       return true;
     }
diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp
index 06d6760..951fa0e 100644
--- a/backend/src/llvm/llvm_gen_backend.cpp
+++ b/backend/src/llvm/llvm_gen_backend.cpp
@@ -3968,13 +3968,22 @@  namespace gbe
     }
 
     if (f.gettidMapSLM() < 0 && opcode >= ir::WORKGROUP_OP_REDUCE_ADD && opcode <= ir::WORKGROUP_OP_EXCLUSIVE_MAX) {
-      /* Because we can not know the thread ID and the EUID for every physical
+      /* 1. For thread SLM based communication (default):
+       * Threads will use SLM to write partial results computed individually
+         and then read the whole set. Because the read is done in chunks of 4
+         extra padding is required.
+
+       * 2. For thread message based communication:
+       * Because we can not know the thread ID and the EUID for every physical
          thead which the work items execute on before the run time. We need to
          sync the thread execution order when using work group functions. We
          create the workitems/threadID map table in slm.
+
          When we come to here, the global thread local vars should have all been
          allocated, so it's safe for us to steal a piece of SLM for this usage. */
-      uint32_t mapSize = sizeof(uint16_t) * 64;// at most 64 thread for one subslice.
+
+      // at most 64 thread for one subslice, along with extra padding
+      uint32_t mapSize = sizeof(uint32_t) * (64 + 4);
       f.setUseSLM(true);
       uint32_t oldSlm = f.getSLMSize();
       f.setSLMSize(oldSlm + mapSize);