Преглед изворни кода

shaderpipeline: add support for specialization constants (#1078)

Brian Lach пре 5 година
родитељ
комит
be63f71e80

+ 5 - 1
panda/src/glstuff/glShaderContext_src.cxx

@@ -3425,7 +3425,11 @@ attach_shader(const ShaderModule *module) {
                                 (const char *)spv->get_data(),
                                 spv->get_data_size() * sizeof(uint32_t));
       }
-      _glgsg->_glSpecializeShader(handle, "main", 0, nullptr, nullptr);
+
+      const Shader::ModuleSpecConstants &consts = _shader->_module_spec_consts[module];
+      _glgsg->_glSpecializeShader(handle, "main", consts._indices.size(),
+                                  (GLuint *)consts._indices.data(),
+                                  (GLuint *)consts._values.data());
     }
     else
 #endif  // !OPENGLES

+ 67 - 0
panda/src/gobj/shader.I

@@ -221,6 +221,49 @@ set_cache_compiled_shader(bool flag) {
   _cache_compiled_shader = flag;
 }
 
+/**
+ * Sets a boolean value for the specialization constant with the indicated
+ * name.  All modules containing a specialization constant with this name
+ * will be given this value.
+ *
+ * Returns true if there was a specialization constant with this name on any of
+ * the modules, false otherwise.
+ */
+INLINE bool Shader::
+set_constant(CPT_InternalName name, bool value) {
+  return set_constant(name, (uint32_t)value);
+}
+
+/**
+ * Sets an integer value for the specialization constant with the indicated
+ * name.  All modules containing a specialization constant with this name
+ * will be given this value.
+ *
+ * Returns true if there was a specialization constant with this name on any of
+ * the modules, false otherwise.
+ */
+INLINE bool Shader::
+set_constant(CPT_InternalName name, int value) {
+  uint32_t val;
+  *((int *)&val) = value;
+  return set_constant(name, val);
+}
+
+/**
+ * Sets a float value for the specialization constant with the indicated
+ * name.  All modules containing a specialization constant with this name
+ * will be given this value.
+ *
+ * Returns true if there was a specialization constant with this name on any of
+ * the modules, false otherwise.
+ */
+INLINE bool Shader::
+set_constant(CPT_InternalName name, float value) {
+  uint32_t val;
+  *((float *)&val) = value;
+  return set_constant(name, val);
+}
+
 /**
  *
  */
@@ -812,6 +855,30 @@ operator < (const Shader::ShaderFile &other) const {
   return false;
 }
 
+/**
+ * Sets an external value for the specialization constant with the given ID.
+ *
+ * Returns true if the value is different from what was already in there, false
+ * otherwise.
+ */
+INLINE bool Shader::ModuleSpecConstants::
+set_constant(uint32_t id, uint32_t value) {
+  auto it = std::find(_indices.begin(), _indices.end(), id);
+  if (it == _indices.end()) {
+    _indices.push_back(id);
+    _values.push_back(value);
+    return true;
+
+  } else {
+    size_t loc = it - _indices.begin();
+    if (_values[loc] != value) {
+      _values[loc] = value;
+      return true;
+    }
+    return false;
+  }
+}
+
 /**
  * Returns a PStatCollector for timing the preparation of just this shader.
  */

+ 63 - 0
panda/src/gobj/shader.cxx

@@ -890,6 +890,7 @@ do_read_source(ShaderModule::Stage stage, std::istream &in,
   }
 
   int used_caps = module->get_used_capabilities();
+  _module_spec_consts.insert({module, ModuleSpecConstants()});
   _modules.push_back(std::move(module));
   _module_mask |= (1u << (uint32_t)stage);
   _used_caps |= used_caps;
@@ -920,6 +921,8 @@ link() {
   pvector<Parameter *> parameters;
   BitArray used_locations;
 
+  pmap<CPT_InternalName, const ::ShaderType *> spec_const_types;
+
   for (COWPT(ShaderModule) &cow_module : _modules) {
     const ShaderModule *module = cow_module.get_read_pointer();
     pmap<int, int> remap;
@@ -991,6 +994,23 @@ link() {
       }
       module->remap_parameter_locations(remap);
     }
+
+    for (const ShaderModule::SpecializationConstant &spec_const : module->_spec_constants) {
+      auto result = spec_const_types.insert({spec_const.name, spec_const.type});
+      auto &it = result.first;
+
+      if (!result.second) {
+        // Another module has already defined a spec constant with this name.
+        // Make sure they have the same type.
+        const ::ShaderType *other_type = it->second;
+        if (spec_const.type != other_type) {
+          shader_cat.error()
+            << "Specialization constant " << *spec_const.name << " in module "
+            << *module << " is declared in another stage with a mismatching type!\n";
+          return false;
+        }
+      }
+    }
   }
 
   // Now bind all of the parameters.
