Updated workgroup reduce optimization using SLM

Submitted by Grigore Lupescu on Feb. 14, 2016, 7:55 a.m.

Details

Message ID 1455436536-7976-1-git-send-email-grigore.lupescu@intel.com
State New
Headers show
Series "Series without cover letter" ( rev: 5 ) in Beignet

Not browsing as part of any series.

Commit Message

Grigore Lupescu Feb. 14, 2016, 7:55 a.m.
Signed-off-by: Grigore Lupescu <grigore.lupescu@intel.com>
---
 backend/src/backend/gen_context.cpp        | 259 ++++++++++-------------------
 backend/src/backend/gen_insn_selection.cpp | 227 ++++++-------------------
 2 files changed, 140 insertions(+), 346 deletions(-)

Patch hide | download patch | download mbox

diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 193494d..601a4ff 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2848,26 +2848,22 @@  namespace gbe
     } p->pop();
   }
 
-  static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
-      uint32_t simd, uint32_t wg_op, GenEncoder *p) {
-    p->push();
-    p->curr.predicate = GEN_PREDICATE_NONE;
-    p->curr.noMask = 1;
-    p->curr.execWidth = 1;
-
-    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+  static void workgroupOpBetweenThread(GenRegister partialRes,
+                                       GenRegister value,
+                                       uint32_t wg_op,
+                                       GenEncoder *p)
+  {
+    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN
+        || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
       uint32_t cond;
       if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
         cond = GEN_CONDITIONAL_LE;
       else
         cond = GEN_CONDITIONAL_GE;
-
-      p->SEL_CMP(cond, msgData, threadData, msgData);
+      p->SEL_CMP(cond, partialRes, partialRes, value);
     } else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
-      p->ADD(msgData, threadData, msgData);
+      p->ADD(partialRes, partialRes, value);
     }
-
-    p->pop();
   }
 
   static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) {
@@ -2947,24 +2943,25 @@  namespace gbe
     else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
       tmp.hstride = GEN_HORIZONTAL_STRIDE_1;
       tmp.vstride = GEN_VERTICAL_STRIDE_4;
-      tmp.width = GEN_WIDTH_4;
+      tmp.width = 2;
 
       GBE_ASSERT(tmp.type == theVal.type);
       GenRegister partialSum = tmp;
 
       /* adjust offset, compute add with ADD4/ADD */
       for (uint32_t i = 1; i < simd/4; i++){
-        tmp = tmp.suboffset(tmp, 4);
-        p->push();
-        p->curr.execWidth = GEN_WIDTH_16;
-        p->ADD(partialSum, partialSum, tmp);
-        p->pop();
+        p->push(); {
+          tmp = tmp.suboffset(tmp, 4);
+          p->curr.execWidth = 4;
+          p->ADD(partialSum, partialSum, tmp);
+        } p->pop();
       }
 
       for (uint32_t i = 0; i < 4; i++){
-        partialSum.width = GEN_WIDTH_1;
+        p->push(); {
         p->ADD(threadData, threadData, partialSum);
-        partialSum = GenRegister::suboffset(partialSum, 1);
+        partialSum = partialSum.suboffset(partialSum, 1);
+        } p->pop();
       }
     }
     p->pop();
@@ -2997,171 +2994,93 @@  do { \
   } p->pop(); \
 } while(0)
 
-
-  /* The basic idea is like this:
-     1. All the threads firstly calculate the max/min/add value within their own thread, that is finding
-        the max/min/add value within their 16 work items when SIMD == 16.
-     2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated
-        result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding.
-     3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it
-        compares the result in the message and the result within its thread, then forward the correct result to
-        the next thread by sending a message again. If it is the last thread, send it to thread 0.
-     4. Thread 0 finally get the message from the last one and broadcast the final result. */
-  void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) {
+/**
+ * Basic idea:
+ * 1. All the threads firstly calculate the max/min/add value for the
+ * allocated work-items. SIMD16=> 16 work-items allocated for each thread
+ * 2. Each thread will write the computed reduce OP result in SLM memory
+ * based on the threadId
+ * 3. After a memory fence for sync each thread will read the SLM memory based
+ * on a loop using the thread number value (threadN)
+ * 4. At the end each thread has the final value computed individually
+ */
+  void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn){
     const GenRegister dst = ra->genReg(insn.dst(0));
-    const GenRegister tmp = ra->genReg(insn.dst(2));
-    GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag);
-    GenRegister nextThreadID = ra->genReg(insn.src(1));
+    const GenRegister tmp = ra->genReg(insn.dst(1));
+
     const GenRegister theVal = ra->genReg(insn.src(0));
