فهرست منبع

shaderpipeline: Assorted fixes and improvements

rdb 1 سال پیش
والد
کامیت
ecd76c9800
25فایلهای تغییر یافته به همراه714 افزوده شده و 343 حذف شده
  1. 7 0
      dtool/src/parser-inc/spirv-tools/libspirv.h
  2. 0 1
      panda/src/glstuff/glShaderContext_src.h
  3. 9 9
      panda/src/gobj/shaderModule.h
  4. 84 0
      panda/src/gobj/shaderType.cxx
  5. 7 0
      panda/src/gobj/shaderType.h
  6. 1 0
      panda/src/shaderpipeline/p3shaderpipeline_composite2.cxx
  7. 10 7
      panda/src/shaderpipeline/shaderModuleSpirV.cxx
  8. 3 1
      panda/src/shaderpipeline/shaderModuleSpirV.h
  9. 3 3
      panda/src/shaderpipeline/spirVFlattenStructPass.cxx
  10. 1 1
      panda/src/shaderpipeline/spirVFlattenStructPass.h
  11. 9 6
      panda/src/shaderpipeline/spirVHoistStructResourcesPass.cxx
  12. 1 1
      panda/src/shaderpipeline/spirVHoistStructResourcesPass.h
  13. 162 0
      panda/src/shaderpipeline/spirVMakeBlockPass.cxx
  14. 46 0
      panda/src/shaderpipeline/spirVMakeBlockPass.h
  15. 2 2
      panda/src/shaderpipeline/spirVReplaceVariableTypePass.cxx
  16. 1 1
      panda/src/shaderpipeline/spirVReplaceVariableTypePass.h
  17. 8 0
      panda/src/shaderpipeline/spirVResultDatabase.I
  18. 33 2
      panda/src/shaderpipeline/spirVResultDatabase.cxx
  19. 8 2
      panda/src/shaderpipeline/spirVResultDatabase.h
  20. 23 0
      panda/src/shaderpipeline/spirVTransformPass.I
  21. 257 42
      panda/src/shaderpipeline/spirVTransformPass.cxx
  22. 14 3
      panda/src/shaderpipeline/spirVTransformPass.h
  23. 19 254
      panda/src/shaderpipeline/spirVTransformer.cxx
  24. 2 5
      panda/src/shaderpipeline/spirVTransformer.h
  25. 4 3
      tests/display/test_glsl_shader.py

+ 7 - 0
dtool/src/parser-inc/spirv-tools/libspirv.h

@@ -0,0 +1,7 @@
+#pragma once
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+typedef enum {} spv_target_env;

+ 0 - 1
panda/src/glstuff/glShaderContext_src.h

@@ -130,7 +130,6 @@ private:
     PT(ShaderInputBinding) _binding;
     ShaderInputBinding::ResourceId _resource_id;
     GLenum _target;
-    int _index;
   };
   typedef pvector<TextureUnit> TextureUnits;
   TextureUnits _texture_units;

+ 9 - 9
panda/src/gobj/shaderModule.h

@@ -73,17 +73,17 @@ public:
   INLINE const SpecializationConstant &get_spec_constant(size_t i) const;
   INLINE size_t get_num_spec_constants() const;
 
-  size_t get_num_inputs() const;
-  const Variable &get_input(size_t i) const;
-  int find_input(CPT_InternalName name) const;
+  INLINE size_t get_num_inputs() const;
+  INLINE const Variable &get_input(size_t i) const;
+  INLINE int find_input(CPT_InternalName name) const;
 
-  size_t get_num_outputs() const;
-  const Variable &get_output(size_t i) const;
-  int find_output(CPT_InternalName name) const;
+  INLINE size_t get_num_outputs() const;
+  INLINE const Variable &get_output(size_t i) const;
+  INLINE int find_output(CPT_InternalName name) const;
 
-  size_t get_num_parameters() const;
-  const Variable &get_parameter(size_t i) const;
-  int find_parameter(CPT_InternalName name) const;
+  INLINE size_t get_num_parameters() const;
+  INLINE const Variable &get_parameter(size_t i) const;
+  INLINE int find_parameter(CPT_InternalName name) const;
 
   typedef pmap<CPT_InternalName, Variable *> VariablesByName;
 

+ 84 - 0
panda/src/gobj/shaderType.cxx

@@ -223,6 +223,18 @@ as_scalar_type(ScalarType &type, uint32_t &num_elements,
   return true;
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::Scalar::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  if (_scalar_type == a) {
+    return ShaderType::register_type(ShaderType::Scalar(b));
+  } else {
+    return this;
+  }
+}
+
 /**
  *
  */
@@ -297,6 +309,18 @@ as_scalar_type(ScalarType &type, uint32_t &num_elements,
   return true;
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::Vector::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  if (_scalar_type == a) {
+    return ShaderType::register_type(ShaderType::Vector(b, _num_components));
+  } else {
+    return this;
+  }
+}
+
 /**
  * Returns the number of in/out locations taken up by in/out variables having
  * this type.
@@ -390,6 +414,18 @@ as_scalar_type(ScalarType &type, uint32_t &num_elements,
   return true;
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::Matrix::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  if (_scalar_type == a) {
+    return ShaderType::register_type(ShaderType::Matrix(b, _num_rows, _num_columns));
+  } else {
+    return this;
+  }
+}
+
 /**
  *
  */
@@ -521,6 +557,28 @@ contains_scalar_type(ScalarType type) const {
   return false;
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::Struct::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  if (contains_scalar_type(a)) {
+    ShaderType::Struct copy;
+    for (const Member &member : _members) {
+      const ShaderType *type = member.type->replace_scalar_type(a, b);
+      if ((a == ST_double) != (b == ST_double)) {
+        // Recompute offsets.
+        copy.add_member(type, member.name);
+      } else {
+        copy.add_member(type, member.name, member.offset);
+      }
+    }
+    return ShaderType::register_type(std::move(copy));
+  } else {
+    return this;
+  }
+}
+
 /**
  *
  */
@@ -730,6 +788,19 @@ as_scalar_type(ScalarType &type, uint32_t &num_elements,
   return false;
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::Array::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  const ShaderType *element_type = _element_type->replace_scalar_type(a, b);
+  if (_element_type != element_type) {
+    return ShaderType::register_type(ShaderType::Array(element_type, _num_elements));
+  } else {
+    return this;
+  }
+}
+
 /**
  *
  */
@@ -1077,6 +1148,19 @@ contains_scalar_type(ScalarType type) const {
   return _contained_type != nullptr && _contained_type->contains_scalar_type(type);
 }
 
+/**
+ * Replaces any occurrence of the given scalar type with the given other one.
+ */
+const ShaderType *ShaderType::StorageBuffer::
+replace_scalar_type(ScalarType a, ScalarType b) const {
+  const ShaderType *contained_type = _contained_type->replace_scalar_type(a, b);
+  if (_contained_type != contained_type) {
+    return ShaderType::register_type(ShaderType::StorageBuffer(contained_type, _access));
+  } else {
+    return this;
+  }
+}
+
 /**
  * Writes the contents of this object to the datagram for shipping out to a
  * Bam file.

+ 7 - 0
panda/src/gobj/shaderType.h

@@ -98,6 +98,7 @@ public:
                               uint32_t &num_elements,
                               uint32_t &num_rows,
                               uint32_t &num_columns) const { return false; }
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const { return this; }
 
   virtual const Scalar *as_scalar() const { return nullptr; }
   virtual const Vector *as_vector() const { return nullptr; }
@@ -182,6 +183,7 @@ public:
 
   INLINE ScalarType get_scalar_type() const;
   virtual bool contains_scalar_type(ScalarType type) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
   virtual bool as_scalar_type(ScalarType &type, uint32_t &num_elements,
                               uint32_t &num_rows, uint32_t &num_columns) const override;
 
@@ -228,6 +230,7 @@ public:
   virtual bool contains_scalar_type(ScalarType type) const override;
   virtual bool as_scalar_type(ScalarType &type, uint32_t &num_elements,
                               uint32_t &num_rows, uint32_t &num_columns) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
 
   virtual int get_num_interface_locations() const override;
 
@@ -275,6 +278,7 @@ public:
   virtual bool contains_scalar_type(ScalarType type) const override;
   virtual bool as_scalar_type(ScalarType &type, uint32_t &num_elements,
                               uint32_t &num_rows, uint32_t &num_columns) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
 
   virtual int get_num_interface_locations() const override;
 
@@ -333,6 +337,7 @@ public:
   bool is_aggregate_type() const override { return true; }
   virtual bool contains_opaque_type() const override;
   virtual bool contains_scalar_type(ScalarType type) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
   const Struct *as_struct() const override { return this; }
 
 PUBLISHED:
@@ -382,6 +387,7 @@ public:
   virtual bool contains_scalar_type(ScalarType type) const override;
   virtual bool as_scalar_type(ScalarType &type, uint32_t &num_elements,
                               uint32_t &num_rows, uint32_t &num_columns) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
 
   virtual void output(std::ostream &out) const override;
   virtual int compare_to_impl(const ShaderType &other) const override;
@@ -570,6 +576,7 @@ public:
   virtual int compare_to_impl(const ShaderType &other) const override;
 
   virtual bool contains_scalar_type(ScalarType type) const override;
+  virtual const ShaderType *replace_scalar_type(ScalarType a, ScalarType b) const override;
 
   const StorageBuffer *as_storage_buffer() const override { return this; }
 

+ 1 - 0
panda/src/shaderpipeline/p3shaderpipeline_composite2.cxx

@@ -1,6 +1,7 @@
 #ifndef CPPPARSER
 #include "spirVFlattenStructPass.cxx"
 #include "spirVHoistStructResourcesPass.cxx"
+#include "spirVMakeBlockPass.cxx"
 #include "spirVRemoveUnusedVariablesPass.cxx"
 #include "spirVReplaceVariableTypePass.cxx"
 #include "spirVResultDatabase.cxx"

+ 10 - 7
panda/src/shaderpipeline/shaderModuleSpirV.cxx

@@ -22,8 +22,6 @@
 
 #include "GLSL.std.450.h"
 
-#include <spirv-tools/libspirv.h>
-
 #ifndef NDEBUG
 #include <glslang/SPIRV/disassemble.h>
 #endif
@@ -660,16 +658,21 @@ validate_header() const {
  * Checks whether this is valid SPIR-V.
  */
 bool ShaderModuleSpirV::InstructionStream::
