[Mesa-dev,1/3] nir/search: add support for opcode variables

Submitted by Connor Abbott on Feb. 8, 2015, 1:05 a.m.

Details

Message ID 1423357546-30702-2-git-send-email-cwabbott0@gmail.com
State New
Headers show

Not browsing as part of any series.

Commit Message

Connor Abbott Feb. 8, 2015, 1:05 a.m.
These are similar to normal variables. They let us match trees with
arbitrary opcodes, as long as the opcodes match and possibly only if
they have certain algebraic properties, like being associative or
commutatative. The syntax is '!name' (so that it uses a different
character from # and @ which we already use, suggestions welcome...) or
'!name@property1,property2'.

Signed-off-by: Connor Abbott <cwabbott0@gmail.com>
---
 src/glsl/nir/nir_algebraic.py | 68 +++++++++++++++++++++++++++++++++++++------
 src/glsl/nir/nir_search.c     | 40 +++++++++++++++++++++----
 src/glsl/nir/nir_search.h     | 21 ++++++++++++-
 3 files changed, 113 insertions(+), 16 deletions(-)

Patch hide | download patch | download mbox

diff --git a/src/glsl/nir/nir_algebraic.py b/src/glsl/nir/nir_algebraic.py
index 4929745..6252969 100644
--- a/src/glsl/nir/nir_algebraic.py
+++ b/src/glsl/nir/nir_algebraic.py
@@ -49,9 +49,9 @@  class VarSet(object):
 
 class Value(object):
    @staticmethod
-   def create(val, name_base, varset):
+   def create(val, name_base, varset, op_varset):
       if isinstance(val, tuple):
-         return Expression(val, name_base, varset)
+         return Expression(val, name_base, varset, op_varset)
       elif isinstance(val, Expression):
          return val
       elif isinstance(val, (str, unicode)):
@@ -69,7 +69,7 @@  static const ${val.c_type} ${val.name} = {
    ${'true' if val.is_constant else 'false'},
    nir_type_${ val.required_type or 'invalid' },
 % elif isinstance(val, Expression):
-   nir_op_${val.opcode},
+   ${val.opcode_string},
    { ${', '.join(src.c_ptr for src in val.sources)} },
 % endif
 };""")
@@ -133,12 +133,45 @@  class Variable(Value):
       self.index = varset[self.var_name]
 
 class Expression(Value):
-   def __init__(self, expr, name_base, varset):
+   def __init__(self, expr, name_base, varset, op_varset):
       Value.__init__(self, name_base, "expression")
       assert isinstance(expr, tuple)
 
-      self.opcode = expr[0]
-      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
+      self.is_variable = expr[0].startswith("!")
+      if self.is_variable:
+         at_index = expr[0].find("@")
+         if at_index == -1:
+            name = expr[0][1:]
+            props_string = "0"
+         else:
+            name = expr[0][1:at_index]
+            props = expr[0][at_index+1:].split(",")
+            props_string = " | ".join("NIR_OP_IS_" + prop.upper()
+                                         for prop in props)
+
+         index = op_varset[name]
+         self.opcode_string = """{{
+            true,
+            {{
+               .var = {{
+                  {index},
+                  {props}
+               }}
+            }}
+         }}
+         """.format(index=index, props=props_string)
+         self.opcode = "variable"
+      else:
+         self.opcode = expr[0];
+         self.opcode_string = """{{
+            false,
+            {{
+               .op = nir_op_{opcode}
+            }}
+         }}""".format(opcode=self.opcode)
+
+      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i),
+                                    varset, op_varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
    def render(self):
@@ -152,17 +185,21 @@  class SearchAndReplace(object):
       self.id = _optimization_ids.next()
 
       varset = VarSet()
+      op_varset = VarSet()
       if isinstance(search, Expression):
          self.search = search
       else:
-         self.search = Expression(search, "search{0}".format(self.id), varset)
+         self.search = Expression(search, "search{0}".format(self.id), varset,
+                                  op_varset)
 
       varset.lock()
+      op_varset.lock()
 
       if isinstance(replace, Value):
          self.replace = replace
       else:
-         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
+         self.replace = Value.create(replace, "replace{0}".format(self.id),
+                                     varset, op_varset)
 
 _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
@@ -206,6 +243,7 @@  ${pass_name}_block(nir_block *block, void *void_state)
 
       switch (alu->op) {
       % for opcode in xform_dict.keys():
+         % if opcode != "variable":
       case nir_op_${opcode}:
          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
             const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
@@ -215,9 +253,21 @@  ${pass_name}_block(nir_block *block, void *void_state)
                break;
             }
          }
-         break;
+         goto variable_opcode_match;
+         % endif
       % endfor
       default:
+variable_opcode_match:
+      % if "variable" in xform_dict:
+         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_variable_xforms); i++) {
+            if (nir_replace_instr(alu, ${pass_name}_variable_xforms[i].search,
+                                  ${pass_name}_variable_xforms[i].replace,
+                                  state->mem_ctx)) {
+               state->progress = true;
+               break;
+            }
+         }
+      % endif
          break;
       }
    }
diff --git a/src/glsl/nir/nir_search.c b/src/glsl/nir/nir_search.c
index 4671931..79cfdba 100644
--- a/src/glsl/nir/nir_search.c
+++ b/src/glsl/nir/nir_search.c
@@ -30,6 +30,9 @@ 
 struct match_state {
    unsigned variables_seen;
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
+
+   unsigned op_variables_seen;
+   nir_op op_variables[NIR_SEARCH_MAX_OPCODE_VARIABLES];
 };
 
 static bool
@@ -151,8 +154,25 @@  match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state)
 {
-   if (instr->op != expr->opcode)
-      return false;
+   if (expr->opcode.is_variable) {
+      nir_op_algebraic_property instr_props =
+         nir_op_infos[instr->op].algebraic_properties;
+      nir_op_algebraic_property expr_props =
+         expr->opcode.op.var.props;
+      if ((instr_props & expr_props) != expr_props)
+         return false;
+
+      if (state->op_variables_seen & (1 << expr->opcode.op.var.index)) {
+         if (instr->op != state->op_variables[expr->opcode.op.var.index])
+            return false;
+      } else {
+         state->op_variables[expr->opcode.op.var.index] = instr->op;
+         state->op_variables_seen |= 1 << expr->opcode.op.var.index;
+      }
+   } else {
+      if (instr->op != expr->opcode.op.op)
+         return false;
+   }
 
    assert(!instr->dest.saturate);
    assert(nir_op_infos[instr->op].num_inputs > 0);
@@ -212,15 +232,23 @@  construct_value(const nir_search_value *value, nir_alu_type type,
    case nir_search_value_expression: {
       const nir_search_expression *expr = nir_search_value_as_expression(value);
 
-      if (nir_op_infos[expr->opcode].output_size != 0)
-         num_components = nir_op_infos[expr->opcode].output_size;
+      nir_op op;
+      if (expr->opcode.is_variable) {
+         assert(state->op_variables_seen & (1 << expr->opcode.op.var.index));
+         op = state->op_variables[expr->opcode.op.var.index];
+      } else {
+         op = expr->opcode.op.op;
+      }
+
+      if (nir_op_infos[op].output_size != 0)
+         num_components = nir_op_infos[op].output_size;
 
-      nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
+      nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, op);
       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, NULL);
       alu->dest.write_mask = (1 << num_components) - 1;
       alu->dest.saturate = false;
 
-      for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
          /* If the source is an explicitly sized source, then we need to reset
           * the number of components to match.
           */
diff --git a/src/glsl/nir/nir_search.h b/src/glsl/nir/nir_search.h
index 7d47792..2cc9ccd 100644
--- a/src/glsl/nir/nir_search.h
+++ b/src/glsl/nir/nir_search.h
@@ -31,6 +31,7 @@ 
 #include "nir.h"
 
 #define NIR_SEARCH_MAX_VARIABLES 16
+#define NIR_SEARCH_MAX_OPCODE_VARIABLES 4
 
 typedef enum {
    nir_search_value_expression,
@@ -78,10 +79,28 @@  typedef struct {
    } data;
 } nir_search_constant;
 
+/** An opcode to match against.
+ *
+ * This can either be a regular opcode, or an "opcode variable" which matches
+ * any opcode that has certain algebraic properties.
+ */
+
+typedef struct {
+   bool is_variable;
+
+   union {
+      nir_op op;
+      struct {
+         unsigned index;
+         nir_op_algebraic_property props;
+      } var;
+   } op;
+} nir_search_op;
+
 typedef struct {
    nir_search_value value;
 
-   nir_op opcode;
+   nir_search_op opcode;
    const nir_search_value *srcs[4];
 } nir_search_expression;