@@ -2973,6 +2993,49 @@ make_compute(ShaderLanguage lang, string body) {
   return shader;
 }
 
+/**
+ * Sets an unsigned integer value for the specialization constant with the
+ * indicated name.  All modules containing a specialization constant with
+ * this name will be given this value.
+ *
+ * Returns true if there was a specialization constant with this name on any of
+ * the modules, false otherwise.
+ */
+bool Shader::
+set_constant(CPT_InternalName name, unsigned int value) {
+  bool any_changed = false;
+  bool any_found = false;
+
+  // Set the value on all modules containing a spec constant with this name.
+  for (COWPT(ShaderModule) &cow_module : _modules) {
+    const ShaderModule *module = cow_module.get_read_pointer();
+
+    for (const ShaderModule::SpecializationConstant &spec_const : module->_spec_constants) {
+      if (spec_const.name == name) {
+        // Found one.
+        ModuleSpecConstants &constants = _module_spec_consts[module];
+        if (constants.set_constant(spec_const.id, value)) {
+          any_changed = true;
+        }
+        any_found = true;
+        break;
+      }
+    }
+  }
+
+  if (any_changed) {
+    if (shader_cat.is_debug()) {
+      shader_cat.debug()
+        << "Specialization constant value changed, forcing shader to "
+        << "re-prepare.\n";
+    }
+    // Force the shader to be re-prepared so the value change is picked up.
+    release_all();
+  }
+
+  return any_found;
+}
+
 /**
  * Indicates that the shader should be enqueued to be prepared in the
  * indicated prepared_objects at the beginning of the next frame.  This will

+ 21 - 0
panda/src/gobj/shader.h

@@ -121,6 +121,11 @@ PUBLISHED:
   INLINE bool get_cache_compiled_shader() const;
   INLINE void set_cache_compiled_shader(bool flag);
 
+  INLINE bool set_constant(CPT_InternalName name, bool value);
+  INLINE bool set_constant(CPT_InternalName name, int value);
+  INLINE bool set_constant(CPT_InternalName name, float value);
+  bool set_constant(CPT_InternalName name, unsigned int value);
+
   PT(AsyncFuture) prepare(PreparedGraphicsObjects *prepared_objects);
   bool is_prepared(PreparedGraphicsObjects *prepared_objects) const;
   bool release(PreparedGraphicsObjects *prepared_objects);
@@ -439,6 +444,20 @@ public:
     std::string _compute;
   };
 
+  /**
+   * Contains external values given to the specialization constants of a single
+   * ShaderModule.
+   */
+  class ModuleSpecConstants {
+  public:
+    INLINE ModuleSpecConstants() {};
+
+    INLINE bool set_constant(uint32_t id, uint32_t value);
+  public:
+    pvector<uint32_t> _values;
+    pvector<uint32_t> _indices;
+  };
+
 protected:
   bool report_parameter_error(const InternalName *name, const ::ShaderType *type, const char *msg);
   bool expect_num_words(const InternalName *name, const ::ShaderType *type, size_t len);
@@ -486,6 +505,8 @@ public:
 
   typedef pvector<COWPT(ShaderModule)> Modules;
   Modules _modules;
+  typedef pmap<const ShaderModule *, ModuleSpecConstants> ModuleSpecConsts;
+  ModuleSpecConsts _module_spec_consts;
   uint32_t _module_mask = 0;
   int _used_caps = 0;
 

+ 16 - 0
panda/src/gobj/shaderModule.I

@@ -45,6 +45,22 @@ set_source_filename(const Filename &filename) {
   _source_filename = filename;
 }
 
+/**
+ * Returns the SpecializationConstant at the indicated index.
+ */
+INLINE const ShaderModule::SpecializationConstant &ShaderModule::
+get_spec_constant(size_t i) const {
+  return _spec_constants[i];
+}
+
+/**
+ * Returns the number of SpecializationConstants in the module.
+ */
+INLINE size_t ShaderModule::
+get_num_spec_constants() const {
+  return _spec_constants.size();
+}
+
 /**
  * Returns the number of input variables in this shader stage.
  */

+ 17 - 0
panda/src/gobj/shaderModule.h

@@ -57,6 +57,16 @@ PUBLISHED:
     int _location;
   };
 
+  /**
+   * Defines a specialization constant.
+   */
+  struct SpecializationConstant {
+  PUBLISHED:
+    const ShaderType *type;
+    CPT(InternalName) name;
+    uint32_t id;
+  };
+
 public:
   ShaderModule(Stage stage);
   virtual ~ShaderModule();