-    GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid));
-    GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn));
-    GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward.
-    msgData.vstride = GEN_VERTICAL_STRIDE_0;
-    msgData.width = GEN_WIDTH_1;
-    msgData.hstride = GEN_HORIZONTAL_STRIDE_0;
-    GenRegister threadData =
-      GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread.
-    threadData.vstride = GEN_VERTICAL_STRIDE_0;
-    threadData.width = GEN_WIDTH_1;
-    threadData.hstride = GEN_HORIZONTAL_STRIDE_0;
+    GenRegister threadData = ra->genReg(insn.src(1));
+    GenRegister partialData = GenRegister::toUniform(threadData, dst.type);
+    GenRegister threadId = ra->genReg(insn.src(2));
+    GenRegister threadNum = ra->genReg(insn.src(3));
+
+    threadId = GenRegister::toUniform(threadId, GEN_TYPE_UD);
+    threadNum = GenRegister::toUniform(threadNum, GEN_TYPE_UD);
+
     uint32_t wg_op = insn.extra.workgroupOp;
     uint32_t simd = p->curr.execWidth;
-    GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW);
-    flag_save.vstride = GEN_VERTICAL_STRIDE_0;
-    flag_save.width = GEN_WIDTH_1;
-    flag_save.hstride = GEN_HORIZONTAL_STRIDE_0;
-    int32_t jip;
-    int32_t oneThreadJip = -1;
-
-    p->push(); { /* First, so something within thread. */
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      /* Do some calculation within each thread. */
-      workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
-    } p->pop();
+    int32_t jip0, jip1;
 
-    if(wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
-        p->push(); {
-          p->MOV(dst, threadData);
-        } p->pop();
-    }
-    else {
-      /* If we are the only one thread, no need to send msg, just broadcast the result.*/
-      p->push(); {
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->curr.noMask = 1;
-        p->curr.execWidth = 1;
-        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-        p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
-
-        /* Broadcast result. */
-        if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
-          p->curr.predicate = GEN_PREDICATE_NORMAL;
-          p->curr.inversePredicate = 1;
-          p->MOV(flag_save, GenRegister::immuw(0x0));
-          p->curr.inversePredicate = 0;
-          p->MOV(flag_save, GenRegister::immuw(0xffff));
-          p->curr.predicate = GEN_PREDICATE_NONE;
-          p->MOV(flagReg, flag_save);
-          p->curr.predicate = GEN_PREDICATE_NORMAL;
-          p->curr.execWidth = simd;
-          p->MOV(dst, threadData);
-        }
+    GenRegister result = GenRegister::offset(dst, 0, 16);
+    result = GenRegister::toUniform(result, dst.type);
 
-        /* Bail out. */
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.inversePredicate = 0;
-        p->curr.execWidth = 1;
-        oneThreadJip = p->n_instruction();
-        p->JMPI(GenRegister::immud(0));
-      } p->pop();
+    /* FIXME need continuous GRF allocation */
+    GenRegister msg = GenRegister::retype(ra->genReg(insn.src(4)), dst.type);
+    GenRegister msgSlmOff = GenRegister::retype(ra->genReg(insn.src(5)), GEN_TYPE_UD);
+    GenRegister msgAddr = GenRegister::retype(GenRegister::offset(msg, 0), GEN_TYPE_UD);
+    GenRegister msgData = GenRegister::retype(GenRegister::offset(msg, 1), dst.type);
 