-validate() const {
-  spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0);
+validate(spv_target_env env) const {
+  spv_context context = spvContextCreate(env);
   spv_const_binary_t binary = {_words.data(), _words.size()};
   spv_diagnostic diagnostic = nullptr;
 
   spv_result_t result = spvValidate(context, &binary, &diagnostic);
 
-  if (diagnostic != nullptr) {
-    shader_cat.error()
-      << "SPIR-V validation failed:\n" << diagnostic->error << "\n";
+  if (result != SPV_SUCCESS) {
+    if (diagnostic != nullptr) {
+      shader_cat.error()
+        << "SPIR-V validation failed:\n" << diagnostic->error << "\n";
+    } else {
+      shader_cat.error()
+        << "SPIR-V validation failed.\n";
+    }
 
     disassemble(shader_cat.error() << "Disassembly follows:\n");
   }

+ 3 - 1
panda/src/shaderpipeline/shaderModuleSpirV.h

@@ -22,6 +22,8 @@
 #endif
 #include "spirv.hpp"
 
+#include <spirv-tools/libspirv.h>
+
 class ShaderType;
 
 /**
@@ -91,7 +93,7 @@ public:
     INLINE InstructionStream(std::vector<uint32_t> words);
 
     bool validate_header() const;
-    bool validate() const;
+    bool validate(spv_target_env env = SPV_ENV_UNIVERSAL_1_0) const;
     bool disassemble(std::ostream &out) const;
 
     INLINE operator std::vector<uint32_t> & ();

+ 3 - 3
panda/src/shaderpipeline/spirVFlattenStructPass.cxx

@@ -64,7 +64,7 @@ transform_definition_op(Instruction op) {
         // Insert a new variable for this struct member.
         uint32_t variable_id = define_variable(member.type, spv::StorageClassUniformConstant);
         if (!member.name.empty()) {
-          add_name(variable_id, member.name);
+          set_name(variable_id, member.name);
         }
 
         Definition &variable_def = _db.modify_definition(variable_id);
@@ -99,7 +99,7 @@ transform_definition_op(Instruction op) {
  *
  */
 bool SpirVFlattenStructPass::
-transform_function_op(Instruction op, uint32_t function_id) {
+transform_function_op(Instruction op) {
   switch (op.opcode) {
   case spv::OpAccessChain:
   case spv::OpInBoundsAccessChain:
@@ -286,7 +286,7 @@ transform_function_op(Instruction op, uint32_t function_id) {
     break;
 
   default:
-    return SpirVTransformPass::transform_function_op(op, function_id);
+    return SpirVTransformPass::transform_function_op(op);
   }
 
   return true;

+ 1 - 1
panda/src/shaderpipeline/spirVFlattenStructPass.h

@@ -27,7 +27,7 @@ public:
   virtual void preprocess();
 
   virtual bool transform_definition_op(Instruction op);
-  virtual bool transform_function_op(Instruction op, uint32_t function_id);
+  virtual bool transform_function_op(Instruction op);
 
 private:
   const uint32_t _type_id;

+ 9 - 6
panda/src/shaderpipeline/spirVHoistStructResourcesPass.cxx

@@ -213,7 +213,7 @@ begin_function(Instruction op) {
  *
  */
 bool SpirVHoistStructResourcesPass::
-transform_function_op(Instruction op, uint32_t function_id) {
+transform_function_op(Instruction op) {
   switch (op.opcode) {
   case spv::OpFunctionParameter:
     // Erase deleted types in function parameter list.
@@ -223,7 +223,7 @@ transform_function_op(Instruction op, uint32_t function_id) {
         delete_id(param_id);
       } else {
         add_instruction(op.opcode, op.args, op.nargs);
-        _db.modify_definition(function_id)._parameters.push_back(op.args[1]);
+        _db.modify_definition(_current_function_id)._parameters.push_back(op.args[1]);
       }
 
       // Structs with non-opaque types must be passed through pointers.
@@ -244,7 +244,7 @@ transform_function_op(Instruction op, uint32_t function_id) {
           access_chain._var_id = param_id;
           _hoisted_vars[std::move(access_chain)] = id;
 
-          _db.record_function_parameter(id, type_ptr_id, function_id);
+          _db.record_function_parameter(id, type_ptr_id, _current_function_id);
         }
       }
 
@@ -374,7 +374,7 @@ transform_function_op(Instruction op, uint32_t function_id) {
             hoisted_new_args[1] = id;
             hoisted_new_args[2] = hoisted_var_id;
             add_instruction(spv::OpAccessChain, hoisted_new_args.data(), hoisted_new_args.size());
-            _db.record_temporary(id, hoisted_type_ptr_id, hoisted_var_id, function_id);
+            _db.record_temporary(id, hoisted_type_ptr_id, hoisted_var_id, _current_function_id);
 
             AccessChain new_access_chain(pair.second);
             new_access_chain._var_id = orig_chain_id;
@@ -482,12 +482,15 @@ transform_function_op(Instruction op, uint32_t function_id) {
     break;
 
   default:
-    return SpirVTransformPass::transform_function_op(op, function_id);
+    return SpirVTransformPass::transform_function_op(op);
   }
 
   return true;
 }
 
+/**
+ *
+ */
 void SpirVHoistStructResourcesPass::
 postprocess() {
   for (auto vit = _hoisted_vars.begin(); vit != _hoisted_vars.end(); ++vit) {
@@ -499,7 +502,7 @@ postprocess() {
       for (size_t i = 0; i < access_chain.size(); ++i) {
         name += "_m" + format_string(access_chain[i]);
       }
-      add_name(var_id, name);
+      set_name(var_id, name);
     }
   }
 }

+ 1 - 1
panda/src/shaderpipeline/spirVHoistStructResourcesPass.h

@@ -27,7 +27,7 @@ public:
 
   virtual bool transform_definition_op(Instruction op);
   virtual bool begin_function(Instruction op);
-  virtual bool transform_function_op(Instruction op, uint32_t function_id);
+  virtual bool transform_function_op(Instruction op);
 
   virtual void postprocess();
 

+ 162 - 0
panda/src/shaderpipeline/spirVMakeBlockPass.cxx

@@ -0,0 +1,162 @@
+/**
+ * PANDA 3D SOFTWARE
+ * Copyright (c) Carnegie Mellon University.  All rights reserved.
+ *
+ * All use of this software is subject to the terms of the revised BSD
+ * license.  You should have received a copy of this license along
+ * with this source code in a file named "LICENSE."
+ *
+ * @file spirVMakeBlockPass.cxx
+ * @author rdb
+ * @date 2024-10-11
+ */
+
+#include "spirVMakeBlockPass.h"
+
+/**
+ *
+ */
+SpirVMakeBlockPass::
+SpirVMakeBlockPass(const ShaderType::Struct *block_type, const pvector<uint32_t> &member_ids,
+                   spv::StorageClass storage_class, uint32_t binding, uint32_t set) :
+  _block_type(block_type),
+  _storage_class(storage_class),
+  _binding(binding),
+  _set(set) {
+  nassertv(block_type->get_num_members() == member_ids.size());
+
+  for (uint32_t i = 0; i < (uint32_t)member_ids.size(); ++i) {
+    uint32_t member_id = member_ids[i];
+    if (member_id > 0) {
+      _member_indices[member_id] = i;
+      delete_id(member_id);
+    }
+  }
+}
+
+/**
+ *
+ */
+bool SpirVMakeBlockPass::
+begin_function(Instruction op) {
+  // Define the block type at the first variable definition.
+  if (_block_var_id == 0) {
+    uint32_t block_type_id = define_type(_block_type);
+    _block_var_id = define_variable(_block_type, _storage_class);
+
+    decorate(block_type_id, spv::DecorationBlock);
+
+    if (_storage_class != spv::StorageClassPushConstant) {
+      decorate(_block_var_id, spv::DecorationBinding, _binding);
+      decorate(_block_var_id, spv::DecorationDescriptorSet, _set);
+    }
+
+    for (auto &item : _member_indices) {
+      const std::string &name = _db.get_definition(item.first)._name;
+      if (!name.empty()) {
+        set_member_name(block_type_id, item.second, name);
+      }
+    }
+  }
+  return true;
+}
+
+/**
+ *
+ */
+bool SpirVMakeBlockPass::
+transform_function_op(Instruction op) {
+  switch (op.opcode) {
+  case spv::OpAccessChain:
+  case spv::OpInBoundsAccessChain:
+    if (_member_indices.count(op.args[2])) {
+      uint32_t result_id = op.args[1];
+      uint32_t member_index = _member_indices[op.args[2]];
+      uint32_t constant_id = define_int_constant(member_index);
+
+      // Get a type pointer with the correct storage class.
+      uint32_t pointer_type_id = define_pointer_type(resolve_pointer_type(op.args[0]), _storage_class);
+
+      // Prepend our new block variable to the existing access chain.
+      pvector<uint32_t> new_args({pointer_type_id, result_id, _block_var_id, constant_id});
+      new_args.insert(new_args.end(), op.args + 3, op.args + op.nargs);
+
+      add_instruction(op.opcode, new_args.data(), new_args.size());
+      return false;
+    }
+    break;
+
+  case spv::OpFunctionCall:
+    // Add access chains when passing a load of a member id to a function.
+    if (op.nargs >= 3) {
+      for (size_t i = 3; i < op.nargs; ++i) {
+        maybe_replace_with_access_chain(op.args[i]);
+      }
+    }
+    break;
+
+  case spv::OpPtrEqual:
+  case spv::OpPtrNotEqual:
+  case spv::OpPtrDiff:
+    maybe_replace_with_access_chain(op.args[3]);
+    // fall through
+  case spv::OpLoad:
+  case spv::OpAtomicLoad:
+  case spv::OpAtomicExchange:
+  case spv::OpAtomicCompareExchange:
+  case spv::OpAtomicCompareExchangeWeak:
+  case spv::OpAtomicIIncrement:
+  case spv::OpAtomicIDecrement:
+  case spv::OpAtomicIAdd:
+  case spv::OpAtomicISub:
+  case spv::OpAtomicSMin:
+  case spv::OpAtomicUMin:
+  case spv::OpAtomicSMax:
+  case spv::OpAtomicUMax:
+  case spv::OpAtomicAnd:
+  case spv::OpAtomicOr:
+  case spv::OpAtomicXor:
+  case spv::OpAtomicFlagTestAndSet:
+  case spv::OpAtomicFMinEXT:
+  case spv::OpAtomicFMaxEXT:
+  case spv::OpAtomicFAddEXT:
+  case spv::OpImageTexelPointer:
+  case spv::OpCopyObject:
+  case spv::OpExpectKHR:
+  case spv::OpBitcast:
+  case spv::OpCopyLogical:
+    // Add access chains before all loads to access the right block member.
+    maybe_replace_with_access_chain(op.args[2]);
+    break;
+
+  case spv::OpCopyMemory:
+  case spv::OpCopyMemorySized:
+    maybe_replace_with_access_chain(op.args[1]);
+    // fall through
+  case spv::OpStore:
+  case spv::OpAtomicStore:
+  case spv::OpAtomicFlagClear:
+  case spv::OpReturnValue:
+    maybe_replace_with_access_chain(op.args[0]);
+    break;
+
+  default:
+    return SpirVTransformPass::transform_function_op(op);
+  }
+
+  return true;
+}
+
+/**
+ * Replaces the given id with the new access chain if the given id was one of
+ * the ids that was added to the block.
+ */
+bool SpirVMakeBlockPass::
+maybe_replace_with_access_chain(uint32_t &id) {
+  auto it = _member_indices.find(id);
+  if (it != _member_indices.end()) {
+    id = op_access_chain(_block_var_id, {define_int_constant(it->second)});
+    return true;
+  }
+  return false;
+}

+ 46 - 0
panda/src/shaderpipeline/spirVMakeBlockPass.h

@@ -0,0 +1,46 @@
+/**
+ * PANDA 3D SOFTWARE
+ * Copyright (c) Carnegie Mellon University.  All rights reserved.
+ *
+ * All use of this software is subject to the terms of the revised BSD
+ * license.  You should have received a copy of this license along
+ * with this source code in a file named "LICENSE."
+ *
+ * @file spirVMakeBlockPass.h
+ * @author rdb
+ * @date 2024-10-11
+ */
+
+#ifndef SPIRVMAKEBLOCKPASS_H
+#define SPIRVMAKEBLOCKPASS_H
+
+#include "spirVTransformPass.h"
+
+/**
+ * Creates a new uniform (or push constant) block using the parameters specified
+ * by the given ids and types.  This is the opposite of SpirVFlattenStructPass.
+ */
+class EXPCL_PANDA_SHADERPIPELINE SpirVMakeBlockPass final : public SpirVTransformPass {
+public:
+  SpirVMakeBlockPass(const ShaderType::Struct *block_type, const pvector<uint32_t> &member_ids,
+                     spv::StorageClass storage_class, uint32_t binding=0, uint32_t set=0);
+
+  virtual bool begin_function(Instruction op);
+  virtual bool transform_function_op(Instruction op);
+
+  bool maybe_replace_with_access_chain(uint32_t &id);
+
+private:
+  const ShaderType::Struct *_block_type;
+  const spv::StorageClass _storage_class;
+  const uint32_t _binding;
+  const uint32_t _set;
+
+  // Map from id to index of the member in the struct.
+  pmap<uint32_t, uint32_t> _member_indices;
+
+public:
+  uint32_t _block_var_id = 0;
+};
+
+#endif

+ 2 - 2
panda/src/shaderpipeline/spirVReplaceVariableTypePass.cxx

@@ -63,7 +63,7 @@ transform_definition_op(Instruction op) {
  *
  */
 bool SpirVReplaceVariableTypePass::
-transform_function_op(Instruction op, uint32_t function_id) {
+transform_function_op(Instruction op) {
   switch (op.opcode) {
   case spv::OpLoad:
   case spv::OpAtomicLoad:
@@ -135,7 +135,7 @@ transform_function_op(Instruction op, uint32_t function_id) {
     break;
 
   default:
-    return SpirVTransformPass::transform_function_op(op, function_id);
+    return SpirVTransformPass::transform_function_op(op);
   }
 
   return true;

+ 1 - 1
panda/src/shaderpipeline/spirVReplaceVariableTypePass.h

@@ -27,7 +27,7 @@ public:
                                spv::StorageClass storage_class);
 
   virtual bool transform_definition_op(Instruction op);
-  virtual bool transform_function_op(Instruction op, uint32_t function_id);
+  virtual bool transform_function_op(Instruction op);
 
 private:
   const uint32_t _variable_id;

+ 8 - 0
panda/src/shaderpipeline/spirVResultDatabase.I

@@ -51,6 +51,14 @@ is_constant() const {
   return _dtype == DT_constant;
 }
 
+/**
+ * Returns true if this is specifically a null constant.
+ */
+INLINE bool SpirVResultDatabase::Definition::
+is_null_constant() const {
+  return _dtype == DT_constant && (_flags & DF_null_constant) != 0;
+}
+
 /**
  * Returns true if this is a specialization constant.
  */

+ 33 - 2
panda/src/shaderpipeline/spirVResultDatabase.cxx

@@ -120,7 +120,7 @@ modify_definition(uint32_t id) {
  * encountered definitions are recorded in the definitions vector.
  */
 void SpirVResultDatabase::
-parse_instruction(spv::Op opcode, uint32_t *args, uint32_t nargs, uint32_t &current_function_id) {
+parse_instruction(spv::Op opcode, const uint32_t *args, uint32_t nargs, uint32_t &current_function_id) {
   switch (opcode) {
   case spv::OpExtInstImport:
     record_ext_inst_import(args[0], (const char*)&args[1]);
@@ -195,6 +195,7 @@ parse_instruction(spv::Op opcode, uint32_t *args, uint32_t nargs, uint32_t &curr
       uint32_t component_count = args[2];
       record_type(args[0], ShaderType::register_type(
         ShaderType::Vector(element_type->get_scalar_type(), component_count)));
+      _defs[args[0]]._type_id = args[1];
     }
     break;
 
@@ -205,6 +206,7 @@ parse_instruction(spv::Op opcode, uint32_t *args, uint32_t nargs, uint32_t &curr
       uint32_t num_rows = args[2];
       record_type(args[0], ShaderType::register_type(
         ShaderType::Matrix(column_type->get_scalar_type(), num_rows, column_type->get_num_components())));
+      _defs[args[0]]._type_id = args[1];
     }
     break;
 
@@ -349,6 +351,7 @@ parse_instruction(spv::Op opcode, uint32_t *args, uint32_t nargs, uint32_t &curr
       record_type(args[0], ShaderType::register_type(
         ShaderType::Array(_defs[args[1]]._type, 0)));
     }
+    _defs[args[0]]._type_id = args[1];
     break;
 
   case spv::OpTypeStruct:
@@ -868,8 +871,15 @@ find_pointer_type(const ShaderType *type, spv::StorageClass storage_class) {
   if (tit == _type_map.end()) {
     return 0;
   }
-  uint32_t type_id = tit->second;
+  return find_pointer_type(tit->second, storage_class);
+}
 
+/**
+ * Searches for an already-defined type pointer of the given storage class.
+ * Returns its id, or 0 if it was not found.
+ */
+uint32_t SpirVResultDatabase::
+find_pointer_type(uint32_t type_id, spv::StorageClass storage_class) {
   for (uint32_t id = 0; id < _defs.size(); ++id) {
     Definition &def = _defs[id];
     if (def._dtype == DT_pointer_type &&
@@ -881,6 +891,23 @@ find_pointer_type(const ShaderType *type, spv::StorageClass storage_class) {
   return 0;
 }
 
+/**
+ * Searches for an already-defined null constant of the given type.
+ * Returns its id, or 0 if it was not found.
+ */
+uint32_t SpirVResultDatabase::
+find_null_constant(uint32_t type_id) {
+  for (uint32_t id = 0; id < _defs.size(); ++id) {
+    Definition &def = _defs[id];
+    if (def._dtype == DT_constant &&
+        (def._flags & DF_null_constant) != 0 &&
+        def._type_id == type_id) {
+      return id;
+    }
+  }
+  return 0;
+}
+
 /**
  * Records that the given type has been defined.
  */
@@ -1062,6 +1089,10 @@ record_constant(uint32_t id, uint32_t type_id, const uint32_t *words, uint32_t n
   def._type = (type_def._dtype == DT_type) ? type_def._type : nullptr;
   def._constant = (nwords > 0) ? words[0] : 0;
   def._flags |= DF_constant_expression;
+
+  if (words == nullptr) {
+    def._flags |= DF_null_constant;
+  }
 }
 
 /**

+ 8 - 2
panda/src/shaderpipeline/spirVResultDatabase.h

@@ -46,6 +46,8 @@ private:
     DF_non_readable = 256, // writeonly
 
     DF_relaxed_precision = 512,
+
+    DF_null_constant = 1024,
   };
 
 public:
@@ -91,6 +93,7 @@ public:
     INLINE bool is_variable() const;
     INLINE bool is_function_parameter() const;
     INLINE bool is_constant() const;
+    INLINE bool is_null_constant() const;
     INLINE bool is_spec_constant() const;
     INLINE bool is_function() const;
     INLINE bool is_ext_inst() const;
@@ -105,16 +108,19 @@ public:
     MemberDefinition &modify_member(uint32_t i);
     void clear();
   };
-  typedef pvector<Definition> Definitions;
+  typedef pdeque<Definition> Definitions;
 
   uint32_t find_definition(const std::string &name) const;
   const Definition &get_definition(uint32_t id) const;
   Definition &modify_definition(uint32_t id);
 
-  void parse_instruction(spv::Op opcode, uint32_t *args, uint32_t nargs, uint32_t &current_function_id);
+  void parse_instruction(spv::Op opcode, const uint32_t *args, uint32_t nargs,
+                         uint32_t &current_function_id);
 
   uint32_t find_type(const ShaderType *type);
   uint32_t find_pointer_type(const ShaderType *type, spv::StorageClass storage_class);
+  uint32_t find_pointer_type(uint32_t type_id, spv::StorageClass storage_class);
+  uint32_t find_null_constant(uint32_t type_id);
 
   void record_type(uint32_t id, const ShaderType *type);
   void record_pointer_type(uint32_t id, spv::StorageClass storage_class, uint32_t type_id);

+ 23 - 0
panda/src/shaderpipeline/spirVTransformPass.I

@@ -109,6 +109,29 @@ is_member_deleted(uint32_t type_id, uint32_t member) const {
   return false;
 }
 
+/**
+ * Inserts a decoration instruction taking no arguments.
+ */
+INLINE void SpirVTransformPass::
+decorate(uint32_t id, spv::Decoration decoration) {
+  nassertv(decoration != spv::DecorationLocation);
+  _new_annotations.insert(_new_annotations.end(),
+    {spv::OpDecorate | (3 << spv::WordCountShift), id, decoration});
+}
+
+/**
+ * Inserts a decoration instruction taking one argument.
+ */
+INLINE void SpirVTransformPass::
+decorate(uint32_t id, spv::Decoration decoration, uint32_t value) {
+  _new_annotations.insert(_new_annotations.end(),
+    {spv::OpDecorate | (4 << spv::WordCountShift), id, decoration, value});
+
+  if (decoration == spv::DecorationLocation) {
+    _db.modify_definition(id)._location = (int)value;
+  }
+}
+
 /**
  * Returns true if the given id has already been defined during this pass.
  */

+ 257 - 42
panda/src/shaderpipeline/spirVTransformPass.cxx

@@ -21,7 +21,8 @@ SpirVTransformPass() {
 }
 
 /**
- *
+ * Processes the header and all instructions, including the debug instructions,
+ * up to the first annotation instruction.
  */
 void SpirVTransformPass::
 process_preamble(std::vector<uint32_t> &stream) {
@@ -69,7 +70,8 @@ process_preamble(std::vector<uint32_t> &stream) {
 }
 
 /**
- *
+ * Processes the instructions of the annotations section, which contains the
+ * decorations.
  */
 void SpirVTransformPass::
 process_annotations(std::vector<uint32_t> &stream) {
@@ -92,7 +94,9 @@ process_annotations(std::vector<uint32_t> &stream) {
 }
 
 /**
- *
+ * Processes the instructions of the definitions section, which starts with the
+ * first non-annotation instruction and contains all of the types, constants,
+ * and global variables.
  */
 void SpirVTransformPass::
 process_definitions(std::vector<uint32_t> &stream) {
@@ -125,7 +129,8 @@ process_definitions(std::vector<uint32_t> &stream) {
 }
 
 /**
- *
+ * Processes the instructions of the function section, which starts with the
+ * first OpFunction and contains the remainder of the module.
  */
 void SpirVTransformPass::
 process_functions(std::vector<uint32_t> &stream) {
@@ -137,6 +142,7 @@ process_functions(std::vector<uint32_t> &stream) {
     if (op.opcode == spv::OpFunction) {
       if (begin_function(op)) {
         uint32_t function_id = op.args[1];
+        _current_function_id = function_id;
         _new_functions.insert(_new_functions.end(), it._words, it.next()._words);
 
         ++it;
@@ -148,7 +154,7 @@ process_functions(std::vector<uint32_t> &stream) {
           bool has_result, has_type;
           HasResultAndType(op.opcode, &has_result, &has_type);
           if (!has_result || (!is_defined(op.args[has_type]) && !is_deleted(op.args[has_type])))  {
-            if (transform_function_op(op, function_id)) {
+            if (transform_function_op(op)) {
               if (has_result) {
                 mark_defined(op.args[has_type]);
               }
@@ -160,6 +166,7 @@ process_functions(std::vector<uint32_t> &stream) {
 
         if (it != end) {
           end_function(function_id);
+          _current_function_id = 0;
           _new_functions.insert(_new_functions.end(), {spv::OpFunctionEnd | (1 << spv::WordCountShift)});
         } else {
           shader_cat.error()
@@ -178,26 +185,17 @@ process_functions(std::vector<uint32_t> &stream) {
 }
 
 /**
- *
+ * Called before any of the instructions are read.  Perform any pre-processing
+ * based on the result database and the input arguments here.
  */
 void SpirVTransformPass::
 preprocess() {
 }
 
 /**
- *
- */
-ShaderModuleSpirV::InstructionStream SpirVTransformPass::
-get_result() const {
-  InstructionStream stream(_new_preamble);
-  stream._words.insert(stream._words.end(), _new_annotations.begin(), _new_annotations.end());
-  stream._words.insert(stream._words.end(), _new_definitions.begin(), _new_definitions.end());
-  stream._words.insert(stream._words.end(), _new_functions.begin(), _new_functions.end());
-  return stream;
-}
-
-/**
- *
+ * Transforms a debug instruction (OpName or OpMemberName).
+ * Return true to preserve the instruction, false to omit it (in which case you
+ * may replace it using add_debug).
  */
 bool SpirVTransformPass::
 transform_debug_op(Instruction op) {
@@ -208,7 +206,9 @@ transform_debug_op(Instruction op) {
 }
 
 /**
- *
+ * Transforms an annotation instruction.
+ * Return true to preserve the instruction, false to omit it (in which case you
+ * may replace it using add_annotation).
  */
 bool SpirVTransformPass::
 transform_annotation_op(Instruction op) {
@@ -219,7 +219,9 @@ transform_annotation_op(Instruction op) {
 }
 
 /**
- *
+ * Transforms a definition instruction (a type, constant or global variable).
+ * Return true to preserve the instruction, false to omit it (in which case you
+ * may replace it using add_definition).
  */
 bool SpirVTransformPass::
 transform_definition_op(Instruction op) {
@@ -248,7 +250,6 @@ transform_definition_op(Instruction op) {
           }
         }
         add_definition(spv::OpTypeFunction, new_args.data(), new_args.size());
-        mark_defined(op.args[0]);
         return false;
       }
     }
@@ -261,7 +262,10 @@ transform_definition_op(Instruction op) {
 }
 
 /**
- *
+ * Called when an OpFunction is encountered.  Return true to preserve the
+ * function, false to skip all instructions up to the next OpFunctionEnd (in
+ * which case end_function() will not be called either).
+ * It is permitted to modify the arguments of the given op.
  */
 bool SpirVTransformPass::
 begin_function(Instruction op) {
@@ -269,10 +273,12 @@ begin_function(Instruction op) {
 }
 
 /**
- *
+ * Transforms an instruction encountered inside a function.  This will always
+ * be called between begin_function() and end_function() and will be passed the
+ * result identifier of the previous OpFunction.
  */
 bool SpirVTransformPass::
-transform_function_op(Instruction op, uint32_t function_id) {
+transform_function_op(Instruction op) {
   switch (op.opcode) {
   case spv::OpLoad:
   case spv::OpAtomicLoad:
@@ -346,7 +352,6 @@ transform_function_op(Instruction op, uint32_t function_id) {
           }
         }
         add_instruction(spv::OpFunctionCall, new_args.data(), new_args.size());
-        mark_defined(new_args[1]);
         return false;
       }
     }
@@ -364,14 +369,17 @@ transform_function_op(Instruction op, uint32_t function_id) {
 }
 
 /**
- *
+ * Called when an OpFunctionEnd instruction is encountered, belonging to an
+ * OpFunction with the given identifier.
  */
 void SpirVTransformPass::
 end_function(uint32_t function_id) {
 }
 
 /**
- *
+ * Called after all instructions have been read, this does any post-processing
+ * needed (such as updating the result database to reflect the transformations,
+ * adding names/decorations, etc.)
  */
 void SpirVTransformPass::
 postprocess() {
@@ -381,7 +389,26 @@ postprocess() {
  * Writes a name for the given id.
  */
 void SpirVTransformPass::
-add_name(uint32_t id, const std::string &name) {
+set_name(uint32_t id, const std::string &name) {
+  Definition &def = _db.modify_definition(id);
+  if (!def._name.empty()) {
+    // Remove the existing name.
+    auto it = _new_preamble.begin() + 5;
+    while (it != _new_preamble.end()) {
+      spv::Op opcode = (spv::Op)(*it & spv::OpCodeMask);
+      uint32_t wcount = *it >> spv::WordCountShift;
+      nassertd(wcount > 0) break;
+
+      if (wcount >= 2 && opcode == spv::OpName && *(it + 1) == id) {
+        it = _new_preamble.erase(it, it + wcount);
+        continue;
+      }
+
+      std::advance(it, wcount);
+    }
+  }
+  def._name = name;
+
   uint32_t nargs = 2 + name.size() / 4;
   uint32_t *args = (uint32_t *)alloca(nargs * 4);
   memset(args, 0, nargs * 4);
@@ -390,6 +417,39 @@ add_name(uint32_t id, const std::string &name) {
   add_debug(spv::OpName, args, nargs);
 }
 
+/**
+ * Writes a name for the given struct member.
+ */
+void SpirVTransformPass::
+set_member_name(uint32_t type_id, uint32_t member_index, const std::string &name) {
+  MemberDefinition &mdef = _db.modify_definition(type_id).modify_member(member_index);
+  if (!mdef._name.empty()) {
+    // Remove the existing name.
+    auto it = _new_preamble.begin();
+    while (it != _new_preamble.end()) {
+      spv::Op opcode = (spv::Op)(*it & spv::OpCodeMask);
+      uint32_t wcount = *it >> spv::WordCountShift;
+      nassertd(wcount > 0) break;
+
+      if (wcount >= 3 && opcode == spv::OpMemberName && *(it + 1) == type_id && *(it + 2) == member_index) {
+        it = _new_preamble.erase(it, it + wcount);
+        continue;
+      }
+
+      std::advance(it, wcount);
+    }
+  }
+  mdef._name = name;
+
+  uint32_t nargs = 3 + name.size() / 4;
+  uint32_t *args = (uint32_t *)alloca(nargs * 4);
+  memset(args, 0, nargs * 4);
+  args[0] = type_id;
+  args[1] = member_index;
+  memcpy((char *)(args + 2), name.data(), name.size());
+  add_debug(spv::OpMemberName, args, nargs);
+}
+
 /**
  * Deletes the given identifier, and any annotations for it.
  */
@@ -455,7 +515,8 @@ delete_id(uint32_t id) {
 
 /**
  * Deletes the annotations for the given struct member (using the pre-transform
- * struct index numbering).
+ * struct index numbering).  Does not update the actual OpTypeStruct args, any
+ * access chains, etc.
  */
 void SpirVTransformPass::
 delete_struct_member(uint32_t id, uint32_t member_index) {
@@ -518,7 +579,9 @@ delete_struct_member(uint32_t id, uint32_t member_index) {
 }
 
 /**
- * Deletes the given parameter of the given function type.
+ * Deletes the given parameter of the given function type.  Should be called
+ * before the OpTypeFunction is encountered, or the OpTypeFunction should have
+ * already been modified to remove this parameter.
  */
 void SpirVTransformPass::
 delete_function_parameter(uint32_t type_id, uint32_t param_index) {
@@ -537,7 +600,8 @@ delete_function_parameter(uint32_t type_id, uint32_t param_index) {
 }
 
 /**
- *
+ * Adds a new variable definition to the definitions section.  Inserts any type
+ * declarations and annotations that may be necessary.
  */
 uint32_t SpirVTransformPass::
 define_variable(const ShaderType *type, spv::StorageClass storage_class) {
@@ -586,9 +650,8 @@ define_pointer_type(const ShaderType *type, spv::StorageClass storage_class) {
 }
 
 /**
- * Helper for define_type.  Inserts the given type (after any requisite
- * dependent types, as found through the given type map) at the given iterator,
- * and advances the iterator.
+ * Ensures that the given type is defined by adding instructions to the
+ * definitions section as necessary.
  */
 uint32_t SpirVTransformPass::
 define_type(const ShaderType *type) {
@@ -804,7 +867,8 @@ define_type(const ShaderType *type) {
 
 /**
  * Defines a new integral constant, either of type uint or int, reusing an
- * existing one one is already defined.
+ * existing one if one is already defined (except for OpConstantNull, which
+ * can't be used to index structure members).
  */
 uint32_t SpirVTransformPass::
 define_int_constant(int32_t constant) {
@@ -813,7 +877,7 @@ define_int_constant(int32_t constant) {
 
   for (uint32_t id = 0; id < get_id_bound(); ++id) {
     const Definition &def = _db.get_definition(id);
-    if (def.is_constant() &&
+    if (def.is_constant() && !def.is_null_constant() &&
         def._constant == (uint32_t)constant &&
         (def._type == ShaderType::int_type || (constant >= 0 && def._type == ShaderType::uint_type))) {
       if (is_defined(id)) {
@@ -836,7 +900,36 @@ define_int_constant(int32_t constant) {
 }
 
 /**
- * Defines a new constant.
+ * Defines a new null constant of the given type, reusing an existing one if
+ * one is already defined.
+ */
+uint32_t SpirVTransformPass::
+define_null_constant(const ShaderType *type) {
+  uint32_t constant_id = 0;
+  uint32_t type_id = define_type(type);
+
+  for (uint32_t id = 0; id < get_id_bound(); ++id) {
+    const Definition &def = _db.get_definition(id);
+    if (def.is_null_constant() && def._type_id == type_id) {
+      if (is_defined(id)) {
+        return id;
+      }
+      constant_id = id;
+    }
+  }
+
+  if (constant_id == 0) {
+    constant_id = allocate_id();
+  }
+
+  add_definition(spv::OpConstantNull, {type_id, constant_id});
+
+  _db.record_constant(constant_id, type_id, nullptr, 0);
+  return constant_id;
+}
+
+/**
+ * Defines a new constant.  Does not attempt to reuse constants.
  */
 uint32_t SpirVTransformPass::
 define_constant(const ShaderType *type, uint32_t constant) {
@@ -898,10 +991,11 @@ r_annotate_struct_layout(uint32_t type_id) {
       // Also make sure there's an ArrayStride decoration for this array.
       uint32_t array_type_id = _db.find_type(array_type);
 
-      if (def._array_stride == 0) {
-        def._array_stride = array_type->get_stride_bytes();
+      Definition &array_def = _db.modify_definition(array_type_id);
+      if (array_def._array_stride == 0) {
+        array_def._array_stride = array_type->get_stride_bytes();
         add_annotation(spv::OpDecorate,
-          {array_type_id, spv::DecorationArrayStride, def._array_stride});
+          {array_type_id, spv::DecorationArrayStride, array_def._array_stride});
       }
     }
 
@@ -933,7 +1027,8 @@ add_definition(spv::Op opcode, const uint32_t *args, uint16_t nargs) {
 }
 
 /**
- * Adds an instruction to the current function.
+ * Adds an instruction to the current function.  May only be called from
+ * transform_function_op.
  */
 void SpirVTransformPass::
 add_instruction(spv::Op opcode, const uint32_t *args, uint16_t nargs) {
@@ -946,3 +1041,123 @@ add_instruction(spv::Op opcode, const uint32_t *args, uint16_t nargs) {
   _new_functions.push_back(((nargs + 1) << spv::WordCountShift) | opcode);
   _new_functions.insert(_new_functions.end(), args, args + nargs);
 }
+
+/**
+ * Inserts an OpLoad from the given pointer id.
+ */
+uint32_t SpirVTransformPass::
+op_load(uint32_t var_id, spv::MemoryAccessMask access) {
+  const Definition &var_def = _db.get_definition(var_id);
+  uint32_t type_id = unwrap_pointer_type(var_def._type_id);
+
+  uint32_t id = allocate_id();
+  if (access != spv::MemoryAccessMaskNone) {
+    _new_functions.insert(_new_functions.end(),
+      {(5 << spv::WordCountShift) | spv::OpLoad, type_id, id, var_id, (uint32_t)access});
+  } else {
+    _new_functions.insert(_new_functions.end(),
+      {(4 << spv::WordCountShift) | spv::OpLoad, type_id, id, var_id});
+  }
+
+  _db.record_temporary(id, type_id, var_id, _current_function_id);
+
+  // A load from the pointer is enough for us to consider it "used", for now.
+  mark_used(id);
+  mark_defined(id);
+  return id;
+}
+
+/**
+ * Inserts an OpSelect.
+ */
+uint32_t SpirVTransformPass::
+op_select(uint32_t cond, uint32_t obj1, uint32_t obj2) {
+  const Definition &obj1_def = _db.get_definition(obj1);
+  const Definition &obj2_def = _db.get_definition(obj2);
+  nassertr(obj1_def._type_id == obj2_def._type_id, 0);
+
+  uint32_t id = allocate_id();
+  _new_functions.insert(_new_functions.end(), {(6 << spv::WordCountShift) | spv::OpSelect, obj1_def._type_id, id, cond, obj1, obj2});
+
+  mark_used(obj1);
+  mark_used(obj2);
+  mark_defined(id);
+  return id;
+}
+
+/**
+ * Inserts an OpAccessChain with the given base id (which must be a pointer)
+ * and constant ids containing the various member/array indices.
+ */
+uint32_t SpirVTransformPass::
+op_access_chain(uint32_t var_id, std::initializer_list<uint32_t> chain) {
+  const Definition &var_def = _db.get_definition(var_id);
+  const Definition &var_type_def = _db.get_definition(var_def._type_id);
+  nassertr(var_type_def.is_pointer_type(), 0);
+  spv::StorageClass storage_class = var_type_def._storage_class;
+
+  uint32_t type_id = var_type_def._type_id;
+  for (auto index_id : chain) {
+    const Definition &type_def = _db.get_definition(type_id);
+    nassertr(type_def.is_type(), 0);
+
+    if (!type_def._members.empty()) {
+      uint32_t member_index = resolve_constant(index_id);
+      nassertr((size_t)member_index < type_def._members.size(), 0);
+      type_id = type_def._members[member_index]._type_id;
+    } else {
+      // Array, matrix, or vector
+      type_id = type_def._type_id;
+    }
+    nassertr(type_id != 0, 0);
+  }
+
+  uint32_t pointer_type_id = _db.find_pointer_type(type_id, storage_class);
+  if (pointer_type_id == 0) {
+    pointer_type_id = allocate_id();
+    _db.record_pointer_type(pointer_type_id, storage_class, type_id);
+
+    add_definition(spv::OpTypePointer,
+      {pointer_type_id, (uint32_t)storage_class, type_id});
+  }
+
+  uint32_t id = allocate_id();
+  _new_functions.insert(_new_functions.end(), {((4 + (uint32_t)chain.size()) << spv::WordCountShift) | spv::OpAccessChain, pointer_type_id, id, var_id});
+  _new_functions.insert(_new_functions.end(), chain);
+
+  _db.record_temporary(id, pointer_type_id, var_id, _current_function_id);
+  return id;
+}
+
+/**
+ * Inserts an OpCompositeExtract.
+ */
+uint32_t SpirVTransformPass::
+op_composite_extract(uint32_t obj_id, std::initializer_list<uint32_t> chain) {
+  const Definition &obj_def = _db.get_definition(obj_id);
+
+  uint32_t type_id = obj_def._type_id;
+  for (auto index : chain) {
+    const Definition &type_def = _db.get_definition(type_id);
+    nassertr(type_def.is_type() && !type_def.is_pointer_type(), 0);
+
+    if (!type_def._members.empty()) {
+      nassertr((size_t)index < type_def._members.size(), 0);
+      type_id = type_def._members[index]._type_id;
+    } else {
+      // Array, matrix, or vector
+      type_id = type_def._type_id;
+    }
+    nassertr(type_id != 0, 0);
+  }
+
+  uint32_t id = allocate_id();
+  _new_functions.insert(_new_functions.end(), {((4 + (uint32_t)chain.size()) << spv::WordCountShift) | spv::OpCompositeExtract, type_id, id, obj_id});
+  _new_functions.insert(_new_functions.end(), chain);
+
+  Definition &def = _db.modify_definition(id);
+  def._type_id = type_id;
+
+  mark_defined(id);
+  return id;
+}

+ 14 - 3
panda/src/shaderpipeline/spirVTransformPass.h

@@ -37,14 +37,13 @@ public:
   void process_annotations(std::vector<uint32_t> &instructions);
   void process_definitions(std::vector<uint32_t> &instructions);
   void process_functions(std::vector<uint32_t> &instructions);
-  InstructionStream get_result() const;
 
   virtual void preprocess();
   virtual bool transform_debug_op(Instruction op);
   virtual bool transform_annotation_op(Instruction op);
   virtual bool transform_definition_op(Instruction op);
   virtual bool begin_function(Instruction op);
-  virtual bool transform_function_op(Instruction op, uint32_t function_id);
+  virtual bool transform_function_op(Instruction op);
   virtual void end_function(uint32_t function_id);
   virtual void postprocess();
 
@@ -58,7 +57,8 @@ public:
   INLINE uint32_t get_id_bound() const;
   INLINE uint32_t allocate_id();
 
-  void add_name(uint32_t id, const std::string &name);
+  void set_name(uint32_t id, const std::string &name);
+  void set_member_name(uint32_t type_id, uint32_t member_index, const std::string &name);
 
   void delete_id(uint32_t id);
   void delete_struct_member(uint32_t id, uint32_t member_index);
@@ -67,10 +67,14 @@ public:
   INLINE bool is_deleted(uint32_t id) const;
   INLINE bool is_member_deleted(uint32_t id, uint32_t member) const;
 
+  INLINE void decorate(uint32_t id, spv::Decoration decoration);
+  INLINE void decorate(uint32_t id, spv::Decoration decoration, uint32_t value);
+
   uint32_t define_variable(const ShaderType *type, spv::StorageClass storage_class);
   uint32_t define_pointer_type(const ShaderType *type, spv::StorageClass storage_class);
   uint32_t define_type(const ShaderType *type);
   uint32_t define_int_constant(int32_t constant);
+  uint32_t define_null_constant(const ShaderType *type);
   uint32_t define_constant(const ShaderType *type, uint32_t constant);
 
   /**
@@ -117,12 +121,19 @@ protected:
   INLINE void add_instruction(spv::Op opcode, std::initializer_list<uint32_t> args);
   void add_instruction(spv::Op opcode, const uint32_t *args, uint16_t nargs);
 
+  // Functions for putting specific instructions in the functions block.
+  uint32_t op_load(uint32_t var_id, spv::MemoryAccessMask access = spv::MemoryAccessMaskNone);
+  uint32_t op_select(uint32_t cond, uint32_t obj1, uint32_t obj2);
+  uint32_t op_access_chain(uint32_t var_id, std::initializer_list<uint32_t>);
+  uint32_t op_composite_extract(uint32_t obj_id, std::initializer_list<uint32_t>);
+
   // The module is split into sections to make it easier to add instructions
   // to other sections while we are iterating.
   std::vector<uint32_t> _new_preamble;
   std::vector<uint32_t> _new_annotations;
   std::vector<uint32_t> _new_definitions;
   std::vector<uint32_t> _new_functions;
+  uint32_t _current_function_id = 0;
 
   // Keeps track of what has been defined and deleted during this pass.
   BitArray _defined;

+ 19 - 254
panda/src/shaderpipeline/spirVTransformer.cxx

@@ -17,14 +17,15 @@
  * Constructs an instruction writer to operate on the given instruction stream.
  */
 SpirVTransformer::
-SpirVTransformer(InstructionStream &stream) {
+SpirVTransformer(const InstructionStream &stream) {
   _db.modify_definition(stream.get_id_bound() - 1);
 
   uint32_t current_function_id = 0;
 
-  InstructionIterator begin = stream.begin();
+  InstructionIterator begin = ((InstructionStream &)stream).begin();
+  InstructionIterator end = ((InstructionStream &)stream).end();
   InstructionIterator it = begin;
-  while (it != stream.end()) {
+  while (it != end) {
     Instruction op = *it;
     if (op.opcode != spv::OpNop &&
         op.opcode != spv::OpCapability &&
@@ -46,10 +47,10 @@ SpirVTransformer(InstructionStream &stream) {
     _db.parse_instruction(op.opcode, op.args, op.nargs, current_function_id);
     ++it;
   }
-  _preamble = std::vector<uint32_t>(stream._words.data(), it._words);
+  _preamble = std::vector<uint32_t>((const uint32_t *)stream._words.data(), (const uint32_t *)it._words);
 
   begin = it;
-  while (it != stream.end()) {
+  while (it != end) {
     Instruction op = *it;
     if (!op.is_annotation()) {
       break;
@@ -61,7 +62,7 @@ SpirVTransformer(InstructionStream &stream) {
   _annotations = std::vector<uint32_t>(begin._words, it._words);
 
   begin = it;
-  while (it != stream.end()) {
+  while (it != end) {
     Instruction op = *it;
     if (op.opcode == spv::OpFunction) {
       break;
@@ -73,7 +74,7 @@ SpirVTransformer(InstructionStream &stream) {
   _definitions = std::vector<uint32_t>(begin._words, it._words);
 
   begin = it;
-  while (it != stream.end()) {
+  while (it != end) {
     Instruction op = *it;
     _db.parse_instruction(op.opcode, op.args, op.nargs, current_function_id);
     ++it;
@@ -82,6 +83,8 @@ SpirVTransformer(InstructionStream &stream) {
 }
 
 /**
+ * Runs the given transformation pass object (which can be used only once) on
+ * the module stored in the SpirVTransformer, and updates the database.
  */
 void SpirVTransformer::
 run(SpirVTransformPass &pass) {
@@ -277,25 +280,25 @@ assign_locations(pmap<uint32_t, int> remap) {
 }
 
 /**
- * Assign descriptor bindings for a descriptor set based on the given locations.
+ * Assign descriptor bindings for a descriptor set based on the given ids.
  * Assumes there are already binding and set decorations.
- * To create gaps in the descriptor set, entries in locations may be -1.
+ * To create gaps in the descriptor set, entries in ids may be 0.
  */
 void SpirVTransformer::
-bind_descriptor_set(uint32_t set, const vector_int &locations) {
+bind_descriptor_set(uint32_t set, const pvector<uint32_t> &ids) {
   InstructionIterator it(_annotations.data());
   InstructionIterator end(_annotations.data() + _annotations.size());
 
   while (it != end) {
     Instruction op = *it;
 
-    if (op.opcode == spv::OpDecorate && op.nargs >= 3) {
-      const Definition &def = _db.get_definition(op.args[0]);
-
-      auto lit = std::find(locations.begin(), locations.end(), def._location);
-      if (lit != locations.end() && def.has_location()) {
+    if (op.opcode == spv::OpDecorate && op.nargs >= 3 &&
+        (op.args[1] == spv::DecorationBinding ||
+         op.args[1] == spv::DecorationDescriptorSet)) {
+      auto iit = std::find(ids.begin(), ids.end(), op.args[0]);
+      if (iit != ids.end()) {
         if (op.args[1] == spv::DecorationBinding) {
-          op.args[2] = std::distance(locations.begin(), lit);
+          op.args[2] = std::distance(ids.begin(), iit);
         }
         else if (op.args[1] == spv::DecorationDescriptorSet) {
           op.args[2] = set;
@@ -306,241 +309,3 @@ bind_descriptor_set(uint32_t set, const vector_int &locations) {
     ++it;
   }
 }
-
-/**
- * Creates a new uniform block using the parameters specified by the given
- * locations and types.  The opposite of flatten_struct, if you will.
- */
-/*uint32_t SpirVTransformer::
-make_block(const ShaderType::Struct *block_type, const pvector<int> &member_locations,
-           spv::StorageClass storage_class, uint32_t binding, uint32_t set) {
-  nassertr(block_type->get_num_members() == member_locations.size(), false);
-
-  // Define block struct variable, which will implicitly define its type.
-  uint32_t block_var_id = define_variable(block_type, storage_class);
-  uint32_t block_type_id = _type_map[block_type];
-  nassertr(block_type_id != 0, 0);
-
-  // Collect type pointers that we have to create.
-  pvector<uint32_t> insert_pointer_types;
-
-  // Find the variables we should replace with members of this block by looking
-  // at the locations.  Collect a map of defined type pointers while we're at
-  // it, so we don't unnecessarily duplicate them.
-  pmap<uint32_t, uint32_t> member_indices;
-  pmap<uint32_t, uint32_t> pointer_type_map;
-
-  for (uint32_t id = 0; id < _defs.size(); ++id) {
-    Definition &def = _defs[id];
-    if (def.is_pointer_type()) {
-      if (!def.has_builtin() && def._storage_class == storage_class) {
-        // This is the storage class we need, store it in case we need it.
-        pointer_type_map[def._type_id] = id;
-      }
-    }
-    else if (def.is_variable() && def.has_location() &&
-             def._storage_class == spv::StorageClassUniformConstant) {
-
-      auto lit = std::find(member_locations.begin(), member_locations.end(), def._location);
-      if (lit != member_locations.end()) {
-        member_indices[id] = std::distance(member_locations.begin(), lit);
-      }
-    }
-  }
-
-  uint32_t num_members = member_locations.size();
-  uint32_t *allocation = (uint32_t *)alloca(num_members * sizeof(uint32_t) * 2);
-  memset(allocation, 0, num_members * sizeof(uint32_t) * 2);
-
-  uint32_t *member_type_ids = allocation;
-  uint32_t *member_constant_ids = allocation + num_members;
-
-  // Now add the decorations for the uniform block itself.
-  InstructionIterator it = _instructions.end_annotations();
-  it = _instructions.insert(it, spv::OpDecorate, {block_type_id, spv::DecorationBlock});
-  ++it;
-
-  if (storage_class != spv::StorageClassPushConstant) {
-    it = _instructions.insert(it, spv::OpDecorate, {block_var_id, spv::DecorationBinding, binding});
-    ++it;
-    it = _instructions.insert(it, spv::OpDecorate, {block_var_id, spv::DecorationDescriptorSet, set});
-    ++it;
-  }
-
-  it = _instructions.begin();
-  while (it != _instructions.end()) {
-    Instruction op = *it;
-
-    switch (op.opcode) {
-    case spv::OpName:
-      // Translate an OpName to an OpMemberName for vars that become struct
-      // members.  We could just strip them, but this is useful for debugging.
-      if (member_indices.count(op.args[0])) {
-        uint32_t member_index = member_indices[op.args[0]];
-
-        uint32_t nargs = op.nargs + 1;
-        uint32_t *args = (uint32_t *)alloca(nargs * sizeof(uint32_t));
-        args[0] = block_type_id;
-        args[1] = member_index;
-        memcpy(args + 2, op.args + 1, (op.nargs - 1) * sizeof(uint32_t));
-
-        it = _instructions.insert(it, spv::OpMemberName, args, nargs);
-        ++it;
-        it = _instructions.erase(it);
-        continue;
-      }
-      break;
-
-    case spv::OpMemberName:
-    case spv::OpDecorate:
-    case spv::OpDecorateId:
-    case spv::OpDecorateString:
-    case spv::OpMemberDecorate:
-    case spv::OpMemberDecorateString:
-      // Remove other annotations on the members.
-      if (op.nargs >= 1 && member_indices.count(op.args[0])) {
-        it = _instructions.erase(it);
-        continue;
-      }
-      break;
-
-    case spv::OpConstant:
-      // Store integer constants that are already defined in the file that may
-      // be useful for defining our struct indices.
-      if (op.args[2] < num_members &&
-          (_defs[op.args[0]]._type == ShaderType::int_type ||
-           _defs[op.args[0]]._type == ShaderType::uint_type)) {
-        member_constant_ids[op.args[2]] = op.args[1];
-      }
-      break;
-
-    case spv::OpVariable:
-      if (member_indices.count(op.args[1])) {
-        // Remove this variable.  We'll replace it with an access chain later.
-        uint32_t pointer_type_id = op.args[0];
-        uint32_t member_id = op.args[1];
-        uint32_t member_index = member_indices[member_id];
-
-        if (_defs[pointer_type_id]._storage_class != storage_class) {
-          // Get or create a type pointer with the correct storage class.
-          uint32_t type_id = _defs[pointer_type_id]._type_id;
-          auto tpi = pointer_type_map.find(type_id);
-          if (tpi != pointer_type_map.end()) {
-            pointer_type_id = tpi->second;
-          } else {
-            pointer_type_id = _instructions.allocate_id();
-            pointer_type_map[type_id] = pointer_type_id;
-            record_pointer_type(pointer_type_id, storage_class, type_id);
-
-            it = _instructions.insert(it, spv::OpTypePointer,
-              {pointer_type_id, (uint32_t)storage_class, type_id});
-            ++it;
-          }
-        }
-
-        member_type_ids[member_index] = pointer_type_id;
-
-        it = _instructions.erase(it);
-        continue;
-      }
-      break;
-
-    case spv::OpFunction:
-      // Before we get to the function section, make sure that all the
-      // remaining constants we need are defined.
-      for (uint32_t i =  0; i < num_members; ++i) {
-        uint32_t constant_id = member_constant_ids[i];
-        if (constant_id == 0) {
-          // Doesn't matter whether we pick uint or int, prefer whatever is
-          // already defined.
-          const ShaderType *type =
-            _type_map.count(ShaderType::uint_type)
-              ? ShaderType::uint_type
-              : ShaderType::int_type;
-          constant_id = r_define_constant(it, type, i);
-          member_constant_ids[i] = constant_id;
-        }
-      }
-      break;
-
-    case spv::OpAccessChain:
-    case spv::OpInBoundsAccessChain:
-      if (member_indices.count(op.args[2])) {
-        uint32_t member_index = member_indices[op.args[2]];
-        uint32_t constant_id = member_constant_ids[member_index];
-
-        // Get or create a type pointer with the correct storage class.
-        uint32_t type_id = _defs[op.args[0]]._type_id;
-        auto tpi = pointer_type_map.find(type_id);
-        uint32_t pointer_type_id;
-        if (tpi != pointer_type_map.end()) {
-          pointer_type_id = tpi->second;
-        } else {
-          pointer_type_id = _instructions.allocate_id();
-          pointer_type_map[type_id] = pointer_type_id;
-          record_pointer_type(pointer_type_id, storage_class, type_id);
-
-          // Can't create the type pointer immediately, since we're no longer
-          // in the type declaration block.  We'll add it at the end.
-          insert_pointer_types.push_back(pointer_type_id);
-        }
-        op.args[0] = pointer_type_id;
-
-        // Prepend our new block variable to the existing access chain.
-        op.args[2] = block_var_id;
-        it = _instructions.insert_arg(it, 3, constant_id);
-      }
-      break;
-
-    case spv::OpImageTexelPointer:
-    case spv::OpLoad:
-    case spv::OpCopyObject:
-    case spv::OpExpectKHR:
-      // Add access chains before all loads to access the right block member.
-      if (member_indices.count(op.args[2])) {
-        uint32_t member_index = member_indices[op.args[2]];
-        uint32_t type_id = member_type_ids[member_index];
-        uint32_t constant_id = member_constant_ids[member_index];
-        uint32_t chain_id = _instructions.allocate_id();
-
-        op.args[2] = chain_id;
-        it = _instructions.insert(it, spv::OpInBoundsAccessChain,
-          {type_id, chain_id, block_var_id, constant_id});
-        ++it;
-      }
-      break;
-
-    case spv::OpCopyMemory:
-    case spv::OpCopyMemorySized:
-      // Same as above, but these take the pointer in a different argument.
-      if (member_indices.count(op.args[1])) {
-        uint32_t member_index = member_indices[op.args[1]];
-        uint32_t type_id = member_type_ids[member_index];
-        uint32_t constant_id = member_constant_ids[member_index];
-        uint32_t chain_id = _instructions.allocate_id();
-
-        op.args[1] = chain_id;
-        it = _instructions.insert(it, spv::OpInBoundsAccessChain,
-          {type_id, chain_id, block_var_id, constant_id});
-        ++it;
-      }
-      break;
-
-    default:
-      break;
-    }
-
-    ++it;
-  }
-
-  it = _instructions.begin_functions();
-
-  // Insert all the type pointers for the access chains.
-  for (uint32_t id : insert_pointer_types) {
-    it = _instructions.insert(it, spv::OpTypePointer,
-      {id, (uint32_t)_defs[id]._storage_class, _defs[id]._type_id});
-    ++it;
-  }
-
-  return block_var_id;
-}*/

+ 2 - 5
panda/src/shaderpipeline/spirVTransformer.h

@@ -32,7 +32,7 @@ public:
   using InstructionStream = ShaderModuleSpirV::InstructionStream;
   using InstructionIterator = ShaderModuleSpirV::InstructionIterator;
 
-  SpirVTransformer(InstructionStream &stream);
+  SpirVTransformer(const InstructionStream &stream);
 
   void run(SpirVTransformPass &pass);
   INLINE void run(SpirVTransformPass &&pass);
@@ -44,10 +44,7 @@ public:
 
   void assign_locations(ShaderModule::Stage stage);
   void assign_locations(pmap<uint32_t, int> locations);
-  void bind_descriptor_set(uint32_t set, const vector_int &locations);
-
-  //uint32_t make_block(const ShaderType::Struct *block_type, const pvector<int> &locations,
-  //                    spv::StorageClass storage_class, uint32_t binding=0, uint32_t set=0);
+  void bind_descriptor_set(uint32_t set, const pvector<uint32_t> &ids);
 
 private:
   // Stores the module split into the different sections for easier

+ 4 - 3
tests/display/test_glsl_shader.py

@@ -99,14 +99,15 @@ def run_glsl_test(gsg, body, preamble="", inputs={}, version=420, exts=set(),
 
     use_compute = gsg.supports_compute_shaders and \
                   gsg.supports_buffer_texture and \
-                  gsg.has_extension('GL_ARB_shader_image_load_store')
-    if use_compute:
-        exts = exts | {'GL_ARB_compute_shader', 'GL_ARB_shader_image_load_store'}
+                  (gsg.supported_shader_capabilities & core.Shader.C_image_load_store) != 0
 
     missing_exts = sorted(ext for ext in exts if not gsg.has_extension(ext))
     if missing_exts:
         pytest.skip("missing extensions: " + ' '.join(missing_exts))
 
+    if use_compute:
+        exts = exts | {'GL_ARB_compute_shader', 'GL_ARB_shader_image_load_store'}
+
     extensions = ''
     for ext in exts:
         extensions += '#extension {ext} : require\n'.format(ext=ext)