@@ -67,6 +77,9 @@ public:
   INLINE const Filename &get_source_filename() const;
   INLINE void set_source_filename(const Filename &);
 
+  INLINE const SpecializationConstant &get_spec_constant(size_t i) const;
+  INLINE size_t get_num_spec_constants() const;
+
   size_t get_num_inputs() const;
   const Variable &get_input(size_t i) const;
   int find_input(CPT_InternalName name) const;
@@ -88,6 +101,7 @@ PUBLISHED:
   MAKE_PROPERTY(stage, get_stage);
   MAKE_SEQ_PROPERTY(inputs, get_num_inputs, get_input);
   MAKE_SEQ_PROPERTY(outputs, get_num_outputs, get_output);
+  MAKE_SEQ_PROPERTY(spec_constants, get_num_spec_constants, get_spec_constant);
 
   virtual std::string get_ir() const=0;
 
@@ -172,6 +186,9 @@ protected:
   Variables _outputs;
   Variables _parameters;
 
+  typedef pvector<SpecializationConstant> SpecializationConstants;
+  SpecializationConstants _spec_constants;
+
   friend class Shader;
 
 public:

+ 37 - 0
panda/src/shaderpipeline/shaderModuleSpirV.cxx

@@ -193,6 +193,18 @@ ShaderModuleSpirV(Stage stage, std::vector<uint32_t> words) :
         }
       }
     }
+    else if (def._dtype == DT_spec_constant && def._type != nullptr) {
+      SpecializationConstant spec_constant;
+      spec_constant.id = def._spec_id;
+      spec_constant.name = InternalName::make(def._name);
+      spec_constant.type = def._type;
+      if (shader_cat.is_debug()) {
+        shader_cat.debug()
+          << "Found specialization constant " << def._name << " with type "
+          << *def._type << " and ID " << def._spec_id << "\n";
+      }
+      _spec_constants.push_back(spec_constant);
+    }
   }
 
 #ifndef NDEBUG
@@ -2154,6 +2166,13 @@ parse_instruction(const Instruction &op, uint32_t &current_function_id) {
     record_constant(op.args[1], op.args[0], op.args + 2, op.nargs - 2);
     break;
 
+  case spv::OpSpecConstantTrue:
+  case spv::OpSpecConstantFalse:
+  case spv::OpSpecConstant:
+    // A specialization constant.
+    record_spec_constant(op.args[1], op.args[0]);
+    break;
+
   case spv::OpFunction:
     if (current_function_id != 0) {
       shader_cat.error()
@@ -2302,6 +2321,10 @@ parse_instruction(const Instruction &op, uint32_t &current_function_id) {
       _defs[op.args[0]]._array_stride = op.args[2];
       break;
 
+    case spv::DecorationSpecId:
+      _defs[op.args[0]]._spec_id = op.args[2];
+      break;
+
     default:
       break;
     }
@@ -2564,6 +2587,20 @@ record_local(uint32_t id, uint32_t type_id, uint32_t from_id, uint32_t function_
   nassertv(function_id != 0);
 }
 
+/**
+ * Records that the given specialization constant has been defined.
+ */
+void ShaderModuleSpirV::InstructionWriter::
+record_spec_constant(uint32_t id, uint32_t type_id) {
+  const Definition &type_def = get_definition(type_id);
+  nassertv(type_def._dtype == DT_type);
+
+  Definition &def = modify_definition(id);
+  def._dtype = DT_spec_constant;
+  def._type_id = type_id;
+  def._type = type_def._type;
+}
+
 /**
  * Called for a variable, or any id whose value (indirectly) originates from a
  * variable, to mark the variable and any types used thereby as "used".

+ 3 - 0
panda/src/shaderpipeline/shaderModuleSpirV.h

@@ -124,6 +124,7 @@ public:
     DT_function_parameter,
     DT_function,
     DT_local,
+    DT_spec_constant,
   };
 
   enum DefinitionFlags {
@@ -165,6 +166,7 @@ public:
     uint32_t _array_stride = 0;
     uint32_t _origin_id = 0; // set for loads, tracks original variable ID
     uint32_t _function_id = 0;
+    uint32_t _spec_id = 0;
     MemberDefinitions _members;
     int _flags = 0;
 
@@ -228,6 +230,7 @@ public:
     void record_ext_inst_import(uint32_t id, const char *import);
     void record_function(uint32_t id, uint32_t type_id);
     void record_local(uint32_t id, uint32_t type_id, uint32_t from_id, uint32_t function_id);
+    void record_spec_constant(uint32_t id, uint32_t type_id);
 
     void mark_used(uint32_t id);