-      p->push(); {
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->curr.noMask = 1;
-        p->curr.execWidth = 1;
-        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-        p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
+    GenRegister threadN = GenRegister::retype(tmp, GEN_TYPE_UD);
 
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.inversePredicate = 1;
-        p->MOV(flag_save, GenRegister::immuw(0x0));
-        p->curr.inversePredicate = 0;
-        p->MOV(flag_save, GenRegister::immuw(0xffff));
+    /* Do some calculation within each thread. */
+    workgroupOpInThread(msg, theVal, threadData, tmp, simd, wg_op, p);
 
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->MOV(flagReg, flag_save);
-      } p->pop();
+    p->curr.execWidth = 16;
+    p->MOV(threadN, ra->genReg(GenRegister::ud1grf(ir::ocl::threadn)));
 
-      p->push(); {
-        p->curr.noMask = 1;
-        p->curr.execWidth = 1;
+    /* All threads write the partial results to SLM memory */
+    p->curr.execWidth = 8;
+    p->MOV(msgData, threadData);
+    p->MUL(msgAddr, threadId, GenRegister::immd(0x4));
+    p->ADD(msgAddr, msgAddr, msgSlmOff);
+    p->UNTYPED_WRITE(msg, GenRegister::immw(0xFE), 1);
 
-        /* threadid 0, send the msg and wait */
-        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-        p->curr.inversePredicate = 1;
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        jip = p->n_instruction();
-        p->JMPI(GenRegister::immud(0));
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->MOV(msgData, threadData);
-        SEND_RESULT_MSG();
-        p->WAIT(2);
-        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+    /* Wait for write to complete in workgroup */
+    p->FENCE(tmp);
 
-        /* Others wait and send msg, and do something when we get the msg. */
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.inversePredicate = 0;
-        jip = p->n_instruction();
-        p->JMPI(GenRegister::immud(0));
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->WAIT(2);
-        workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
-        SEND_RESULT_MSG();
-        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+    /* Initialiaze partialData register, it will hold the final result */
+    initValue(p, partialData, wg_op);
 
-        /* Restore the flag. */
-        p->curr.predicate = GEN_PREDICATE_NONE;
-        p->MOV(flagReg, flag_save);
-      } p->pop();
+    /* Perform a loop based on thread count */
+    p->push();{
+      jip0 = p->n_instruction();
 
-      /* Broadcast the result. */
-      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN
-          || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
-        p->push(); {
-          p->curr.predicate = GEN_PREDICATE_NORMAL;
-          p->curr.noMask = 1;
-          p->curr.execWidth = 1;
-          p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-          p->curr.inversePredicate = 0;
+      p->curr.execWidth = 8;
+      p->curr.predicate = GEN_PREDICATE_NONE;
 
-          /* Not the first thread, wait for msg first. */
-          jip = p->n_instruction();
-          p->JMPI(GenRegister::immud(0));
-          p->curr.predicate = GEN_PREDICATE_NONE;
-          p->WAIT(2);
-          p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
-          /* Do something when get the msg. */
-          p->curr.execWidth = simd;
-          p->MOV(dst, msgData);
-          p->curr.execWidth = 8;
-          p->FWD_GATEWAY_MSG(nextThreadID, 2);
-          p->curr.execWidth = 1;
-          p->curr.inversePredicate = 1;
-          p->curr.predicate = GEN_PREDICATE_NORMAL;
-
-          /* The first thread, the last one will notify us. */
-          jip = p->n_instruction();
-          p->JMPI(GenRegister::immud(0));
-          p->curr.predicate = GEN_PREDICATE_NONE;
-          p->WAIT(2);
-          p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-        } p->pop();
-      }
-    }
+      /* TODO can be optimized further with larger SLM reads (4 elem) */
+      p->ADD(threadN, threadN, GenRegister::immd(-1));
+      p->MUL(msgAddr, threadN, GenRegister::immd(0x4));
+      p->ADD(msgAddr, msgAddr, msgSlmOff);
+      p->UNTYPED_READ(msgData, msgAddr, GenRegister::immw(0xFE), 1);
+
+      /* Perform OP after read, partialData will hold the result */
+      workgroupOpBetweenThread(partialData, msgData, wg_op, p);
+
+      /* While threadN is not 0, redo cycle of read SLM / update value */
+      p->curr.noMask = 1;
+      p->curr.flag = 0;
+      p->curr.subFlag = 1;
+      p->CMP(GEN_CONDITIONAL_G, threadN, GenRegister::immd(0x0));
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      jip1 = p->n_instruction();
+      p->JMPI(GenRegister::immud(0));
+      p->patchJMPI(jip1, jip0 - jip1, 0);
+    } p->pop();
 
-    if (oneThreadJip >=0)
-      p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0);
+    /* Save result to final register location dst */
+    p->curr.execWidth = 16;
+    p->MOV(dst, partialData);
   }
 
   void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) {
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 0b9fe45..0ad3f33 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -685,8 +685,9 @@  namespace gbe
     /*! double division */
     void F64DIV(Reg dst, Reg src0, Reg src1, GenRegister* tmp, int tmpNum);
     /*! Work Group Operations */
-    void WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister nextThreadID,
-                     GenRegister threadID, GenRegister threadn, GenRegister tmp);
+    void WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister data,
+                      GenRegister threadId, GenRegister threadN,
+                      GenRegister tmp, GenRegister msg, GenRegister slmOff);
     /* common functions for both binary instruction and sel_cmp and compare instruction.
        It will handle the IMM or normal register assignment, and will try to avoid LOADI
        as much as possible. */
@@ -1905,18 +1906,26 @@  namespace gbe
     }
   }
 
-  void Selection::Opaque::WORKGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, GenRegister nextThreadID,
-                      GenRegister threadID, GenRegister threadn, GenRegister tmp) {
-    SelectionInstruction *insn = this->appendInsn(SEL_OP_WORKGROUP_OP, 3, 4);
+  void Selection::Opaque::WORKGROUP_OP(uint32_t wg_op,
+                                       Reg dst,
+                                       GenRegister src,
+                                       GenRegister data,
+                                       GenRegister threadId,
+                                       GenRegister threadN,
+                                       GenRegister tmp,
+                                       GenRegister msg,
+                                       GenRegister slmOff) {
+    SelectionInstruction *insn = this->appendInsn(SEL_OP_WORKGROUP_OP, 2, 6);
     insn->extra.workgroupOp = wg_op;
     insn->dst(0) = dst;
-    insn->dst(1) = nextThreadID;
-    insn->dst(2) = tmp;
+    insn->dst(1) = tmp;
 
     insn->src(0) = src;
-    insn->src(1) = nextThreadID;
-    insn->src(2) = threadID;
-    insn->src(3) = threadn;
+    insn->src(1) = data;
+    insn->src(2) = threadId;
+    insn->src(3) = threadN;
+    insn->src(4) = msg;
+    insn->src(5) = slmOff;
   }
 
   // Boiler plate to initialize the selection library at c++ pre-main
@@ -6119,6 +6128,34 @@  extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
       return tid;
     }
 
+    /* SLM bassed communication between threads, most of the logic bellow */
+    INLINE bool emitWGReduce(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn) const
+    {
+      using namespace ir;
+      const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
+
+      const Type type = insn.getType();
+      GenRegister dst = sel.selReg(insn.getDst(0), type);
+      GenRegister src = sel.selReg(insn.getSrc(2), type);
+      GenRegister threadId = sel.selReg(ocl::threadid, ir::TYPE_U32);
+      GenRegister threadN = sel.selReg(ocl::threadn, ir::TYPE_U32);
+      const uint32_t srcNum = insn.getSrcNum();
+
+      GBE_ASSERT(srcNum == 3);
+      GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
+      GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
+      GenRegister tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD)), type);
+      GenRegister data = sel.selReg(sel.reg(FAMILY_WORD), type);
+      GenRegister msg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+      GenRegister slmOff = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+
+      /* compute individual slice of workitems, (e.g. 0->16 workitems) */
+      sel.MOV(slmOff, GenRegister::immud(insn.getSlmAddr()));
+      sel.WORKGROUP_OP(workGroupOp, dst, src, data, threadId,
+                       threadN, tmp, msg, slmOff);
+      return true;
+    }
+
     INLINE bool emitWGBroadcast(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn) const {
       /*  1. BARRIER    Ensure all the threads have set the correct value for the var which will be broadcasted.
           2. CMP IDs    Compare the local IDs with the specified ones in the function call.
@@ -6190,176 +6227,14 @@  extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
       using namespace ir;
       const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
 
-      if (workGroupOp == WORKGROUP_OP_BROADCAST) {
+      if (workGroupOp == WORKGROUP_OP_BROADCAST){
         return emitWGBroadcast(sel, insn);
-      } else if (workGroupOp >= WORKGROUP_OP_REDUCE_MIN && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
-        const uint32_t slmAddr = insn.getSlmAddr();
-        /* First, we create the TheadID/localID map, in order to get which thread hold the next 16 workitems. */
-
-        if (!sel.storeThreadMap) {
-          this->storeThreadID(sel, slmAddr);
-          sel.storeThreadMap = true;
-        }
-
-        /* Then we insert a barrier to make sure all the var we are interested in
-           have been assigned the final value. */
-        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)), sel.selReg(sel.reg(FAMILY_DWORD)), syncLocalBarrier);
-
-        /* Third, get the next thread ID which we will Forward MSG to. */
-        GenRegister nextThreadID = getNextThreadID(sel, slmAddr);
-        GenRegister threadID = sel.selReg(ocl::threadid, ir::TYPE_U32);
-        GenRegister threadNum = sel.selReg(ocl::threadn, ir::TYPE_U32);
-        GenRegister tmp = GenRegister::retype(sel.selReg(sel.reg(FAMILY_DWORD)), GEN_TYPE_UD);
-
-        const Type type = insn.getType();
-        const GenRegister dst = sel.selReg(insn.getDst(0), type);
-        const uint32_t srcNum = insn.getSrcNum();
-        GBE_ASSERT(srcNum == 3);
-        GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
-        GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
-        GenRegister src = sel.selReg(insn.getSrc(2), type);
-        sel.push(); {
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.WORKGROUP_OP(workGroupOp, dst, src, nextThreadID, threadID, threadNum, tmp);
-        } sel.pop();
       }
-      else if (workGroupOp == WORKGROUP_OP_REDUCE_ADD) {
-        const Type type = insn.getType();
-        GenRegister dst = sel.selReg(insn.getDst(0), type);
-        GenRegister src = sel.selReg(insn.getSrc(2), type);
-        const uint32_t srcNum = insn.getSrcNum();
-        GBE_ASSERT(srcNum == 3);
-        GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
-        GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
-        GenRegister threadID = sel.selReg(ocl::threadid, ir::TYPE_U32);
-        GenRegister threadSEL = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister threadNum = sel.selReg(ocl::threadn, ir::TYPE_U32);
-        GenRegister tmp = GenRegister::retype(
-            sel.selReg(sel.reg(FAMILY_DWORD)), type);
-        GenRegister nextThreadID = sel.selReg(sel.reg(FAMILY_WORD), type);
-        GenRegister result = sel.selReg(sel.reg(FAMILY_WORD), type);
-
-        vector<GenRegister> lstPartSum;
-        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), type));
-        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), type));
-        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), type));
-        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), type));
-        vector<GenRegister> fakeTemps;
-        fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
-        fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
-        sel.MOV(lstPartSum[0], GenRegister::immud(0));
-        sel.MOV(lstPartSum[1], GenRegister::immud(0));
-        sel.MOV(lstPartSum[2], GenRegister::immud(0));
-        sel.MOV(lstPartSum[3], GenRegister::immud(0));
-
-        /* precompute SLM address offsets */
-        GenRegister slm1Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister slm2Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister slm2RegOff1 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister slm2RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister slm1RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        GenRegister slm1RegOff8 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
-        sel.MOV(slm1Reg, GenRegister::immud(insn.getSlmAddr()));
-        sel.MUL(slm2Reg, threadNum, GenRegister::immud(0x10));
-        sel.ADD(slm2Reg, slm2Reg, slm1Reg);
-        sel.MUL(slm1RegOff4, threadID, GenRegister::immud(0x4));
-        sel.MUL(slm1RegOff8, threadID, GenRegister::immud(0x8));
-        sel.ADD(slm2RegOff1, threadID, slm2Reg);
-        sel.ADD(slm2RegOff4, slm1RegOff4, slm2Reg);
-        sel.ADD(slm1RegOff4, slm1RegOff4, slm1Reg);
-        sel.ADD(slm1RegOff8, slm1RegOff8, slm1Reg);
-
-        /* write the SLM with 0s on both the SLM region1 and region2 */
-        sel.UNTYPED_WRITE(slm1RegOff8,
-                          lstPartSum.data(), 2, GenRegister::immw(0xFE), fakeTemps);
-        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
-                sel.selReg(sel.reg(FAMILY_DWORD)),
-                SYNC_LOCAL_WRITE_FENCE);
-
-        /* compute individual slice of workitems, (e.g. 0->16 workitems) */
-        sel.WORKGROUP_OP(workGroupOp, result, src,
-                         nextThreadID, threadID, threadNum, tmp);
-
-        /* write result data to SLM with offset using threadID*/
-        sel.UNTYPED_WRITE(slm1RegOff4,
-                          &result, 1, GenRegister::immw(0xFE), fakeTemps);
-        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
-                    sel.selReg(sel.reg(FAMILY_DWORD)),
-                    SYNC_LOCAL_WRITE_FENCE);
-
-        /* select threads, compute/write to SLM region2 */
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-
-          /* select threads based on threadID%4==0 */
-          sel.MOV(threadSEL, threadID);
-          sel.SHR(threadSEL, threadSEL, GenRegister::immud(2));
-          sel.SHL(threadSEL, threadSEL, GenRegister::immud(2));
-          sel.CMP(GEN_CONDITIONAL_EQ, threadID, threadSEL, GenRegister::null());
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-
-          /* compute sums and write to SLM region2 */
-          sel.MOV(dst, GenRegister::immud(0));
-          sel.UNTYPED_READ(slm1RegOff4, lstPartSum.data(), 4,
-                             GenRegister::immw(0xFE), fakeTemps);
-          sel.ADD(dst, dst, lstPartSum[0]);
-          sel.ADD(dst, dst, lstPartSum[1]);
-          sel.ADD(dst, dst, lstPartSum[2]);
-          sel.ADD(dst, dst, lstPartSum[3]);
-          sel.UNTYPED_WRITE(slm2RegOff1,
-                            &dst, 1, GenRegister::immw(0xFE), fakeTemps);
-        } sel.pop();
-
-        /* wait for all writes to finish */
-        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
-                    sel.selReg(sel.reg(FAMILY_DWORD)),
-                   SYNC_LOCAL_WRITE_FENCE);
-
-        /* special case if threadnum<=4, skip final sum, just read result SLM */
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.CMP(GEN_CONDITIONAL_LE, threadNum, GenRegister::immud(4), GenRegister::null());
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-          sel.UNTYPED_READ(slm2Reg, &dst, 1,
-                           GenRegister::immw(0xFE), fakeTemps);
-        } sel.pop();
-
-        /* special case if threadnum>4, do final compute */
-        sel.push(); {
-          sel.curr.predicate = GEN_PREDICATE_NONE;
-          sel.curr.noMask = 1;
-          sel.curr.flag = 0;
-          sel.curr.subFlag = 1;
-          sel.CMP(GEN_CONDITIONAL_G, threadNum, GenRegister::immud(4), GenRegister::null());
-          sel.curr.predicate = GEN_PREDICATE_NORMAL;
-
-          /* each thread collects the partial sums and computes the final sum */
-          sel.MOV(dst, GenRegister::immud(0));
-          for(int i=0; i<2; i++){
-            sel.MOV(lstPartSum[0], GenRegister::immud(0));
-            sel.MOV(lstPartSum[1], GenRegister::immud(0));
-            sel.MOV(lstPartSum[2], GenRegister::immud(0));
-            sel.MOV(lstPartSum[3], GenRegister::immud(0));
-            sel.UNTYPED_READ(slm2Reg, lstPartSum.data(), 4,
-                             GenRegister::immw(0xFE), fakeTemps);
-            sel.ADD(dst, dst, lstPartSum[0]);
-            sel.ADD(dst, dst, lstPartSum[1]);
-            sel.ADD(dst, dst, lstPartSum[2]);
-            sel.ADD(dst, dst, lstPartSum[3]);
-            sel.ADD(slm2Reg, slm2Reg, GenRegister::immud(0x10));
-          }
-        } sel.pop();
+      else if (workGroupOp >= WORKGROUP_OP_REDUCE_ADD && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX){
+        return emitWGReduce(sel, insn);
       }
-      else {
+      else
         GBE_ASSERT(0);
-      }
 
       return true;
     }