Forráskód Böngészése

Updated spirv-tools.

Бранимир Караџић 1 napja
szülő
commit
e17b97c0fe
48 módosított fájl, 3186 hozzáadás és 2045 törlés
  1. 1 1
      3rdparty/spirv-tools/include/generated/build-version.inc
  2. 1503 1489
      3rdparty/spirv-tools/include/generated/core_tables_body.inc
  3. 3 0
      3rdparty/spirv-tools/include/generated/core_tables_header.inc
  4. 9 2
      3rdparty/spirv-tools/source/opcode.cpp
  5. 1 0
      3rdparty/spirv-tools/source/operand.cpp
  6. 7 37
      3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp
  7. 16 1
      3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp
  8. 142 0
      3rdparty/spirv-tools/source/opt/const_folding_rules.cpp
  9. 39 1
      3rdparty/spirv-tools/source/opt/convert_to_half_pass.cpp
  10. 3 0
      3rdparty/spirv-tools/source/opt/convert_to_half_pass.h
  11. 5 5
      3rdparty/spirv-tools/source/opt/eliminate_dead_members_pass.cpp
  12. 28 25
      3rdparty/spirv-tools/source/opt/folding_rules.cpp
  13. 1 0
      3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp
  14. 1 0
      3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp
  15. 1 0
      3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp
  16. 2 1
      3rdparty/spirv-tools/source/opt/reflect.h
  17. 20 2
      3rdparty/spirv-tools/source/opt/type_manager.cpp
  18. 28 0
      3rdparty/spirv-tools/source/opt/types.cpp
  19. 29 0
      3rdparty/spirv-tools/source/opt/types.h
  20. 1 1
      3rdparty/spirv-tools/source/val/validate.cpp
  21. 3 0
      3rdparty/spirv-tools/source/val/validate.h
  22. 86 7
      3rdparty/spirv-tools/source/val/validate_annotation.cpp
  23. 2 2
      3rdparty/spirv-tools/source/val/validate_atomics.cpp
  24. 1 2
      3rdparty/spirv-tools/source/val/validate_barriers.cpp
  25. 19 0
      3rdparty/spirv-tools/source/val/validate_builtins.cpp
  26. 1 0
      3rdparty/spirv-tools/source/val/validate_capability.cpp
  27. 32 0
      3rdparty/spirv-tools/source/val/validate_cfg.cpp
  28. 15 14
      3rdparty/spirv-tools/source/val/validate_composites.cpp
  29. 28 3
      3rdparty/spirv-tools/source/val/validate_constants.cpp
  30. 6 2
      3rdparty/spirv-tools/source/val/validate_conversion.cpp
  31. 131 17
      3rdparty/spirv-tools/source/val/validate_decorations.cpp
  32. 17 28
      3rdparty/spirv-tools/source/val/validate_extensions.cpp
  33. 2 2
      3rdparty/spirv-tools/source/val/validate_function.cpp
  34. 229 0
      3rdparty/spirv-tools/source/val/validate_group.cpp
  35. 3 0
      3rdparty/spirv-tools/source/val/validate_id.cpp
  36. 220 211
      3rdparty/spirv-tools/source/val/validate_image.cpp
  37. 7 4
      3rdparty/spirv-tools/source/val/validate_interfaces.cpp
  38. 15 0
      3rdparty/spirv-tools/source/val/validate_invalid_type.cpp
  39. 8 0
      3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp
  40. 4 0
      3rdparty/spirv-tools/source/val/validate_logicals.cpp
  41. 187 42
      3rdparty/spirv-tools/source/val/validate_memory.cpp
  42. 10 13
      3rdparty/spirv-tools/source/val/validate_ray_query.cpp
  43. 8 8
      3rdparty/spirv-tools/source/val/validate_ray_tracing.cpp
  44. 40 50
      3rdparty/spirv-tools/source/val/validate_ray_tracing_reorder.cpp
  45. 1 2
      3rdparty/spirv-tools/source/val/validate_tensor_layout.cpp
  46. 87 51
      3rdparty/spirv-tools/source/val/validate_type.cpp
  47. 155 21
      3rdparty/spirv-tools/source/val/validation_state.cpp
  48. 29 1
      3rdparty/spirv-tools/source/val/validation_state.h

+ 1 - 1
3rdparty/spirv-tools/include/generated/build-version.inc

@@ -1 +1 @@
-"v2026.1-dev", "SPIRV-Tools v2026.1-dev v2025.5-37-g6d9a94ac"
+"v2026.2-dev", "SPIRV-Tools v2026.2-dev v2026.1-9-g994cf90e"

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 1503 - 1489
3rdparty/spirv-tools/include/generated/core_tables_body.inc


+ 3 - 0
3rdparty/spirv-tools/include/generated/core_tables_header.inc

@@ -68,6 +68,7 @@ enum Extension : uint32_t {
   kSPV_ARM_tensors,
   kSPV_EXT_arithmetic_fence,
   kSPV_EXT_demote_to_helper_invocation,
+  kSPV_EXT_descriptor_heap,
   kSPV_EXT_descriptor_indexing,
   kSPV_EXT_float8,
   kSPV_EXT_fragment_fully_covered,
@@ -87,6 +88,7 @@ enum Extension : uint32_t {
   kSPV_EXT_shader_image_int64,
   kSPV_EXT_shader_invocation_reorder,
   kSPV_EXT_shader_stencil_export,
+  kSPV_EXT_shader_subgroup_partitioned,
   kSPV_EXT_shader_tile_image,
   kSPV_EXT_shader_viewport_index_layer,
   kSPV_GOOGLE_decorate_string,
@@ -196,6 +198,7 @@ enum Extension : uint32_t {
   kSPV_NV_geometry_shader_passthrough,
   kSPV_NV_linear_swept_spheres,
   kSPV_NV_mesh_shader,
+  kSPV_NV_push_constant_bank,
   kSPV_NV_raw_access_chains,
   kSPV_NV_ray_tracing,
   kSPV_NV_ray_tracing_motion_blur,

+ 9 - 2
3rdparty/spirv-tools/source/opcode.cpp

@@ -153,6 +153,7 @@ int32_t spvOpcodeIsConstant(const spv::Op opcode) {
     case spv::Op::OpSpecConstantArchitectureINTEL:
     case spv::Op::OpSpecConstantTargetINTEL:
     case spv::Op::OpSpecConstantCapabilitiesINTEL:
+    case spv::Op::OpConstantSizeOfEXT:
       return true;
     default:
       return false;
@@ -183,7 +184,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
     case spv::Op::OpTypeRuntimeArray:
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
       return true;
     default:
       return false;
@@ -198,8 +199,10 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) {
     case spv::Op::OpInBoundsAccessChain:
     case spv::Op::OpUntypedAccessChainKHR:
     case spv::Op::OpUntypedInBoundsAccessChainKHR:
+    case spv::Op::OpBufferPointerEXT:
     case spv::Op::OpFunctionParameter:
     case spv::Op::OpImageTexelPointer:
+    case spv::Op::OpUntypedImageTexelPointerEXT:
     case spv::Op::OpCopyObject:
     case spv::Op::OpAllocateNodePayloadsAMDX:
     case spv::Op::OpSelect:
@@ -224,8 +227,10 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) {
     case spv::Op::OpInBoundsAccessChain:
     case spv::Op::OpUntypedAccessChainKHR:
     case spv::Op::OpUntypedInBoundsAccessChainKHR:
+    case spv::Op::OpBufferPointerEXT:
     case spv::Op::OpFunctionParameter:
     case spv::Op::OpImageTexelPointer:
+    case spv::Op::OpUntypedImageTexelPointerEXT:
     case spv::Op::OpCopyObject:
     case spv::Op::OpRawAccessChainNV:
     case spv::Op::OpAllocateNodePayloadsAMDX:
@@ -262,7 +267,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
     case spv::Op::OpTypeAccelerationStructureNV:
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
     // case spv::Op::OpTypeAccelerationStructureKHR: covered by
     // spv::Op::OpTypeAccelerationStructureNV
     case spv::Op::OpTypeRayQueryKHR:
@@ -275,6 +280,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
     case spv::Op::OpTypeTensorARM:
     case spv::Op::OpTypeTaskSequenceINTEL:
     case spv::Op::OpTypeGraphARM:
+    case spv::Op::OpTypeBufferEXT:
       return true;
     default:
       // In particular, OpTypeForwardPointer does not generate a type,
@@ -290,6 +296,7 @@ bool spvOpcodeIsDecoration(const spv::Op opcode) {
     case spv::Op::OpDecorate:
     case spv::Op::OpDecorateId:
     case spv::Op::OpMemberDecorate:
+    case spv::Op::OpMemberDecorateIdEXT:
     case spv::Op::OpGroupDecorate:
     case spv::Op::OpGroupMemberDecorate:
     case spv::Op::OpDecorateStringGOOGLE:

+ 1 - 0
3rdparty/spirv-tools/source/operand.cpp

@@ -523,6 +523,7 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
     case spv::Op::OpSelectionMerge:
     case spv::Op::OpDecorate:
     case spv::Op::OpMemberDecorate:
+    case spv::Op::OpMemberDecorateIdEXT:
     case spv::Op::OpDecorateId:
     case spv::Op::OpDecorateStringGOOGLE:
     case spv::Op::OpMemberDecorateStringGOOGLE:

+ 7 - 37
3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp

@@ -45,9 +45,7 @@ constexpr uint32_t kExtInstOpInIdx = 1;
 constexpr uint32_t kInterpolantInIdx = 2;
 constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0;
 constexpr uint32_t kDebugDeclareVariableInIdx = 3;
-constexpr uint32_t kDebugValueLocalVariableInIdx = 2;
 constexpr uint32_t kDebugValueValueInIdx = 3;
-constexpr uint32_t kDebugValueExpressionInIdx = 4;
 
 // Sorting functor to present annotation instructions in an easy-to-process
 // order. The functor orders by opcode first and falls back on unique id
@@ -309,8 +307,8 @@ Pass::Status AggressiveDCEPass::ProcessDebugInformation(
         // DebugDeclare Variable is not live. Find the value that was being
         // stored to this variable. If it's live then create a new DebugValue
         // with this value. Otherwise let it die in peace.
-        get_def_use_mgr()->ForEachUser(var_id, [this, var_id,
-                                                inst](Instruction* user) {
+        get_def_use_mgr()->ForEachUser(var_id, [this,
+                                                var_id](Instruction* user) {
           if (user->opcode() == spv::Op::OpStore) {
             uint32_t stored_value_id = 0;
             const uint32_t kStoreValueInIdx = 1;
@@ -320,13 +318,13 @@ Pass::Status AggressiveDCEPass::ProcessDebugInformation(
             }
 
             // value being stored is still live
-            Instruction* next_inst = inst->NextNode();
+            Instruction* next_inst = user->NextNode();
             bool added =
                 context()->get_debug_info_mgr()->AddDebugValueForVariable(
-                    user, var_id, stored_value_id, inst);
+                    user, var_id, stored_value_id, user);
             if (added && next_inst) {
               auto new_debug_value = next_inst->PreviousNode();
-              live_insts_.Set(new_debug_value->unique_id());
+              AddToWorklist(new_debug_value);
             }
           }
           return true;
@@ -344,42 +342,13 @@ Pass::Status AggressiveDCEPass::ProcessDebugInformation(
 
         // Value operand of DebugValue is not live
         // Set Value to Undef of appropriate type
-        live_insts_.Set(inst->unique_id());
-
         uint32_t type_id = def->type_id();
-        auto type_def = get_def_use_mgr()->GetDef(type_id);
-        AddToWorklist(type_def);
-
         uint32_t undef_id = Type2Undef(type_id);
         if (undef_id == 0) return false;
 
-        auto undef_inst = get_def_use_mgr()->GetDef(undef_id);
-        live_insts_.Set(undef_inst->unique_id());
         inst->SetInOperand(var_operand_idx, {undef_id});
         context()->get_def_use_mgr()->AnalyzeInstUse(inst);
-
-        id = inst->GetSingleWordInOperand(kDebugValueLocalVariableInIdx);
-        auto localVar = get_def_use_mgr()->GetDef(id);
-        AddToWorklist(localVar);
-
-        uint32_t expr_idx = kDebugValueExpressionInIdx;
-        id = inst->GetSingleWordInOperand(expr_idx);
-        auto expression = get_def_use_mgr()->GetDef(id);
-        AddToWorklist(expression);
-
-        for (uint32_t i = expr_idx + 1; i < inst->NumInOperands(); ++i) {
-          id = inst->GetSingleWordInOperand(i);
-          auto index_def = get_def_use_mgr()->GetDef(id);
-          if (index_def) {
-            AddToWorklist(index_def);
-          }
-        }
-
-        for (auto& line_inst : inst->dbg_line_insts()) {
-          if (line_inst.IsDebugLineInst()) {
-            AddToWorklist(&line_inst);
-          }
-        }
+        AddToWorklist(inst);
       }
       return true;
     });
@@ -1151,6 +1120,7 @@ void AggressiveDCEPass::InitExtensions() {
       "SPV_NV_shader_subgroup_partitioned",
       "SPV_EXT_demote_to_helper_invocation",
       "SPV_EXT_descriptor_indexing",
+      "SPV_EXT_descriptor_heap",
       "SPV_NV_fragment_shader_barycentric",
       "SPV_NV_compute_shader_derivatives",
       "SPV_NV_shader_image_footprint",

+ 16 - 1
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp

@@ -223,7 +223,7 @@ spv::Id CanonicalizeIdsPass::HashTypeAndConst(spv::Id const id) const {
     // remapper. Support should be added as necessary.
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
     case spv::Op::OpTypeHitObjectNV:
     case spv::Op::OpTypeUntypedPointerKHR:
     case spv::Op::OpTypeNodePayloadArrayAMDX:
@@ -428,6 +428,21 @@ bool CanonicalizeIdsPass::ApplyMap() {
             }
           }
         }
+        const auto& debug_scope = inst->GetDebugScope();
+        if (debug_scope.GetLexicalScope() != kNoDebugScope) {
+          uint32_t old_scope = debug_scope.GetLexicalScope();
+          uint32_t new_scope = GetNewId(old_scope);
+          uint32_t old_inlined_at = debug_scope.GetInlinedAt();
+          uint32_t new_inlined_at = old_inlined_at != kNoInlinedAt
+                                        ? GetNewId(old_inlined_at)
+                                        : old_inlined_at;
+          if ((new_scope != unused_ && new_scope != old_scope) ||
+              (new_inlined_at != unused_ && new_inlined_at != old_inlined_at)) {
+            DebugScope new_debug_scope(new_scope, new_inlined_at);
+            inst->SetDebugScope(new_debug_scope);
+            modified = true;
+          }
+        }
       },
       true);
 

+ 142 - 0
3rdparty/spirv-tools/source/opt/const_folding_rules.cpp

@@ -14,6 +14,8 @@
 
 #include "source/opt/const_folding_rules.h"
 
+#include <optional>
+
 #include "source/opt/ir_context.h"
 
 namespace spvtools {
@@ -988,6 +990,32 @@ ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
 
+// x - x = 0
+ConstantFoldingRule FoldRedundantSub() {
+  return [](IRContext* context, Instruction* inst,
+            const std::vector<const analysis::Constant*>&)
+             -> const analysis::Constant* {
+    assert(inst->opcode() == spv::Op::OpFSub ||
+           inst->opcode() == spv::Op::OpISub);
+
+    if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) {
+      bool use_float = inst->opcode() == spv::Op::OpFSub;
+      if (use_float && !inst->IsFloatingPointFoldingAllowed()) {
+        return nullptr;
+      }
+      analysis::TypeManager* type_mgr = context->get_type_mgr();
+      const analysis::Type* type = type_mgr->GetType(inst->type_id());
+      if (type->IsCooperativeMatrix()) {
+        return nullptr;
+      }
+      analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+      uint32_t null_id = const_mgr->GetNullConstId(type);
+      return const_mgr->FindDeclaredConstant(null_id);
+    }
+    return nullptr;
+  };
+}
+
 // Returns the constant that results from evaluating |numerator| / 0.0.  Returns
 // |nullptr| if the result could not be evaluated.
 const analysis::Constant* FoldFPScalarDivideByZero(
@@ -1047,6 +1075,107 @@ const analysis::Constant* FoldScalarFPDivide(
 // Returns the constant folding rule to fold |OpFDiv| with two constants.
 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
 
+// Get a singular uniform value, which is repeated when the |type| is a vector.
+const analysis::Constant* GetConstantUniformValue(
+    analysis::ConstantManager* const_mgr, const analysis::Type* type,
+    std::optional<double> f = {}, std::optional<uint64_t> i = {}) {
+  const analysis::Constant* uniform = nullptr;
+  bool is_vector = false;
+  const analysis::Type* base_type = type;
+
+  if (base_type->AsVector()) {
+    is_vector = true;
+    base_type = base_type->AsVector()->element_type();
+  }
+
+  if (f && base_type->AsFloat()) {
+    if (base_type->AsFloat()->width() == 32) {
+      uniform = const_mgr->GetConstant(
+          base_type, utils::FloatProxy<float>((float)f.value()).GetWords());
+    } else if (base_type->AsFloat()->width() == 64) {
+      uniform = const_mgr->GetConstant(
+          base_type, utils::FloatProxy<double>(f.value()).GetWords());
+    }
+  } else if (i && base_type->AsInteger()) {
+    uniform =
+        const_mgr->GenerateIntegerConstant(base_type->AsInteger(), i.value());
+  }
+
+  if (!uniform) {
+    return nullptr;
+  }
+
+  if (is_vector) {
+    Instruction* uniform_inst = const_mgr->GetDefiningInstruction(uniform);
+    if (!uniform_inst) return nullptr;
+
+    uint32_t uniform_id = uniform_inst->result_id();
+    uniform =
+        const_mgr->GetConstant(type, std::vector<uint32_t>(4, uniform_id));
+  }
+
+  return uniform;
+}
+
+//  x /  x =  1
+// -x /  x = -1
+//  x / -x = -1
+ConstantFoldingRule FoldRedundantDiv() {
+  return [](IRContext* context, Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    assert(inst->opcode() == spv::Op::OpFDiv ||
+           inst->opcode() == spv::Op::OpSDiv ||
+           inst->opcode() == spv::Op::OpUDiv);
+
+    if (constants[0] || constants[1]) {
+      return nullptr;
+    }
+
+    analysis::TypeManager* type_mgr = context->get_type_mgr();
+    const analysis::Type* type = type_mgr->GetType(inst->type_id());
+
+    if (type->IsCooperativeMatrix()) {
+      return nullptr;
+    }
+
+    bool use_float = inst->opcode() == spv::Op::OpFDiv;
+    if (use_float && !inst->IsFloatingPointFoldingAllowed()) {
+      return nullptr;
+    }
+
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+
+    if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) {
+      return GetConstantUniformValue(const_mgr, type, 1.0, 1);
+    }
+
+    if (inst->opcode() == spv::Op::OpUDiv) {
+      return nullptr;
+    }
+
+    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+
+    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+    if ((lhs->opcode() == spv::Op::OpSNegate ||
+         lhs->opcode() == spv::Op::OpFNegate) &&
+        lhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1) &&
+        (!use_float || lhs->IsFloatingPointFoldingAllowed())) {
+      return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX);
+    }
+
+    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
+    if ((rhs->opcode() == spv::Op::OpSNegate ||
+         rhs->opcode() == spv::Op::OpFNegate) &&
+        rhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(0) &&
+        (!use_float || rhs->IsFloatingPointFoldingAllowed())) {
+      return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX);
+    }
+
+    return nullptr;
+  };
+}
+
 bool CompareFloatingPoint(bool op_result, bool op_unordered,
                           bool need_ordered) {
   if (need_ordered) {
@@ -1945,9 +2074,14 @@ void ConstantFoldingRules::AddFoldingRules() {
 
   rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
   rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
+
   rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
+  rules_[spv::Op::OpFDiv].push_back(FoldRedundantDiv());
+
   rules_[spv::Op::OpFMul].push_back(FoldFMul());
+
   rules_[spv::Op::OpFSub].push_back(FoldFSub());
+  rules_[spv::Op::OpFSub].push_back(FoldRedundantSub());
 
   rules_[spv::Op::OpSelect].push_back(FoldInvariantSelect());
 
@@ -2005,21 +2139,29 @@ void ConstantFoldingRules::AddFoldingRules() {
   rules_[spv::Op::OpIAdd].push_back(
       FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
           [](uint64_t a, uint64_t b) { return a + b; })));
+
   rules_[spv::Op::OpISub].push_back(
       FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
           [](uint64_t a, uint64_t b) { return a - b; })));
+  rules_[spv::Op::OpISub].push_back(FoldRedundantSub());
+
   rules_[spv::Op::OpIMul].push_back(
       FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
           [](uint64_t a, uint64_t b) { return a * b; })));
+
   rules_[spv::Op::OpUDiv].push_back(
       FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
           [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
+  rules_[spv::Op::OpUDiv].push_back(FoldRedundantDiv());
+
   rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
       FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
         return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
                                                static_cast<int64_t>(b))
                        : 0);
       })));
+  rules_[spv::Op::OpSDiv].push_back(FoldRedundantDiv());
+
   rules_[spv::Op::OpUMod].push_back(
       FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
           [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));

+ 39 - 1
3rdparty/spirv-tools/source/opt/convert_to_half_pass.cpp

@@ -22,6 +22,7 @@ namespace spvtools {
 namespace opt {
 namespace {
 // Indices of operands in SPIR-V instructions
+constexpr int kImageSampleCoordinateIdInIdx = 1;
 constexpr int kImageSampleDrefIdInIdx = 2;
 }  // namespace
 
@@ -325,7 +326,7 @@ bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
 
 bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
   bool modified = false;
-  // If image reference, only need to convert dref args back to float32
+  // If image reference, some operands aren't allowed to be non-32 bit floats
   if (dref_image_ops_.count(inst->opcode()) != 0) {
     uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
     if (converted_ids_.count(dref_id) > 0) {
@@ -338,6 +339,19 @@ bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
       modified = true;
     }
   }
+  if (coordinate_image_ops_.count(inst->opcode()) != 0) {
+    uint32_t coordinate_id =
+        inst->GetSingleWordInOperand(kImageSampleCoordinateIdInIdx);
+    if (converted_ids_.count(coordinate_id) > 0) {
+      GenConvert(&coordinate_id, 32, inst);
+      if (status_ == Status::Failure) {
+        return false;
+      }
+      inst->SetInOperand(kImageSampleCoordinateIdInIdx, {coordinate_id});
+      get_def_use_mgr()->AnalyzeInstUse(inst);
+      modified = true;
+    }
+  }
   return modified;
 }
 
@@ -591,6 +605,30 @@ void ConvertToHalfPass::Initialize() {
       spv::Op::OpImageSparseSampleProjDrefExplicitLod,
       spv::Op::OpImageSparseDrefGather,
   };
+  coordinate_image_ops_ = {
+      spv::Op::OpImageSampleImplicitLod,
+      spv::Op::OpImageSampleExplicitLod,
+      spv::Op::OpImageSampleDrefImplicitLod,
+      spv::Op::OpImageSampleDrefExplicitLod,
+      spv::Op::OpImageSampleProjImplicitLod,
+      spv::Op::OpImageSampleProjExplicitLod,
+      spv::Op::OpImageSampleProjDrefImplicitLod,
+      spv::Op::OpImageSampleProjDrefExplicitLod,
+      spv::Op::OpImageFetch,
+      spv::Op::OpImageGather,
+      spv::Op::OpImageDrefGather,
+      spv::Op::OpImageRead,
+      spv::Op::OpImageWrite,
+      spv::Op::OpImageQueryLod,
+      spv::Op::OpImageSparseSampleImplicitLod,
+      spv::Op::OpImageSparseSampleExplicitLod,
+      spv::Op::OpImageSparseSampleDrefImplicitLod,
+      spv::Op::OpImageSparseSampleDrefExplicitLod,
+      spv::Op::OpImageSparseFetch,
+      spv::Op::OpImageSparseGather,
+      spv::Op::OpImageSparseDrefGather,
+      spv::Op::OpImageSparseRead,
+  };
   closure_ops_ = {
       spv::Op::OpVectorExtractDynamic,
       spv::Op::OpVectorInsertDynamic,

+ 3 - 0
3rdparty/spirv-tools/source/opt/convert_to_half_pass.h

@@ -145,6 +145,9 @@ class ConvertToHalfPass : public Pass {
   // Set of only dref sample operations
   std::unordered_set<spv::Op, hasher> dref_image_ops_;
 
+  // Set of only sample operations that have a Coordinate operand
+  std::unordered_set<spv::Op, hasher> coordinate_image_ops_;
+
   // Set of operations that can be marked as relaxed
   std::unordered_set<spv::Op, hasher> closure_ops_;
 

+ 5 - 5
3rdparty/spirv-tools/source/opt/eliminate_dead_members_pass.cpp

@@ -207,7 +207,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
         type_id = type_inst->GetSingleWordInOperand(0);
         break;
       default:
@@ -256,7 +256,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
         type_id = type_inst->GetSingleWordInOperand(0);
         break;
       default:
@@ -518,7 +518,7 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
         new_operands.emplace_back(inst->GetInOperand(i));
         type_id = type_inst->GetSingleWordInOperand(0);
         break;
@@ -594,7 +594,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
         type_id = type_inst->GetSingleWordInOperand(0);
         break;
       default:
@@ -658,7 +658,7 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
         type_id = type_inst->GetSingleWordInOperand(0);
         break;
       default:

+ 28 - 25
3rdparty/spirv-tools/source/opt/folding_rules.cpp

@@ -116,12 +116,6 @@ bool IsValidResult(T val) {
   }
 }
 
-// Returns true if `type` is a cooperative matrix.
-bool IsCooperativeMatrix(const analysis::Type* type) {
-  return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
-         type->kind() == analysis::Type::kCooperativeMatrixNV;
-}
-
 const analysis::Constant* ConstInput(
     const std::vector<const analysis::Constant*>& constants) {
   return constants[0] ? constants[0] : constants[1];
@@ -369,7 +363,7 @@ FoldingRule ReciprocalFDiv() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -455,7 +449,7 @@ FoldingRule MergeNegateMulDivArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -522,7 +516,7 @@ FoldingRule MergeNegateAddSubArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -767,7 +761,7 @@ FoldingRule MergeMulMulArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -826,7 +820,7 @@ FoldingRule MergeMulDivArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -904,7 +898,7 @@ FoldingRule MergeMulNegateArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -941,14 +935,18 @@ static bool IsFoldableNegation(const Instruction* inst) {
            inst->IsFloatingPointFoldingAllowed()));
 }
 
-// Merges multiplies of two negations.
+// Merges multiplies / divisions of two negations.
 // Cases:
 // (-x) * (-y) = x * y
-FoldingRule MergeMulDoubleNegative() {
+// (-x) / (-y) = x / y
+FoldingRule MergeDivMulDoubleNegative() {
   return [](IRContext* context, Instruction* inst,
             const std::vector<const analysis::Constant*>&) {
     assert(inst->opcode() == spv::Op::OpFMul ||
-           inst->opcode() == spv::Op::OpIMul);
+           inst->opcode() == spv::Op::OpVectorTimesScalar ||
+           inst->opcode() == spv::Op::OpFDiv ||
+           inst->opcode() == spv::Op::OpIMul ||
+           inst->opcode() == spv::Op::OpSDiv);
 
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
@@ -985,7 +983,7 @@ FoldingRule MergeDivDivArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1063,7 +1061,7 @@ FoldingRule MergeDivMulArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1222,7 +1220,7 @@ FoldingRule MergeSubNegateArithmetic() {
       return true;
     }
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1254,7 +1252,7 @@ FoldingRule MergeAddAddArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1307,7 +1305,7 @@ FoldingRule MergeAddSubArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1372,7 +1370,7 @@ FoldingRule MergeSubAddArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1443,7 +1441,7 @@ FoldingRule MergeSubSubArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -1541,7 +1539,7 @@ FoldingRule MergeGenericAddSubArithmetic() {
     const analysis::Type* type =
         context->get_type_mgr()->GetType(inst->type_id());
 
-    if (IsCooperativeMatrix(type)) {
+    if (type->IsCooperativeMatrix()) {
       return false;
     }
 
@@ -3739,6 +3737,7 @@ void FoldingRules::AddFoldingRules() {
   rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
   rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
   rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
+  rules_[spv::Op::OpFDiv].push_back(MergeDivMulDoubleNegative());
 
   rules_[spv::Op::OpFMod].push_back(RedundantFMod());
 
@@ -3746,7 +3745,9 @@ void FoldingRules::AddFoldingRules() {
   rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
   rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
   rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
-  rules_[spv::Op::OpFMul].push_back(MergeMulDoubleNegative());
+  rules_[spv::Op::OpFMul].push_back(MergeDivMulDoubleNegative());
+
+  rules_[spv::Op::OpVectorTimesScalar].push_back(MergeDivMulDoubleNegative());
 
   rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
   rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
@@ -3764,10 +3765,12 @@ void FoldingRules::AddFoldingRules() {
   rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
   rules_[spv::Op::OpIAdd].push_back(FactorAddSubMuls());
 
+  rules_[spv::Op::OpSDiv].push_back(MergeDivMulDoubleNegative());
+
   rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
   rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
   rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
-  rules_[spv::Op::OpIMul].push_back(MergeMulDoubleNegative());
+  rules_[spv::Op::OpIMul].push_back(MergeDivMulDoubleNegative());
 
   rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
   rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());

+ 1 - 0
3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp

@@ -436,6 +436,7 @@ void LocalAccessChainConvertPass::InitExtensions() {
       "SPV_NV_shader_subgroup_partitioned",
       "SPV_EXT_demote_to_helper_invocation",
       "SPV_EXT_descriptor_indexing",
+      "SPV_EXT_descriptor_heap",
       "SPV_NV_fragment_shader_barycentric",
       "SPV_NV_compute_shader_derivatives",
       "SPV_NV_shader_image_footprint",

+ 1 - 0
3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp

@@ -272,6 +272,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
       "SPV_NV_shader_subgroup_partitioned",
       "SPV_EXT_demote_to_helper_invocation",
       "SPV_EXT_descriptor_indexing",
+      "SPV_EXT_descriptor_heap",
       "SPV_NV_fragment_shader_barycentric",
       "SPV_NV_compute_shader_derivatives",
       "SPV_NV_shader_image_footprint",

+ 1 - 0
3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp

@@ -122,6 +122,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
       "SPV_GOOGLE_hlsl_functionality1",
       "SPV_NV_shader_subgroup_partitioned",
       "SPV_EXT_descriptor_indexing",
+      "SPV_EXT_descriptor_heap",
       "SPV_NV_fragment_shader_barycentric",
       "SPV_NV_compute_shader_derivatives",
       "SPV_NV_shader_image_footprint",

+ 2 - 1
3rdparty/spirv-tools/source/opt/reflect.h

@@ -44,7 +44,8 @@ inline bool IsAnnotationInst(spv::Op opcode) {
           opcode <= spv::Op::OpGroupMemberDecorate) ||
          opcode == spv::Op::OpDecorateId ||
          opcode == spv::Op::OpDecorateStringGOOGLE ||
-         opcode == spv::Op::OpMemberDecorateStringGOOGLE;
+         opcode == spv::Op::OpMemberDecorateStringGOOGLE ||
+         opcode == spv::Op::OpMemberDecorateIdEXT;
 }
 inline bool IsTypeInst(spv::Op opcode) {
   return spvOpcodeGeneratesType(opcode) ||

+ 20 - 2
3rdparty/spirv-tools/source/opt/type_manager.cpp

@@ -490,7 +490,7 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
         return 0;
       }
       typeInst = MakeUnique<Instruction>(
-          context(), spv::Op::OpTypeCooperativeVectorNV, 0, id,
+          context(), spv::Op::OpTypeVectorIdEXT, 0, id,
           std::initializer_list<Operand>{
               {SPV_OPERAND_TYPE_ID, {component_type}},
               {SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
@@ -539,6 +539,14 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
                                          id, ops);
       break;
     }
+    case Type::kBufferEXT: {
+      typeInst = MakeUnique<Instruction>(
+          context(), spv::Op::OpTypeBufferEXT, 0, id,
+          std::initializer_list<Operand>{
+              {SPV_OPERAND_TYPE_STORAGE_CLASS,
+               {static_cast<uint32_t>(type->AsBufferEXT()->storage_class())}}});
+      break;
+    }
     default:
       assert(false && "Unexpected type");
       break;
@@ -816,6 +824,11 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
       rebuilt_ty = MakeUnique<GraphARM>(graph_type->num_inputs(), io_types);
       break;
     }
+    case Type::kBufferEXT: {
+      const BufferEXT* buffer_type = type.AsBufferEXT();
+      rebuilt_ty = MakeUnique<BufferEXT>(buffer_type->storage_class());
+      break;
+    }
     default:
       assert(false && "Unhandled type");
       return nullptr;
@@ -1074,7 +1087,7 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
           inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
           inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
       break;
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
       type = new CooperativeVectorNV(GetType(inst.GetSingleWordInOperand(0)),
                                      inst.GetSingleWordInOperand(1));
       break;
@@ -1126,6 +1139,11 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
       type = new GraphARM(inst.GetSingleWordInOperand(0), io_types);
       break;
     }
+    case spv::Op::OpTypeBufferEXT: {
+      type = new BufferEXT(
+          static_cast<spv::StorageClass>(inst.GetSingleWordInOperand(0)));
+      break;
+    }
     default:
       assert(false && "Type not handled by the type manager.");
       break;

+ 28 - 0
3rdparty/spirv-tools/source/opt/types.cpp

@@ -138,6 +138,7 @@ std::unique_ptr<Type> Type::Clone() const {
     DeclareKindCase(HitObjectEXT);
     DeclareKindCase(TensorARM);
     DeclareKindCase(GraphARM);
+    DeclareKindCase(BufferEXT);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -193,6 +194,7 @@ bool Type::operator==(const Type& other) const {
     DeclareKindCase(TensorViewNV);
     DeclareKindCase(TensorARM);
     DeclareKindCase(GraphARM);
+    DeclareKindCase(BufferEXT);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -256,6 +258,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
     DeclareKindCase(TensorViewNV);
     DeclareKindCase(TensorARM);
     DeclareKindCase(GraphARM);
+    DeclareKindCase(BufferEXT);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -993,6 +996,31 @@ bool GraphARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
   return true;
 }
 
+BufferEXT::BufferEXT(spv::StorageClass storage_class)
+    : Type(kBufferEXT), storage_class_(storage_class) {}
+
+std::string BufferEXT::str() const {
+  std::ostringstream oss;
+  oss << "buffer<" << static_cast<uint32_t>(storage_class_) << ">";
+  return oss.str();
+}
+
+size_t BufferEXT::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
+  hash = hash_combine(hash, static_cast<uint32_t>(storage_class_));
+  return hash;
+}
+
+bool BufferEXT::IsSameImpl(const Type* that, IsSameCache*) const {
+  const BufferEXT* og = that->AsBufferEXT();
+  if (!og) {
+    return false;
+  }
+  if (storage_class_ != og->storage_class_) {
+    return false;
+  }
+  return true;
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools

+ 29 - 0
3rdparty/spirv-tools/source/opt/types.h

@@ -72,6 +72,7 @@ class TensorLayoutNV;
 class TensorViewNV;
 class TensorARM;
 class GraphARM;
+class BufferEXT;
 
 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
 // which is used as a way to probe the actual <subclass>.
@@ -120,6 +121,7 @@ class Type {
     kTensorViewNV,
     kTensorARM,
     kGraphARM,
+    kBufferEXT,
     kLast
   };
 
@@ -142,6 +144,12 @@ class Type {
     return IsSameImpl(that, &seen);
   }
 
+  // Returns true if this is a cooperative matrix.
+  bool IsCooperativeMatrix() const {
+    return kind() == analysis::Type::kCooperativeMatrixKHR ||
+           kind() == analysis::Type::kCooperativeMatrixNV;
+  }
+
   // Returns true if this type is exactly the same as |that| type, including
   // decorations.  |seen| is the set of |Pointer*| pair that are currently being
   // compared in a parent call to |IsSameImpl|.
@@ -229,6 +237,7 @@ class Type {
   DeclareCastMethod(TensorViewNV)
   DeclareCastMethod(TensorARM)
   DeclareCastMethod(GraphARM)
+  DeclareCastMethod(BufferEXT)
 #undef DeclareCastMethod
 
 protected:
@@ -833,6 +842,26 @@ class GraphARM : public Type {
   const std::vector<const Type*> io_types_;
 };
 
+class BufferEXT : public Type {
+ public:
+  BufferEXT(spv::StorageClass storage_class_);
+  BufferEXT(const BufferEXT&) = default;
+
+  std::string str() const override;
+
+  BufferEXT* AsBufferEXT() override { return this; }
+  const BufferEXT* AsBufferEXT() const override { return this; }
+
+  spv::StorageClass storage_class() const { return storage_class_; }
+
+  size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+ private:
+  bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+  const spv::StorageClass storage_class_;
+};
+
 #define DefineParameterlessType(type, name)                                \
   class type : public Type {                                               \
    public:                                                                 \

+ 1 - 1
3rdparty/spirv-tools/source/val/validate.cpp

@@ -390,7 +390,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
     if (auto error = AtomicsPass(*vstate, &instruction)) return error;
     if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
     if (auto error = BarriersPass(*vstate, &instruction)) return error;
-    // Group
+    if (auto error = GroupPass(*vstate, &instruction)) return error;
     // Device-Side Enqueue
     // Pipe
     if (auto error = NonUniformPass(*vstate, &instruction)) return error;

+ 3 - 0
3rdparty/spirv-tools/source/val/validate.h

@@ -180,6 +180,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst);
 /// Validates correctness of barrier instructions.
 spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst);
 
+/// Validates correctness of Group (Kernel) instructions.
+spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst);
+
 /// Validates correctness of literal numbers.
 spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst);
 

+ 86 - 7
3rdparty/spirv-tools/source/val/validate_annotation.cpp

@@ -37,6 +37,8 @@ bool DecorationTakesIdParameters(spv::Decoration type) {
     case spv::Decoration::PayloadNodeArraySizeAMDX:
     case spv::Decoration::PayloadNodeNameAMDX:
     case spv::Decoration::PayloadNodeBaseIndexAMDX:
+    case spv::Decoration::ArrayStrideIdEXT:
+    case spv::Decoration::OffsetIdEXT:
       return true;
     default:
       break;
@@ -65,6 +67,7 @@ bool IsNotMemberDecoration(spv::Decoration dec) {
     case spv::Decoration::Block:
     case spv::Decoration::BufferBlock:
     case spv::Decoration::ArrayStride:
+    case spv::Decoration::ArrayStrideIdEXT:
     case spv::Decoration::GLSLShared:
     case spv::Decoration::GLSLPacked:
     case spv::Decoration::CPacked:
@@ -174,7 +177,8 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
       if (target->opcode() != spv::Op::OpVariable &&
           target->opcode() != spv::Op::OpUntypedVariableKHR &&
           target->opcode() != spv::Op::OpFunctionParameter &&
-          target->opcode() != spv::Op::OpRawAccessChainNV) {
+          target->opcode() != spv::Op::OpRawAccessChainNV &&
+          target->opcode() != spv::Op::OpBufferPointerEXT) {
         return fail(0) << "must be a memory object declaration";
       }
       if (!_.IsPointerType(target->type_id())) {
@@ -349,6 +353,33 @@ spv_result_t ValidateDecorateId(ValidationState_t& _, const Instruction* inst) {
               "OpDecorateId";
   }
 
+  if (decoration == spv::Decoration::ArrayStrideIdEXT) {
+    if (target->opcode() != spv::Op::OpTypeArray &&
+        target->opcode() != spv::Op::OpTypeRuntimeArray) {
+      // ArrayStrideIdEXT is suppose to identical to ArrayStride, which would
+      // allow it to be a OpTypePointer/OpTypeUntypedPointerKHR
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "ArrayStrideIdEXT decoration must only be applied to array "
+                "types.";
+    } else {
+      const uint32_t operand_id = inst->GetOperandAs<uint32_t>(2);
+      if (!_.IsIntScalarType(_.GetTypeId(operand_id), 32)) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "ArrayStrideIdEXT extra operand must be a 32-bit int "
+                  "scalar type.";
+      }
+
+      // Strip array and should be the descriptor type
+      const uint32_t element_type =
+          _.FindDef(target_id)->GetOperandAs<uint32_t>(1);
+      if (!_.IsDescriptorType(element_type)) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "ArrayStrideIdEXT decoration must only be applied to"
+               << " array type containing a Descriptor type.";
+      }
+    }
+  }
+
   for (uint32_t i = 2; i < inst->operands().size(); ++i) {
     const auto param_id = inst->GetOperandAs<uint32_t>(i);
     const auto param = _.FindDef(param_id);
@@ -375,24 +406,70 @@ spv_result_t ValidateMemberDecorate(ValidationState_t& _,
                                     const Instruction* inst) {
   const auto struct_type_id = inst->GetOperandAs<uint32_t>(0);
   const auto struct_type = _.FindDef(struct_type_id);
+  const bool is_mem_dec_id_inst =
+      (inst->opcode() == spv::Op::OpMemberDecorateIdEXT);
   if (!struct_type || spv::Op::OpTypeStruct != struct_type->opcode()) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpMemberDecorate Structure type <id> "
-           << _.getIdName(struct_type_id) << " is not a struct type.";
+           << (is_mem_dec_id_inst ? "OpMemberDecorateIdEXT"
+                                  : "OpMemberDecorate")
+           << " Structure type <id> " << _.getIdName(struct_type_id)
+           << " is not a struct type.";
   }
   const auto member = inst->GetOperandAs<uint32_t>(1);
   const auto member_count =
       static_cast<uint32_t>(struct_type->words().size() - 2);
   if (member_count <= member) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "Index " << member
-           << " provided in OpMemberDecorate for struct <id> "
-           << _.getIdName(struct_type_id)
+           << "Index " << member << " provided in "
+           << (is_mem_dec_id_inst ? "OpMemberDecorateIdEXT"
+                                  : "OpMemberDecorate")
+           << " for struct <id> " << _.getIdName(struct_type_id)
            << " is out of bounds. The structure has " << member_count
            << " members. Largest valid index is " << member_count - 1 << ".";
   }
 
   const auto decoration = inst->GetOperandAs<spv::Decoration>(2);
+  if (is_mem_dec_id_inst) {
+    if (decoration != spv::Decoration::OffsetIdEXT) {
+      if (decoration == spv::Decoration::ArrayStrideIdEXT) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "ArrayStrideIdEXT could only be directly applied"
+               << " to array type using OpDecorateId.";
+      } else {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "Decoration operand could only be OffsetIdEXT.";
+      }
+    }
+
+    const auto is_descriptor_type = [&_](const Instruction* type_inst) {
+      return _.IsDescriptorType(type_inst->opcode());
+    };
+
+    // recursively scans the struct to find if anything has a descriptor type,
+    // must be at least 1
+    if (decoration == spv::Decoration::OffsetIdEXT) {
+      const uint32_t operand_id = inst->GetOperandAs<uint32_t>(3);
+      if (!_.IsIntScalarType(_.GetTypeId(operand_id), 32)) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "OffsetIdEXT extra operand must be a 32-bit int scalar type.";
+      }
+      if (!_.ContainsType(struct_type_id, is_descriptor_type, true)) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "OffsetIdEXT decoration in MemberDecorateIdEXT must only be "
+                  "applied to members of structs where the struct contains "
+                  "descriptor types.";
+      }
+    }
+
+    for (uint32_t elem_idx = 3; elem_idx < inst->operands().size();
+         elem_idx++) {
+      if (_.FindDef(inst->GetOperandAs<uint32_t>(elem_idx)) > struct_type) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "All <id> Extra Operands must appear before Structure Type.";
+      }
+    }
+  }
+
   if (IsNotMemberDecoration(decoration)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << _.SpvDecorationString(decoration)
@@ -497,7 +574,8 @@ spv_result_t RegisterDecorations(ValidationState_t& _,
       _.RegisterDecorationForId(target_id, Decoration(dec_type, dec_params));
       break;
     }
-    case spv::Op::OpMemberDecorate: {
+    case spv::Op::OpMemberDecorate:
+    case spv::Op::OpMemberDecorateIdEXT: {
       const uint32_t struct_id = inst->word(1);
       const uint32_t index = inst->word(2);
       const spv::Decoration dec_type =
@@ -568,6 +646,7 @@ spv_result_t AnnotationPass(ValidationState_t& _, const Instruction* inst) {
     // TODO(dneto): spv::Op::OpDecorateStringGOOGLE
     // See https://github.com/KhronosGroup/SPIRV-Tools/issues/2253
     case spv::Op::OpMemberDecorate:
+    case spv::Op::OpMemberDecorateIdEXT:
       if (auto error = ValidateMemberDecorate(_, inst)) return error;
       break;
     case spv::Op::OpDecorationGroup:

+ 2 - 2
3rdparty/spirv-tools/source/val/validate_atomics.cpp

@@ -224,7 +224,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       // Can't use result_type because OpAtomicStore doesn't have a result
-      if (_.IsIntScalarType(data_type) && _.GetBitWidth(data_type) == 64 &&
+      if (_.IsIntScalarType(data_type, 64) &&
           !_.HasCapability(spv::Capability::Int64Atomics)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode)
@@ -357,7 +357,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
       // If result and pointer type are different, need to do special check here
       if (opcode == spv::Op::OpAtomicFlagTestAndSet ||
           opcode == spv::Op::OpAtomicFlagClear) {
-        if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
+        if (!_.IsIntScalarType(data_type, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << spvOpcodeString(opcode)
                  << ": expected Pointer to point to a value of 32-bit integer "

+ 1 - 2
3rdparty/spirv-tools/source/val/validate_barriers.cpp

@@ -94,8 +94,7 @@ spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t subgroup_count_type = _.GetOperandTypeId(inst, 2);
-      if (!_.IsIntScalarType(subgroup_count_type) ||
-          _.GetBitWidth(subgroup_count_type) != 32) {
+      if (!_.IsIntScalarType(subgroup_count_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode)
                << ": expected Subgroup Count to be a 32-bit int";

+ 19 - 0
3rdparty/spirv-tools/source/val/validate_builtins.cpp

@@ -369,6 +369,9 @@ class BuiltInsValidator {
   spv_result_t ValidateShadingRateAtDefinition(const Decoration& decoration,
                                                const Instruction& inst);
 
+  spv_result_t ValidateDescriptorHeapAtDefinition(const Decoration& decoration,
+                                                  const Instruction& inst);
+
   spv_result_t ValidateRayTracingBuiltinsAtDefinition(
       const Decoration& decoration, const Instruction& inst);
 
@@ -4507,6 +4510,18 @@ spv_result_t BuiltInsValidator::ValidateShadingRateAtReference(
   return SPV_SUCCESS;
 }
 
+spv_result_t BuiltInsValidator::ValidateDescriptorHeapAtDefinition(
+    const Decoration& decoration, const Instruction& inst) {
+  if (decoration.struct_member_index() != Decoration::kInvalidMember) {
+    return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+           << "BuiltIn "
+           << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                            (uint32_t)decoration.builtin())
+           << " cannot be used as a member decoration ";
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t BuiltInsValidator::ValidateRayTracingBuiltinsAtDefinition(
     const Decoration& decoration, const Instruction& inst) {
   if (spvIsVulkanEnv(_.context()->target_env)) {
@@ -5020,6 +5035,10 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinitionVulkan(
     case spv::BuiltIn::ShadingRateKHR: {
       return ValidateShadingRateAtDefinition(decoration, inst);
     }
+    case spv::BuiltIn::SamplerHeapEXT:
+    case spv::BuiltIn::ResourceHeapEXT: {
+      return ValidateDescriptorHeapAtDefinition(decoration, inst);
+    }
     default:
       // No validation rules (for the moment).
       break;

+ 1 - 0
3rdparty/spirv-tools/source/val/validate_capability.cpp

@@ -147,6 +147,7 @@ bool IsSupportOptionalVulkan_1_0(uint32_t capability) {
     case spv::Capability::Int8:
     case spv::Capability::BFloat16TypeKHR:
     case spv::Capability::Float8EXT:
+    case spv::Capability::PushConstantBanksNV:
       return true;
     default:
       break;

+ 32 - 0
3rdparty/spirv-tools/source/val/validate_cfg.cpp

@@ -352,6 +352,34 @@ spv_result_t ValidateLoopMerge(ValidationState_t& _, const Instruction* inst) {
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateLifetime(ValidationState_t& _, const Instruction* inst) {
+  const uint32_t pointer_id = _.GetOperandTypeId(inst, 0);
+  const Instruction* pointer_inst = _.FindDef(pointer_id);
+  if (pointer_inst->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Op" << spvOpcodeString(inst->opcode())
+           << " pointer operand type must be a OpTypePointer.";
+  } else if (pointer_inst->GetOperandAs<spv::StorageClass>(1) !=
+             spv::StorageClass::Function) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Op" << spvOpcodeString(inst->opcode())
+           << " pointer operand must be in the Function storage class.";
+  }
+
+  const uint32_t size = inst->GetOperandAs<uint32_t>(1);
+  if (size != 0) {
+    if (!_.HasCapability(spv::Capability::Addresses)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Op" << spvOpcodeString(inst->opcode())
+             << " size is non-zero, but the Addresses Capability is not "
+                "declared.";
+    }
+    // TODO - "Size must be 0 if Pointer is a pointer to a non-void type"
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 void printDominatorList(const BasicBlock& b) {
@@ -1268,6 +1296,10 @@ spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpLoopMerge:
       if (auto error = ValidateLoopMerge(_, inst)) return error;
       break;
+    case spv::Op::OpLifetimeStart:
+    case spv::Op::OpLifetimeStop:
+      if (auto error = ValidateLifetime(_, inst)) return error;
+      break;
     default:
       break;
   }

+ 15 - 14
3rdparty/spirv-tools/source/val/validate_composites.cpp

@@ -127,7 +127,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
         *member_type = type_inst->word(component_index + 2);
         break;
       }
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
       case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeCooperativeMatrixNV: {
         *member_type = type_inst->word(2);
@@ -155,7 +155,7 @@ spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _,
   const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
   const spv::Op vector_opcode = _.GetIdOpcode(vector_type);
   if (vector_opcode != spv::Op::OpTypeVector &&
-      vector_opcode != spv::Op::OpTypeCooperativeVectorNV) {
+      vector_opcode != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Vector type to be OpTypeVector";
   }
@@ -184,7 +184,7 @@ spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _,
   const uint32_t result_type = inst->type_id();
   const spv::Op result_opcode = _.GetIdOpcode(result_type);
   if (result_opcode != spv::Op::OpTypeVector &&
-      result_opcode != spv::Op::OpTypeCooperativeVectorNV) {
+      result_opcode != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Result Type to be OpTypeVector";
   }
@@ -223,7 +223,7 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
   const spv::Op result_opcode = _.GetIdOpcode(result_type);
   switch (result_opcode) {
     case spv::Op::OpTypeVector:
-    case spv::Op::OpTypeCooperativeVectorNV: {
+    case spv::Op::OpTypeVectorIdEXT: {
       uint32_t num_result_components = _.GetDimension(result_type);
       const uint32_t result_component_type = _.GetComponentType(result_type);
       uint32_t given_component_count = 0;
@@ -231,7 +231,8 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
       bool comp_is_int32 = true, comp_is_const_int32 = true;
 
       if (result_opcode == spv::Op::OpTypeVector) {
-        if (num_operands <= 3) {
+        if (num_operands <= 3 &&
+            !_.HasCapability(spv::Capability::LongVectorEXT)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected number of constituents to be at least 2";
         }
@@ -529,18 +530,18 @@ spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
 spv_result_t ValidateVectorShuffle(ValidationState_t& _,
                                    const Instruction* inst) {
   auto resultType = _.FindDef(inst->type_id());
-  if (!resultType || resultType->opcode() != spv::Op::OpTypeVector) {
+  if (!_.IsVectorType(resultType->id())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "The Result Type of OpVectorShuffle must be"
-           << " OpTypeVector. Found Op" << spvOpcodeString(resultType->opcode())
-           << ".";
+           << " a vector type. Found Op"
+           << spvOpcodeString(resultType->opcode()) << ".";
   }
 
   // The number of components in Result Type must be the same as the number of
   // Component operands.
   auto componentCount = inst->operands().size() - 4;
-  auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
-  if (componentCount != resultVectorDimension) {
+  auto resultVectorDimension = _.GetDimension(resultType->id());
+  if (resultVectorDimension > 0 && componentCount != resultVectorDimension) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpVectorShuffle component literals count does not match "
               "Result Type <id> "
@@ -553,13 +554,13 @@ spv_result_t ValidateVectorShuffle(ValidationState_t& _,
   auto vector1Type = _.FindDef(vector1Object->type_id());
   auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
   auto vector2Type = _.FindDef(vector2Object->type_id());
-  if (!vector1Type || vector1Type->opcode() != spv::Op::OpTypeVector) {
+  if (!_.IsVectorType(vector1Type->id())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "The type of Vector 1 must be OpTypeVector.";
+           << "The type of Vector 1 must be a vector type.";
   }
-  if (!vector2Type || vector2Type->opcode() != spv::Op::OpTypeVector) {
+  if (!_.IsVectorType(vector2Type->id())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "The type of Vector 2 must be OpTypeVector.";
+           << "The type of Vector 2 must be a vector type.";
   }
 
   auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);

+ 28 - 3
3rdparty/spirv-tools/source/val/validate_constants.cpp

@@ -54,11 +54,11 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
   const auto constituent_count = inst->words().size() - 3;
   switch (result_type->opcode()) {
     case spv::Op::OpTypeVector:
-    case spv::Op::OpTypeCooperativeVectorNV: {
+    case spv::Op::OpTypeVectorIdEXT: {
       uint32_t num_result_components = _.GetDimension(result_type->id());
       bool comp_is_int32 = true, comp_is_const_int32 = true;
 
-      if (result_type->opcode() == spv::Op::OpTypeCooperativeVectorNV) {
+      if (result_type->opcode() == spv::Op::OpTypeVectorIdEXT) {
         uint32_t comp_count_id = result_type->GetOperandAs<uint32_t>(2);
         std::tie(comp_is_int32, comp_is_const_int32, num_result_components) =
             _.EvalInt32IfConst(comp_count_id);
@@ -463,7 +463,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
     case spv::Op::OpTypeMatrix:
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
     case spv::Op::OpTypeVector: {
       auto base_type = _.FindDef(instruction[2]);
       return base_type && IsTypeNullable(base_type->words(), _);
@@ -505,6 +505,28 @@ spv_result_t ValidateConstantNull(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateConstantSizeOfEXT(ValidationState_t& _,
+                                       const Instruction* inst) {
+  const Instruction* result_type = _.FindDef(inst->type_id());
+  const uint32_t bit_width = result_type->GetOperandAs<uint32_t>(1);
+  // VVL will validate the SPV_EXT_shader_64bit_indexing interaction
+  if (result_type->opcode() != spv::Op::OpTypeInt ||
+      (bit_width != 64 && bit_width != 32)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "For OpConstantSizeOfEXT instruction, its result type "
+           << "must be a 32-bit or 64-bit integer type scalar."
+           << " (OpCapability Int64 is required for 64-bit)";
+  }
+
+  const uint32_t type_operand = inst->GetOperandAs<uint32_t>(2);
+  if (!_.IsDescriptorType(type_operand)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "For OpConstantSizeOfEXT instruction, its Type operand <Id> "
+           << _.getIdName(type_operand) << " must be a Descriptor type.";
+  }
+  return SPV_SUCCESS;
+}
+
 // Validates that OpSpecConstant specializes to either int or float type.
 spv_result_t ValidateSpecConstant(ValidationState_t& _,
                                   const Instruction* inst) {
@@ -607,6 +629,9 @@ spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpSpecConstantOp:
       if (auto error = ValidateSpecConstantOp(_, inst)) return error;
       break;
+    case spv::Op::OpConstantSizeOfEXT:
+      if (auto error = ValidateConstantSizeOfEXT(_, inst)) return error;
+      break;
     default:
       break;
   }

+ 6 - 2
3rdparty/spirv-tools/source/val/validate_conversion.cpp

@@ -541,7 +541,10 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
                << "Expected input to be a pointer or int or float vector "
                << "or scalar: " << spvOpcodeString(opcode);
 
-      if (result_is_coopvec != input_is_coopvec)
+      // NV_cooperative_vector doesn't allow bitcasting between vec<->coopvec,
+      // but long_vector does.
+      if (result_is_coopvec != input_is_coopvec &&
+          !_.HasCapability(spv::Capability::LongVectorEXT))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Cooperative vector can only be cast to another cooperative "
                << "vector: " << spvOpcodeString(opcode);
@@ -551,7 +554,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
                << "Cooperative matrix can only be cast to another cooperative "
                << "matrix: " << spvOpcodeString(opcode);
 
-      if (result_is_coopvec) {
+      if (result_is_coopvec && input_is_coopvec &&
+          !_.HasCapability(spv::Capability::LongVectorEXT)) {
         spv_result_t ret =
             _.CooperativeVectorDimensionsMatch(inst, result_type, input_type);
         if (ret != SPV_SUCCESS) return ret;

+ 131 - 17
3rdparty/spirv-tools/source/val/validate_decorations.cpp

@@ -122,7 +122,8 @@ bool isMissingOffsetInStruct(uint32_t struct_id, ValidationState_t& vstate) {
     hasOffset.resize(struct_members.size(), false);
 
     for (auto& decoration : vstate.id_decorations(struct_id)) {
-      if (spv::Decoration::Offset == decoration.dec_type() &&
+      if ((spv::Decoration::Offset == decoration.dec_type() ||
+           spv::Decoration::OffsetIdEXT == decoration.dec_type()) &&
           Decoration::kInvalidMember != decoration.struct_member_index()) {
         // Offset 0xffffffff is not valid so ignore it for simplicity's sake.
         if (decoration.params()[0] == 0xffffffff) return true;
@@ -170,6 +171,10 @@ uint32_t getBaseAlignment(uint32_t member_id, bool roundUp,
     case spv::Op::OpTypeImage:
       if (vstate.HasCapability(spv::Capability::BindlessTextureNV))
         return vstate.samplerimage_variable_address_mode() / 8;
+      // SPV_EXT_descriptor_heap provides a way to access opaque images, we
+      // assume alignment is validated at runtime as it is determined by the
+      // client API
+      if (vstate.HasCapability(spv::Capability::DescriptorHeapEXT)) return 1;
       assert(0);
       return 0;
     case spv::Op::OpTypeInt:
@@ -182,7 +187,19 @@ uint32_t getBaseAlignment(uint32_t member_id, bool roundUp,
       const auto componentAlignment = getBaseAlignment(
           componentId, roundUp, inherited, constraints, vstate);
       baseAlignment =
-          componentAlignment * (numComponents == 3 ? 4 : numComponents);
+          componentAlignment *
+          ((numComponents == 3 || numComponents > 4) ? 4 : numComponents);
+      break;
+    }
+    case spv::Op::OpTypeVectorIdEXT: {
+      const auto componentId = words[2];
+      const auto numComponents = vstate.GetDimension(inst->id());
+      assert(numComponents != 0);
+      const auto componentAlignment = getBaseAlignment(
+          componentId, roundUp, inherited, constraints, vstate);
+      baseAlignment =
+          componentAlignment *
+          ((numComponents == 3 || numComponents > 4) ? 4 : numComponents);
       break;
     }
     case spv::Op::OpTypeMatrix: {
@@ -245,12 +262,17 @@ uint32_t getScalarAlignment(uint32_t type_id, ValidationState_t& vstate) {
     case spv::Op::OpTypeImage:
       if (vstate.HasCapability(spv::Capability::BindlessTextureNV))
         return vstate.samplerimage_variable_address_mode() / 8;
+      // SPV_EXT_descriptor_heap provides a way to access opaque images, we
+      // assume alignment is validated at runtime as it is determined by the
+      // client API
+      if (vstate.HasCapability(spv::Capability::DescriptorHeapEXT)) return 1;
       assert(0);
       return 0;
     case spv::Op::OpTypeInt:
     case spv::Op::OpTypeFloat:
       return words[2] / 8;
     case spv::Op::OpTypeVector:
+    case spv::Op::OpTypeVectorIdEXT:
     case spv::Op::OpTypeMatrix:
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeRuntimeArray: {
@@ -293,6 +315,10 @@ uint32_t getSize(uint32_t member_id, const LayoutConstraints& inherited,
     case spv::Op::OpTypeImage:
       if (vstate.HasCapability(spv::Capability::BindlessTextureNV))
         return vstate.samplerimage_variable_address_mode() / 8;
+      // SPV_EXT_descriptor_heap provides a way to access opaque images, we
+      // assume alignment is validated at runtime as it is determined by the
+      // client API
+      if (vstate.HasCapability(spv::Capability::DescriptorHeapEXT)) return 1;
       assert(0);
       return 0;
     case spv::Op::OpTypeInt:
@@ -306,6 +332,15 @@ uint32_t getSize(uint32_t member_id, const LayoutConstraints& inherited,
       const auto size = componentSize * numComponents;
       return size;
     }
+    case spv::Op::OpTypeVectorIdEXT: {
+      const auto componentId = words[2];
+      const auto numComponents = vstate.GetDimension(inst->id());
+      assert(numComponents != 0);
+      const auto componentSize =
+          getSize(componentId, inherited, constraints, vstate);
+      const auto size = componentSize * numComponents;
+      return size;
+    }
     case spv::Op::OpTypeArray: {
       const auto sizeInst = vstate.FindDef(words[3]);
       if (spvOpcodeIsSpecConstant(sizeInst->opcode())) return 0;
@@ -544,7 +579,8 @@ spv_result_t checkLayout(uint32_t struct_id, spv::StorageClass storage_class,
     }
 
     if (!scalar_block_layout && relaxed_block_layout &&
-        opcode == spv::Op::OpTypeVector) {
+        (opcode == spv::Op::OpTypeVector ||
+         opcode == spv::Op::OpTypeVectorIdEXT)) {
       // In relaxed block layout, the vector offset must be aligned to the
       // vector's scalar element type.
       const auto componentId = inst->words()[2];
@@ -568,7 +604,8 @@ spv_result_t checkLayout(uint32_t struct_id, spv::StorageClass storage_class,
                              << nextValidOffset - 1 << extra();
     if (!scalar_block_layout && relaxed_block_layout) {
       // Check improper straddle of vectors.
-      if (spv::Op::OpTypeVector == opcode &&
+      if ((spv::Op::OpTypeVector == opcode ||
+           spv::Op::OpTypeVectorIdEXT == opcode) &&
           hasImproperStraddle(id, offset, constraint, constraints, vstate))
         return fail(memberIdx)
                << "is an improperly straddling vector at offset " << offset
@@ -866,6 +903,11 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
           }
         }
 
+        // Descriptor heap's base variables have no data type in declaration.
+        if (untyped_pointers && var_instr->words().size() < 5 &&
+            vstate.IsDescriptorHeapBaseVariable(var_instr))
+          continue;
+
         // It is guaranteed (by validator ID checks) that ptr_instr is
         // OpTypePointer. Word 3 of this instruction is the type being pointed
         // to. For untyped variables, the pointee type comes from the data type
@@ -998,8 +1040,7 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
               hasDecoration(var_instr->id(), spv::Decoration::Flat, vstate);
           if (has_frag && storage_class == spv::StorageClass::Input &&
               !has_flat &&
-              ((vstate.IsFloatScalarType(type_id) &&
-                vstate.GetBitWidth(type_id) == 64) ||
+              (vstate.IsFloatScalarType(type_id, 64) ||
                vstate.IsIntScalarOrVectorType(type_id))) {
             return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
                      << vstate.VkErrorID(4744)
@@ -1183,6 +1224,9 @@ spv_result_t CheckDecorationsOfVariables(ValidationState_t& vstate) {
       // storage classes are decorated with DescriptorSet and Binding
       // (VUID-06677).
       if (uniform_constant || storage_buffer || uniform) {
+        if (vstate.IsDescriptorHeapBaseVariable(&inst)) {
+          continue;
+        }
         // Skip validation if the variable is not used and we're looking
         // at a module coming from HLSL that has not been legalized yet.
         if (vstate.options()->before_hlsl_legalization &&
@@ -1238,7 +1282,8 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
       if (spvIsVulkanEnv(vstate.context()->target_env)) {
         // Vulkan: There must be no more than one PushConstant block per entry
         // point.
-        if (push_constant) {
+        if (push_constant &&
+            !(vstate.HasCapability(spv::Capability::PushConstantBanksNV))) {
           auto entry_points = vstate.EntryPointReferences(var_id);
           for (auto ep_id : entry_points) {
             const bool already_used = !uses_push_constant.insert(ep_id).second;
@@ -1279,6 +1324,24 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
           storageClass == spv::StorageClass::Workgroup &&
           vstate.HasCapability(
               spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
+
+      if (spvIsVulkanEnv(vstate.context()->target_env) &&
+          inst.opcode() == spv::Op::OpUntypedVariableKHR &&
+          storageClass != spv::StorageClass::UniformConstant &&
+          vstate.IsDescriptorHeapBaseVariable(&inst)) {
+        if (vstate.IsBuiltin(inst.id(), spv::BuiltIn::ResourceHeapEXT)) {
+          return vstate.diag(SPV_ERROR_INVALID_DATA, &inst)
+                 << vstate.VkErrorID(11241)
+                 << "The variable decorated with ResourceHeapEXT must be "
+                 << "declared using the UniformConstant storage class.";
+        }
+        if (vstate.IsBuiltin(inst.id(), spv::BuiltIn::SamplerHeapEXT)) {
+          return vstate.diag(SPV_ERROR_INVALID_DATA, &inst)
+                 << vstate.VkErrorID(11239)
+                 << "The variable decorated with SamplerHeapEXT must be "
+                 << "declared using the UniformConstant storage class.";
+        }
+      }
       if (uniform || push_constant || storage_buffer || phys_storage_buffer ||
           workgroup) {
         const auto ptrInst = vstate.FindDef(words[1]);
@@ -1376,12 +1439,14 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
               if (!checkForRequiredDecoration(
                       id,
                       [](spv::Decoration d) {
-                        return d == spv::Decoration::ArrayStride;
+                        return d == spv::Decoration::ArrayStride ||
+                               d == spv::Decoration::ArrayStrideIdEXT;
                       },
                       spv::Op::OpTypeArray, vstate)) {
                 return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
                        << "Structure id " << id << " decorated as " << deco_str
-                       << " must be explicitly laid out with ArrayStride "
+                       << " must be explicitly laid out with ArrayStride or "
+                          "ArrayStrideIdEXT "
                           "decorations.";
               }
 
@@ -1529,10 +1594,13 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
           bufferRules
               ? (sc == spv::StorageClass::Uniform ? "BufferBlock" : "Block")
               : "Block";
-      if (auto result =
-              checkLayout(data_type_id, sc, deco_str, !bufferRules,
-                          scalar_block_layout, 0, constraints, vstate)) {
-        return result;
+
+      if (!vstate.IsDescriptorHeapBaseVariable(&inst)) {
+        if (auto result =
+                checkLayout(data_type_id, sc, deco_str, !bufferRules,
+                            scalar_block_layout, 0, constraints, vstate)) {
+          return result;
+        }
       }
     }
   }
@@ -1567,7 +1635,8 @@ spv_result_t CheckDecorationsCompatibility(ValidationState_t& vstate) {
   // An Array of pairs where the decorations in the pair cannot both be applied
   // to the same member.
   static const spv::Decoration mutually_exclusive_per_member[][2] = {
-      {spv::Decoration::RowMajor, spv::Decoration::ColMajor}};
+      {spv::Decoration::RowMajor, spv::Decoration::ColMajor},
+      {spv::Decoration::Offset, spv::Decoration::OffsetIdEXT}};
   static const auto num_mutually_exclusive_per_mem_pairs =
       sizeof(mutually_exclusive_per_member) / (2 * sizeof(spv::Decoration));
 
@@ -1609,7 +1678,8 @@ spv_result_t CheckDecorationsCompatibility(ValidationState_t& vstate) {
                  << " is not allowed.";
         }
       }
-    } else if (spv::Op::OpMemberDecorate == inst.opcode()) {
+    } else if (spv::Op::OpMemberDecorate == inst.opcode() ||
+               spv::Op::OpMemberDecorateIdEXT == inst.opcode()) {
       const auto id = words[1];
       const auto member_id = words[2];
       const auto dec_type = static_cast<spv::Decoration>(words[3]);
@@ -1772,6 +1842,7 @@ spv_result_t CheckNonReadableWritableDecorations(ValidationState_t& vstate,
     const auto type_id = inst.type_id();
     if (opcode != spv::Op::OpVariable &&
         opcode != spv::Op::OpUntypedVariableKHR &&
+        opcode != spv::Op::OpBufferPointerEXT &&
         opcode != spv::Op::OpFunctionParameter &&
         opcode != spv::Op::OpRawAccessChainNV) {
       return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
@@ -1787,6 +1858,19 @@ spv_result_t CheckNonReadableWritableDecorations(ValidationState_t& vstate,
             : opcode == spv::Op::OpUntypedVariableKHR
                   ? inst.GetOperandAs<spv::StorageClass>(3)
                   : spv::StorageClass::Max;
+
+    if (opcode == spv::Op::OpBufferPointerEXT) {
+      auto result_type = vstate.FindDef(inst.type_id());
+      auto sc = result_type->GetOperandAs<spv::StorageClass>(1);
+      if (sc == spv::StorageClass::Uniform && is_non_writable) {
+        return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+               << "Target of NonWritable decoration is invalid: "
+               << "cannot be used to OpBufferPointerEXT "
+               << "with Uniform storage class";
+      }
+      return SPV_SUCCESS;
+    }
+
     if ((var_storage_class == spv::StorageClass::Function ||
          var_storage_class == spv::StorageClass::Private) &&
         vstate.features().nonwritable_var_in_function_or_private &&
@@ -2181,7 +2265,8 @@ spv::Decoration UsesExplicitLayout(
     const auto iter = id_decs.find(type_id);
     if (iter != id_decs.end()) {
       bool allowLayoutDecorations = false;
-      if (type_inst->opcode() == spv::Op::OpTypePointer) {
+      if (type_inst->opcode() == spv::Op::OpTypePointer ||
+          type_inst->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
         const auto sc = type_inst->GetOperandAs<spv::StorageClass>(1);
         allowLayoutDecorations = AllowsLayout(vstate, sc);
       }
@@ -2245,6 +2330,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
     spv::StorageClass sc = spv::StorageClass::Max;
     spv::Decoration layout_dec = spv::Decoration::Max;
     uint32_t fail_id = 0;
+    uint32_t base_id = 0;
     // Variables are the main place to check for improper decorations, but some
     // untyped pointer instructions must also be checked since those types may
     // never be instantiated by a variable. Unlike verifying a valid layout,
@@ -2255,6 +2341,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
       case spv::Op::OpUntypedVariableKHR: {
         sc = inst.GetOperandAs<spv::StorageClass>(2);
         auto check_id = type_id;
+        base_id = inst.id();
         if (inst.opcode() == spv::Op::OpUntypedVariableKHR) {
           if (inst.operands().size() > 3) {
             check_id = inst.GetOperandAs<uint32_t>(3);
@@ -2275,6 +2362,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
         // Check both the base type and return type. The return type may have an
         // invalid array stride.
         sc = type_inst->GetOperandAs<spv::StorageClass>(1);
+        base_id = vstate.FindDef(inst.GetOperandAs<uint32_t>(3))->id();
         if (!AllowsLayout(vstate, sc)) {
           const auto base_type_id = inst.GetOperandAs<uint32_t>(2);
           layout_dec = UsesExplicitLayout(vstate, base_type_id, cache);
@@ -2295,6 +2383,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
             vstate.FindDef(inst.GetOperandAs<uint32_t>(3))->type_id();
         const auto ptr_ty = vstate.FindDef(ptr_ty_id);
         sc = ptr_ty->GetOperandAs<spv::StorageClass>(1);
+        base_id = vstate.FindDef(inst.GetOperandAs<uint32_t>(3))->id();
         if (!AllowsLayout(vstate, sc)) {
           const auto base_type_id = inst.GetOperandAs<uint32_t>(2);
           layout_dec = UsesExplicitLayout(vstate, base_type_id, cache);
@@ -2307,6 +2396,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
       case spv::Op::OpLoad: {
         const auto ptr_id = inst.GetOperandAs<uint32_t>(2);
         const auto ptr_type = vstate.FindDef(vstate.FindDef(ptr_id)->type_id());
+        base_id = ptr_id;
         if (ptr_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
           // For untyped pointers check the return type for an invalid layout.
           sc = ptr_type->GetOperandAs<spv::StorageClass>(1);
@@ -2322,6 +2412,7 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
       case spv::Op::OpStore: {
         const auto ptr_id = inst.GetOperandAs<uint32_t>(1);
         const auto ptr_type = vstate.FindDef(vstate.FindDef(ptr_id)->type_id());
+        base_id = inst.GetOperandAs<uint32_t>(0);
         if (ptr_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
           // For untyped pointers, check the type of the data operand for an
           // invalid layout.
@@ -2336,10 +2427,33 @@ spv_result_t CheckInvalidVulkanExplicitLayout(ValidationState_t& vstate) {
         }
         break;
       }
+      case spv::Op::OpBufferPointerEXT: {
+        const auto ptr_id = inst.GetOperandAs<uint32_t>(1);
+        const auto ptr_type = vstate.FindDef(vstate.FindDef(ptr_id)->type_id());
+        // Check the type of the data operand for an invalid layout.
+        sc = ptr_type->GetOperandAs<spv::StorageClass>(1);
+        if (!AllowsLayout(vstate, sc) &&
+            UsesExplicitLayout(vstate, type_id, cache) !=
+                spv::Decoration::Max) {
+          return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+                 << vstate.VkErrorID(11346)
+                 << "The result type operand of OpBufferPointerEXT must have "
+                 << "a Type operand that is explicitly laid out : "
+                 << vstate.getIdName(type_id);
+        } else if (sc != spv::StorageClass::StorageBuffer &&
+                   sc != spv::StorageClass::Uniform) {
+          return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
+                 << "OpBufferPointerEXT's Result Type must be a pointer type "
+                 << "with a Storage Class of Uniform or StorageBuffer.";
+        }
+        break;
+      }
       default:
         break;
     }
-    if (fail_id != 0) {
+
+    if (fail_id != 0 &&
+        !vstate.IsDescriptorHeapBaseVariable(vstate.FindDef(base_id))) {
       return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
              << vstate.VkErrorID(10684)
              << "Invalid explicit layout decorations on type for operand "

+ 17 - 28
3rdparty/spirv-tools/source/val/validate_extensions.cpp

@@ -1512,7 +1512,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
 
     case GLSLstd450PackSnorm4x8:
     case GLSLstd450PackUnorm4x8: {
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+      if (!_.IsIntScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected Result Type to be 32-bit int scalar type";
@@ -1531,7 +1531,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
     case GLSLstd450PackSnorm2x16:
     case GLSLstd450PackUnorm2x16:
     case GLSLstd450PackHalf2x16: {
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+      if (!_.IsIntScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected Result Type to be 32-bit int scalar type";
@@ -1548,8 +1548,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
     }
 
     case GLSLstd450PackDouble2x32: {
-      if (!_.IsFloatScalarType(result_type) ||
-          _.GetBitWidth(result_type) != 64) {
+      if (!_.IsFloatScalarType(result_type, 64)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected Result Type to be 64-bit float scalar type";
@@ -1577,7 +1576,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
       }
 
       const uint32_t v_type = _.GetOperandTypeId(inst, 4);
-      if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) {
+      if (!_.IsIntScalarType(v_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand P to be a 32-bit int scalar";
@@ -1598,7 +1597,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
       }
 
       const uint32_t v_type = _.GetOperandTypeId(inst, 4);
-      if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) {
+      if (!_.IsIntScalarType(v_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand P to be a 32-bit int scalar";
@@ -1616,7 +1615,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
       }
 
       const uint32_t v_type = _.GetOperandTypeId(inst, 4);
-      if (!_.IsFloatScalarType(v_type) || _.GetBitWidth(v_type) != 64) {
+      if (!_.IsFloatScalarType(v_type, 64)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand V to be a 64-bit float scalar";
@@ -1802,8 +1801,7 @@ spv_result_t ValidateExtInstGlslStd450(ValidationState_t& _,
 
       if (ext_inst_key == GLSLstd450InterpolateAtSample) {
         const uint32_t sample_type = _.GetOperandTypeId(inst, 5);
-        if (!_.IsIntScalarType(sample_type) ||
-            _.GetBitWidth(sample_type) != 32) {
+        if (!_.IsIntScalarType(sample_type, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << GetExtInstName(_, inst) << ": "
                  << "expected Sample to be 32-bit integer";
@@ -2586,8 +2584,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                << " can only be used with physical addressing models";
       }
 
-      if (!_.IsIntScalarType(offset_type) ||
-          _.GetBitWidth(offset_type) != size_t_bit_width) {
+      if (!_.IsIntScalarType(offset_type, size_t_bit_width)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand Offset to be of type size_t ("
@@ -2662,8 +2659,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                << " can only be used with physical addressing models";
       }
 
-      if (!_.IsIntScalarType(offset_type) ||
-          _.GetBitWidth(offset_type) != size_t_bit_width) {
+      if (!_.IsIntScalarType(offset_type, size_t_bit_width)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand Offset to be of type size_t ("
@@ -2715,8 +2711,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                << " can only be used with physical addressing models";
       }
 
-      if (!_.IsIntScalarType(offset_type) ||
-          _.GetBitWidth(offset_type) != size_t_bit_width) {
+      if (!_.IsIntScalarType(offset_type, size_t_bit_width)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand Offset to be of type size_t ("
@@ -2743,8 +2738,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                   "Generic, CrossWorkgroup, Workgroup or Function";
       }
 
-      if ((!_.IsFloatScalarType(p_data_type) ||
-           _.GetBitWidth(p_data_type) != 16) &&
+      if ((!_.IsFloatScalarType(p_data_type, 16)) &&
           !_.ContainsUntypedPointer(p_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
@@ -2778,8 +2772,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                << " can only be used with physical addressing models";
       }
 
-      if (!_.IsIntScalarType(offset_type) ||
-          _.GetBitWidth(offset_type) != size_t_bit_width) {
+      if (!_.IsIntScalarType(offset_type, size_t_bit_width)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand Offset to be of type size_t ("
@@ -2806,8 +2799,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                   "Generic, CrossWorkgroup, Workgroup or Function";
       }
 
-      if ((!_.IsFloatScalarType(p_data_type) ||
-           _.GetBitWidth(p_data_type) != 16) &&
+      if ((!_.IsFloatScalarType(p_data_type, 16)) &&
           !_.ContainsUntypedPointer(p_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
@@ -2872,8 +2864,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                << " can only be used with physical addressing models";
       }
 
-      if (!_.IsIntScalarType(offset_type) ||
-          _.GetBitWidth(offset_type) != size_t_bit_width) {
+      if (!_.IsIntScalarType(offset_type, size_t_bit_width)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected operand Offset to be of type size_t ("
@@ -2899,8 +2890,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
                   "CrossWorkgroup, Workgroup or Function";
       }
 
-      if ((!_.IsFloatScalarType(p_data_type) ||
-           _.GetBitWidth(p_data_type) != 16) &&
+      if ((!_.IsFloatScalarType(p_data_type, 16)) &&
           !_.ContainsUntypedPointer(p_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
@@ -2990,7 +2980,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
     }
 
     case OpenCLLIB::Printf: {
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+      if (!_.IsIntScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "
                << "expected Result Type to be a 32-bit int type";
@@ -3038,8 +3028,7 @@ spv_result_t ValidateExtInstOpenClStd(ValidationState_t& _,
       if (_.IsIntArrayType(format_data_type))
         format_data_type = _.GetComponentType(format_data_type);
 
-      if ((!_.IsIntScalarType(format_data_type) ||
-           _.GetBitWidth(format_data_type) != 8) &&
+      if (!_.IsIntScalarType(format_data_type, 8) &&
           !_.ContainsUntypedPointer(format_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << GetExtInstName(_, inst) << ": "

+ 2 - 2
3rdparty/spirv-tools/source/val/validate_function.cpp

@@ -355,14 +355,14 @@ spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
   const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
   const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
   const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
-  if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
+  if (!_.IsIntScalarType(param0_id, 32)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpCooperativeMatrixPerElementOpNV function type first parameter "
               "type <id> "
            << _.getIdName(param0_id) << " must be a 32-bit integer.";
   }
 
-  if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
+  if (!_.IsIntScalarType(param1_id, 32)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpCooperativeMatrixPerElementOpNV function type second "
               "parameter type <id> "

+ 229 - 0
3rdparty/spirv-tools/source/val/validate_group.cpp

@@ -0,0 +1,229 @@
+// Copyright (c) 2026 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstdint>
+
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validate_scopes.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+spv_result_t ValidateGroupAnyAll(ValidationState_t& _,
+                                 const Instruction* inst) {
+  if (!_.IsBoolScalarType(inst->type_id())) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Result must be a boolean scalar type";
+  }
+
+  if (!_.IsBoolScalarType(_.GetOperandTypeId(inst, 3))) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Predicate must be a boolean scalar type";
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGroupBroadcast(ValidationState_t& _,
+                                    const Instruction* inst) {
+  const uint32_t type_id = inst->type_id();
+  if (!_.IsFloatScalarOrVectorType(type_id) &&
+      !_.IsIntScalarOrVectorType(type_id) &&
+      !_.IsBoolScalarOrVectorType(type_id)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Result must be a scalar or vector of integer, floating-point, "
+              "or boolean type";
+  }
+
+  const uint32_t value_type_id = _.GetOperandTypeId(inst, 3);
+  if (value_type_id != type_id) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "The type of Value must match the Result type";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGroupFloat(ValidationState_t& _, const Instruction* inst) {
+  const uint32_t type_id = inst->type_id();
+  if (!_.IsFloatScalarOrVectorType(type_id)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Result must be a scalar or vector of float type";
+  }
+
+  const uint32_t x_type_id = _.GetOperandTypeId(inst, 4);
+  if (x_type_id != type_id) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "The type of X must match the Result type";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGroupInt(ValidationState_t& _, const Instruction* inst) {
+  const uint32_t type_id = inst->type_id();
+  if (!_.IsIntScalarOrVectorType(type_id)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Result must be a scalar or vector of integer type";
+  }
+
+  const uint32_t x_type_id = _.GetOperandTypeId(inst, 4);
+  if (x_type_id != type_id) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "The type of X must match the Result type";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGroupAsyncCopy(ValidationState_t& _,
+                                    const Instruction* inst) {
+  if (_.FindDef(inst->type_id())->opcode() != spv::Op::OpTypeEvent) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "The result type must be OpTypeEvent.";
+  }
+
+  const uint32_t destination = _.GetOperandTypeId(inst, 3);
+  const Instruction* destination_pointer = _.FindDef(destination);
+  if (destination_pointer->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Destination to be a pointer.";
+  }
+  const auto destination_sc =
+      destination_pointer->GetOperandAs<spv::StorageClass>(1);
+  if (destination_sc != spv::StorageClass::Workgroup &&
+      destination_sc != spv::StorageClass::CrossWorkgroup) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Destination to be a pointer with storage class "
+              "Workgroup or CrossWorkgroup.";
+  }
+  const uint32_t destination_type =
+      destination_pointer->GetOperandAs<uint32_t>(2);
+  if (!_.IsIntScalarOrVectorType(destination_type) &&
+      !_.IsFloatScalarOrVectorType(destination_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Destination to be a pointer to scalar or vector of "
+              "floating-point type or integer type.";
+  }
+
+  const uint32_t source = _.GetOperandTypeId(inst, 4);
+  const Instruction* source_pointer = _.FindDef(source);
+  const auto source_sc = source_pointer->GetOperandAs<spv::StorageClass>(1);
+  const uint32_t source_type = source_pointer->GetOperandAs<uint32_t>(2);
+  if (destination_type != source_type) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Destination and Source to be the same type.";
+  }
+
+  if (destination_sc == spv::StorageClass::Workgroup &&
+      source_sc != spv::StorageClass::CrossWorkgroup) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "If Destination storage class is Workgroup, then the Source "
+              "storage class must be CrossWorkgroup.";
+  } else if (destination_sc == spv::StorageClass::CrossWorkgroup &&
+             source_sc != spv::StorageClass::Workgroup) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "If Destination storage class is CrossWorkgroup, then the Source "
+              "storage class must be Workgroup.";
+  }
+
+  const bool is_physical_64 =
+      _.addressing_model() == spv::AddressingModel::Physical64;
+  const uint32_t bit_width = is_physical_64 ? 64 : 32;
+
+  const uint32_t num_elements_type =
+      _.GetTypeId(inst->GetOperandAs<uint32_t>(5));
+  if (!_.IsIntScalarType(num_elements_type, bit_width)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "NumElements must be a " << bit_width
+           << "-bit int scalar when Addressing Model is "
+           << (is_physical_64 ? "Physical64" : "Physical32");
+  }
+
+  const uint32_t stride_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(6));
+  if (!_.IsIntScalarType(stride_type, bit_width)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Stride must be a " << bit_width
+           << "-bit int scalar when Addressing Model is "
+           << (is_physical_64 ? "Physical64" : "Physical32");
+  }
+
+  const uint32_t event = _.GetOperandTypeId(inst, 7);
+  const Instruction* event_type = _.FindDef(event);
+  if (event_type->opcode() != spv::Op::OpTypeEvent) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Event to be type OpTypeEvent.";
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGroupWaitEvents(ValidationState_t& _,
+                                     const Instruction* inst) {
+  const uint32_t num_events_id = _.GetOperandTypeId(inst, 1);
+  if (!_.IsIntScalarType(num_events_id, 32)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Num Events to be a 32-bit int scalar.";
+  }
+
+  const uint32_t events_id = _.GetOperandTypeId(inst, 2);
+  const Instruction* var_pointer = _.FindDef(events_id);
+  if (var_pointer->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Events List to be a pointer.";
+  }
+  const Instruction* event_list_type =
+      _.FindDef(var_pointer->GetOperandAs<uint32_t>(2));
+  if (event_list_type->opcode() != spv::Op::OpTypeEvent) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Events List to be a pointer to OpTypeEvent.";
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace
+
+spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst) {
+  const spv::Op opcode = inst->opcode();
+
+  switch (opcode) {
+    case spv::Op::OpGroupAny:
+    case spv::Op::OpGroupAll:
+      return ValidateGroupAnyAll(_, inst);
+    case spv::Op::OpGroupBroadcast:
+      return ValidateGroupBroadcast(_, inst);
+    case spv::Op::OpGroupFAdd:
+    case spv::Op::OpGroupFMax:
+    case spv::Op::OpGroupFMin:
+      return ValidateGroupFloat(_, inst);
+    case spv::Op::OpGroupIAdd:
+    case spv::Op::OpGroupUMin:
+    case spv::Op::OpGroupSMin:
+    case spv::Op::OpGroupUMax:
+    case spv::Op::OpGroupSMax:
+      return ValidateGroupInt(_, inst);
+    case spv::Op::OpGroupAsyncCopy:
+      return ValidateGroupAsyncCopy(_, inst);
+    case spv::Op::OpGroupWaitEvents:
+      return ValidateGroupWaitEvents(_, inst);
+    default:
+      break;
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace val
+}  // namespace spvtools

+ 3 - 0
3rdparty/spirv-tools/source/val/validate_id.cpp

@@ -123,6 +123,9 @@ bool InstructionCanHaveTypeOperand(const Instruction* inst) {
       spv::Op::OpUntypedArrayLengthKHR,
       spv::Op::OpFunction,
       spv::Op::OpAsmINTEL,
+      spv::Op::OpConstantSizeOfEXT,
+      spv::Op::OpBufferPointerEXT,
+      spv::Op::OpUntypedImageTexelPointerEXT,
   };
   const auto opcode = inst->opcode();
   bool type_instruction = spvOpcodeGeneratesType(opcode);

+ 220 - 211
3rdparty/spirv-tools/source/val/validate_image.cpp

@@ -237,6 +237,23 @@ uint32_t GetMinCoordSize(spv::Op opcode, const ImageTypeInfo& info) {
     return 3;
   }
 
+  if (opcode == spv::Op::OpImageQueryLod) {
+    return GetPlaneCoordSize(info);
+  }
+
+  if (opcode == spv::Op::OpImageTexelPointer) {
+    if (info.arrayed == 0) {
+      return GetPlaneCoordSize(info);
+    } else if (info.dim == spv::Dim::Dim1D) {
+      return 2;
+    } else if (info.dim == spv::Dim::Cube || info.dim == spv::Dim::Dim2D) {
+      return 3;
+    } else {
+      assert(false);
+      return 0;  // caught elsewhere
+    }
+  }
+
   return GetPlaneCoordSize(info) + info.arrayed + (IsProj(opcode) ? 1 : 0);
 }
 
@@ -314,10 +331,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
              << "Image Operand Bias can only be used with ImplicitLod opcodes";
     }
 
-    const uint32_t type_id = _.GetTypeId(inst->word(word_index++));
-    if (!_.IsFloatScalarType(type_id)) {
+    const uint32_t bias_type_id = _.GetTypeId(inst->word(word_index++));
+    if (!_.IsFloatScalarType(bias_type_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand Bias to be float scalar";
+             << "Expected Image Operand Bias to be a 32-bit float scalar";
     }
 
     if (info.dim != spv::Dim::Dim1D && info.dim != spv::Dim::Dim2D &&
@@ -327,7 +344,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                 "or Cube";
     }
 
-    // Multisampled is already checked.
+    // - |Sample| operand is required to have MS != 0
+    // - |Sample| is only allowed with [Fetch, Write, or Read]
+    // - |Bias| can only be used with |ImplicitLod| opcodes
+    // Multisampled is already checked in all cases
   }
 
   if (mask & uint32_t(spv::ImageOperandsMask::Lod)) {
@@ -345,17 +365,19 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                 "time";
     }
 
-    const uint32_t type_id = _.GetTypeId(inst->word(word_index++));
+    const uint32_t lod_type_id = _.GetTypeId(inst->word(word_index++));
     if (is_explicit_lod || is_valid_gather_lod_bias_amd) {
-      if (!_.IsFloatScalarType(type_id)) {
+      if (!_.IsFloatScalarType(lod_type_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Image Operand Lod to be float scalar when used "
+               << "Expected Image Operand Lod to be a 32-bit float scalar when "
+                  "used "
                << "with ExplicitLod";
       }
     } else {
-      if (!_.IsIntScalarType(type_id)) {
+      if (!_.IsIntScalarType(lod_type_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Image Operand Lod to be int scalar when used with "
+               << "Expected Image Operand Lod to be a 32-bit int scalar when "
+                  "used with "
                << "OpImageFetch";
       }
     }
@@ -367,7 +389,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                 "or Cube";
     }
 
-    // Multisampled is already checked.
+    if (info.multisampled != 0) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Image Operand Lod requires 'MS' parameter to be 0";
+    }
   }
 
   if (mask & uint32_t(spv::ImageOperandsMask::Grad)) {
@@ -379,9 +404,12 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
     const uint32_t dx_type_id = _.GetTypeId(inst->word(word_index++));
     const uint32_t dy_type_id = _.GetTypeId(inst->word(word_index++));
     if (!_.IsFloatScalarOrVectorType(dx_type_id) ||
-        !_.IsFloatScalarOrVectorType(dy_type_id)) {
+        _.GetBitWidth(dx_type_id) != 32 ||
+        !_.IsFloatScalarOrVectorType(dy_type_id) ||
+        _.GetBitWidth(dy_type_id) != 32) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected both Image Operand Grad ids to be float scalars or "
+             << "Expected both Image Operand Grad ids to be 32-bit float "
+                "scalars or "
              << "vectors";
     }
 
@@ -400,7 +428,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
              << " components, but given " << dy_size;
     }
 
-    // Multisampled is already checked.
+    // - |Sample| operand is required to have MS != 0
+    // - |Sample| is only allowed with [Fetch, Write, or Read]
+    // - |Grad| can only be used with |ExplicitLod| opcodes
+    // Multisampled is already checked in all cases
   }
 
   if (mask & uint32_t(spv::ImageOperandsMask::ConstOffset)) {
@@ -410,21 +441,23 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                 "'Dim'";
     }
 
-    const uint32_t id = inst->word(word_index++);
-    const uint32_t type_id = _.GetTypeId(id);
-    if (!_.IsIntScalarOrVectorType(type_id)) {
+    const uint32_t offset_id = inst->word(word_index++);
+    const uint32_t offset_type_id = _.GetTypeId(offset_id);
+    if (!_.IsIntScalarOrVectorType(offset_type_id) ||
+        _.GetBitWidth(offset_type_id) != 32) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand ConstOffset to be int scalar or "
+             << "Expected Image Operand ConstOffset to be a 32-bit int scalar "
+                "or "
              << "vector";
     }
 
-    if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) {
+    if (!spvOpcodeIsConstant(_.GetIdOpcode(offset_id))) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Image Operand ConstOffset to be a const object";
     }
 
     const uint32_t plane_size = GetPlaneCoordSize(info);
-    const uint32_t offset_size = _.GetDimension(type_id);
+    const uint32_t offset_size = _.GetDimension(offset_type_id);
     if (plane_size != offset_size) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Image Operand ConstOffset to have " << plane_size
@@ -438,16 +471,17 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
              << "Image Operand Offset cannot be used with Cube Image 'Dim'";
     }
 
-    const uint32_t id = inst->word(word_index++);
-    const uint32_t type_id = _.GetTypeId(id);
-    if (!_.IsIntScalarOrVectorType(type_id)) {
+    const uint32_t offset_id = inst->word(word_index++);
+    const uint32_t offset_type_id = _.GetTypeId(offset_id);
+    if (!_.IsIntScalarOrVectorType(offset_type_id) ||
+        _.GetBitWidth(offset_type_id) != 32) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand Offset to be int scalar or "
+             << "Expected Image Operand Offset to be a 32-bit int scalar or "
              << "vector";
     }
 
     const uint32_t plane_size = GetPlaneCoordSize(info);
-    const uint32_t offset_size = _.GetDimension(type_id);
+    const uint32_t offset_size = _.GetDimension(offset_type_id);
     if (plane_size != offset_size) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Image Operand Offset to have " << plane_size
@@ -487,9 +521,9 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                 "'Dim'";
     }
 
-    const uint32_t id = inst->word(word_index++);
-    const uint32_t type_id = _.GetTypeId(id);
-    const Instruction* type_inst = _.FindDef(type_id);
+    const uint32_t offset_id = inst->word(word_index++);
+    const uint32_t offset_type_id = _.GetTypeId(offset_id);
+    const Instruction* type_inst = _.FindDef(offset_type_id);
     assert(type_inst);
 
     if (type_inst->opcode() != spv::Op::OpTypeArray) {
@@ -509,13 +543,14 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
 
     const uint32_t component_type = type_inst->word(2);
     if (!_.IsIntVectorType(component_type) ||
-        _.GetDimension(component_type) != 2) {
+        _.GetDimension(component_type) != 2 ||
+        _.GetBitWidth(component_type) != 32) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand ConstOffsets array components to be "
-                "int vectors of size 2";
+             << "Expected Image Operand ConstOffsets array components to be a "
+                "32-bit int vectors of size 2";
     }
 
-    if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) {
+    if (!spvOpcodeIsConstant(_.GetIdOpcode(offset_id))) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Image Operand ConstOffsets to be a const object";
     }
@@ -537,10 +572,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
              << "Image Operand Sample requires non-zero 'MS' parameter";
     }
 
-    const uint32_t type_id = _.GetTypeId(inst->word(word_index++));
-    if (!_.IsIntScalarType(type_id)) {
+    const uint32_t sample_type_id = _.GetTypeId(inst->word(word_index++));
+    if (!_.IsIntScalarType(sample_type_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand Sample to be int scalar";
+             << "Expected Image Operand Sample to be a 32-bit int scalar";
     }
   }
 
@@ -551,10 +586,10 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
              << "opcodes or together with Image Operand Grad";
     }
 
-    const uint32_t type_id = _.GetTypeId(inst->word(word_index++));
-    if (!_.IsFloatScalarType(type_id)) {
+    const uint32_t minlod_type_id = _.GetTypeId(inst->word(word_index++));
+    if (!_.IsFloatScalarType(minlod_type_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Image Operand MinLod to be float scalar";
+             << "Expected Image Operand MinLod to be a 32-bit float scalar";
     }
 
     if (info.dim != spv::Dim::Dim1D && info.dim != spv::Dim::Dim2D &&
@@ -782,8 +817,7 @@ spv_result_t ValidateTypeImage(ValidationState_t& _, const Instruction* inst) {
            << "Corrupt image type definition";
   }
 
-  if (_.IsIntScalarType(info.sampled_type) &&
-      (64 == _.GetBitWidth(info.sampled_type)) &&
+  if (_.IsIntScalarType(info.sampled_type, 64) &&
       !_.HasCapability(spv::Capability::Int64ImageEXT)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Capability Int64ImageEXT is required when using Sampled Type of "
@@ -792,12 +826,9 @@ spv_result_t ValidateTypeImage(ValidationState_t& _, const Instruction* inst) {
 
   const auto target_env = _.context()->target_env;
   if (spvIsVulkanEnv(target_env)) {
-    if ((!_.IsFloatScalarType(info.sampled_type) &&
-         !_.IsIntScalarType(info.sampled_type)) ||
-        ((32 != _.GetBitWidth(info.sampled_type)) &&
-         (64 != _.GetBitWidth(info.sampled_type))) ||
-        ((64 == _.GetBitWidth(info.sampled_type)) &&
-         _.IsFloatScalarType(info.sampled_type))) {
+    if (!_.IsFloatScalarType(info.sampled_type, 32) &&
+        !_.IsIntScalarType(info.sampled_type, 32) &&
+        !_.IsIntScalarType(info.sampled_type, 64)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << _.VkErrorID(4656)
              << "Expected Sampled Type to be a 32-bit int, 64-bit int or "
@@ -1007,6 +1038,82 @@ bool IsAllowedSampledImageOperand(spv::Op opcode, ValidationState_t& _) {
   }
 }
 
+spv_result_t ValidateImageCoordinate(ValidationState_t& _,
+                                     const Instruction* inst,
+                                     const ImageTypeInfo& info,
+                                     uint32_t word_index) {
+  const spv::Op opcode = inst->opcode();
+  const uint32_t coord_type = _.GetOperandTypeId(inst, word_index);
+
+  const bool float_only =
+      opcode == spv::Op::OpImageSampleImplicitLod ||
+      opcode == spv::Op::OpImageSampleDrefImplicitLod ||
+      opcode == spv::Op::OpImageSampleDrefExplicitLod ||
+      opcode == spv::Op::OpImageSampleProjImplicitLod ||
+      opcode == spv::Op::OpImageSampleProjExplicitLod ||
+      opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
+      opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
+      opcode == spv::Op::OpImageGather ||
+      opcode == spv::Op::OpImageDrefGather ||
+      opcode == spv::Op::OpImageQueryLod ||
+      opcode == spv::Op::OpImageSparseSampleImplicitLod ||
+      opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
+      opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
+      opcode == spv::Op::OpImageSparseGather ||
+      opcode == spv::Op::OpImageSparseDrefGather;
+
+  const bool int_only = opcode == spv::Op::OpImageFetch ||
+                        opcode == spv::Op::OpImageSparseFetch ||
+                        opcode == spv::Op::OpImageTexelPointer ||
+                        opcode == spv::Op::OpUntypedImageTexelPointerEXT;
+
+  const bool int_or_float = opcode == spv::Op::OpImageSampleExplicitLod ||
+                            opcode == spv::Op::OpImageSparseSampleExplicitLod ||
+                            opcode == spv::Op::OpImageRead ||
+                            opcode == spv::Op::OpImageWrite ||
+                            opcode == spv::Op::OpImageSparseRead;
+
+  assert(float_only || int_only || int_or_float);
+
+  if (float_only && !_.IsFloatScalarOrVectorType(coord_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Coordinate to be a 32-bit float scalar or vector";
+  } else if (int_only && !_.IsIntScalarOrVectorType(coord_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Coordinate to be a 32-bit integer scalar or vector";
+  } else if (int_or_float) {
+    if (!_.IsFloatScalarOrVectorType(coord_type) &&
+        !_.IsIntScalarOrVectorType(coord_type)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Expected Coordinate to be a 32-bit integer or float scalar or "
+                "vector";
+    }
+  }
+
+  // Needs to be after we validate the scalar/vector
+  if (_.GetBitWidth(coord_type) != 32) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Coordinate to be a 32-bit scalar or vector";
+  }
+
+  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
+  const uint32_t actual_coord_size = _.GetDimension(coord_type);
+
+  if (opcode == spv::Op::OpImageTexelPointer) {
+    if (min_coord_size != actual_coord_size) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Expected Coordinate to have " << min_coord_size
+             << " components, but given " << actual_coord_size;
+    }
+  } else if (min_coord_size > actual_coord_size) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Coordinate to have at least " << min_coord_size
+           << " components, but given only " << actual_coord_size;
+  }
+
+  return SPV_SUCCESS;
+}
+
 spv_result_t ValidateSampledImage(ValidationState_t& _,
                                   const Instruction* inst) {
   auto type_inst = _.FindDef(inst->type_id());
@@ -1135,6 +1242,8 @@ spv_result_t ValidateSampledImage(ValidationState_t& _,
 
 spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
                                        const Instruction* inst) {
+  bool isUntyped = (inst->opcode() == spv::Op::OpUntypedImageTexelPointerEXT);
+
   const auto result_type = _.FindDef(inst->type_id());
   if (result_type->opcode() != spv::Op::OpTypePointer &&
       result_type->opcode() != spv::Op::OpTypeUntypedPointerKHR) {
@@ -1165,16 +1274,23 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
     }
   }
 
-  const auto image_ptr = _.FindDef(_.GetOperandTypeId(inst, 2));
-  if (!image_ptr || image_ptr->opcode() != spv::Op::OpTypePointer) {
+  const auto image_ptr =
+      _.FindDef(_.GetOperandTypeId(inst, (isUntyped ? 3 : 2)));
+  if (!image_ptr ||
+      (isUntyped && image_ptr->opcode() != spv::Op::OpTypeUntypedPointerKHR) ||
+      (!isUntyped && image_ptr->opcode() != spv::Op::OpTypePointer)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Image to be OpTypePointer";
+           << "Expected Image to be "
+           << (isUntyped ? "OpTypeUntypedPointerKHR" : "OpTypePointer");
   }
 
-  const auto image_type = image_ptr->GetOperandAs<uint32_t>(2);
+  const auto image_type = isUntyped ? inst->GetOperandAs<uint32_t>(2)
+                                    : image_ptr->GetOperandAs<uint32_t>(2);
   if (_.GetIdOpcode(image_type) != spv::Op::OpTypeImage) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Image to be OpTypePointer with Type OpTypeImage";
+           << "Expected Image to be "
+           << (isUntyped ? "OpTypeUntypedPointerKHR" : "OpTypePointer ")
+           << "with Type OpTypeImage";
   }
 
   ImageTypeInfo info;
@@ -1199,49 +1315,23 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
 
   if (info.dim == spv::Dim::SubpassData) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Image Dim SubpassData cannot be used with OpImageTexelPointer";
+           << "Image Dim SubpassData cannot be used with "
+           << (isUntyped ? "OpUntypedImageTexelPointerEXT"
+                         : "OpImageTexelPointer");
   }
 
   if (info.dim == spv::Dim::TileImageDataEXT) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Image Dim TileImageDataEXT cannot be used with "
-              "OpImageTexelPointer";
+           << (isUntyped ? "OpUntypedImageTexelPointerEXT"
+                         : "OpImageTexelPointer");
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (!coord_type || !_.IsIntScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be integer scalar or vector";
-  }
-
-  uint32_t expected_coord_size = 0;
-  if (info.arrayed == 0) {
-    expected_coord_size = GetPlaneCoordSize(info);
-  } else if (info.arrayed == 1) {
-    switch (info.dim) {
-      case spv::Dim::Dim1D:
-        expected_coord_size = 2;
-        break;
-      case spv::Dim::Cube:
-      case spv::Dim::Dim2D:
-        expected_coord_size = 3;
-        break;
-      default:
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Image 'Dim' must be one of 1D, 2D, or Cube when "
-                  "Arrayed is 1";
-        break;
-    }
-  }
-
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (expected_coord_size != actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have " << expected_coord_size
-           << " components, but given " << actual_coord_size;
-  }
+  if (spv_result_t result = ValidateImageCoordinate(
+          _, inst, info, /* word_index = */ (isUntyped ? 4 : 3)))
+    return result;
 
-  const uint32_t sample_type = _.GetOperandTypeId(inst, 4);
+  const uint32_t sample_type = _.GetOperandTypeId(inst, (isUntyped ? 5 : 4));
   if (!sample_type || !_.IsIntScalarType(sample_type)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Sample to be integer scalar";
@@ -1249,7 +1339,8 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
 
   if (info.multisampled == 0) {
     uint64_t ms = 0;
-    if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
+    if (!_.EvalConstantValUint64(
+            inst->GetOperandAs<uint32_t>(isUntyped ? 5 : 4), &ms) ||
         ms != 0) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Sample for Image with MS 0 to be a valid <id> for "
@@ -1258,18 +1349,24 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
   }
 
   if (spvIsVulkanEnv(_.context()->target_env)) {
-    if ((info.format != spv::ImageFormat::R64i) &&
-        (info.format != spv::ImageFormat::R64ui) &&
-        (info.format != spv::ImageFormat::R32f) &&
-        (info.format != spv::ImageFormat::R32i) &&
-        (info.format != spv::ImageFormat::R32ui) &&
-        !((info.format == spv::ImageFormat::Rg16f ||
-           info.format == spv::ImageFormat::Rgba16f) &&
-          _.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4658)
+    bool valid_format = info.format == spv::ImageFormat::R64i ||
+                        info.format == spv::ImageFormat::R64ui ||
+                        info.format == spv::ImageFormat::R32f ||
+                        info.format == spv::ImageFormat::R32i ||
+                        info.format == spv::ImageFormat::R32ui;
+    if (!valid_format &&
+        _.HasCapability(spv::Capability::AtomicFloat16VectorNV)) {
+      valid_format = info.format == spv::ImageFormat::Rg16f ||
+                     info.format == spv::ImageFormat::Rgba16f;
+    }
+
+    if (!valid_format) {
+      const uint32_t vuid = isUntyped ? 11416 : 4658;
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << _.VkErrorID(vuid)
              << "Expected the Image Format in Image to be R64i, R64ui, R32f, "
-                "R32i, or R32ui for Vulkan environment";
+                "R32i, or R32ui for Vulkan environment using Op"
+             << spvOpcodeString(inst->opcode());
     }
   }
 
@@ -1330,29 +1427,9 @@ spv_result_t ValidateImageLod(ValidationState_t& _, const Instruction* inst) {
     }
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if ((opcode == spv::Op::OpImageSampleExplicitLod ||
-       opcode == spv::Op::OpImageSparseSampleExplicitLod) &&
-      _.HasCapability(spv::Capability::Kernel)) {
-    if (!_.IsFloatScalarOrVectorType(coord_type) &&
-        !_.IsIntScalarOrVectorType(coord_type)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Coordinate to be int or float scalar or vector";
-    }
-  } else {
-    if (!_.IsFloatScalarOrVectorType(coord_type)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Coordinate to be float scalar or vector";
-    }
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   const uint32_t mask = inst->words().size() <= 5 ? 0 : inst->word(5);
 
@@ -1377,7 +1454,7 @@ spv_result_t ValidateImageLod(ValidationState_t& _, const Instruction* inst) {
 spv_result_t ValidateImageDref(ValidationState_t& _, const Instruction* inst,
                                const ImageTypeInfo& info) {
   const uint32_t dref_type = _.GetOperandTypeId(inst, 4);
-  if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) {
+  if (!_.IsFloatScalarType(dref_type, 32)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Dref to be of 32-bit float type";
   }
@@ -1439,19 +1516,9 @@ spv_result_t ValidateImageDrefLod(ValidationState_t& _,
            << GetActualResultTypeStr(opcode);
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (!_.IsFloatScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be float scalar or vector";
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   if (spv_result_t result = ValidateImageDref(_, inst, info)) return result;
 
@@ -1513,19 +1580,9 @@ spv_result_t ValidateImageFetch(ValidationState_t& _, const Instruction* inst) {
            << "Expected Image 'Sampled' parameter to be 1";
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (!_.IsIntScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be int scalar or vector";
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   if (spv_result_t result =
           ValidateImageOperands(_, inst, info, /* word_index = */ 6))
@@ -1593,26 +1650,15 @@ spv_result_t ValidateImageGather(ValidationState_t& _,
            << "Expected Image 'Dim' to be 2D, Cube, or Rect";
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (!_.IsFloatScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be float scalar or vector";
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   if (opcode == spv::Op::OpImageGather ||
       opcode == spv::Op::OpImageSparseGather) {
     const uint32_t component = inst->GetOperandAs<uint32_t>(4);
     const uint32_t component_index_type = _.GetTypeId(component);
-    if (!_.IsIntScalarType(component_index_type) ||
-        _.GetBitWidth(component_index_type) != 32) {
+    if (!_.IsIntScalarType(component_index_type, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Component to be 32-bit int scalar";
     }
@@ -1736,19 +1782,9 @@ spv_result_t ValidateImageRead(ValidationState_t& _, const Instruction* inst) {
   if (spv_result_t result = ValidateImageReadWrite(_, inst, info))
     return result;
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (!_.IsIntScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be int scalar or vector";
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(opcode, info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   if (spvIsVulkanEnv(_.context()->target_env)) {
     if (info.format == spv::ImageFormat::Unknown &&
@@ -1793,19 +1829,9 @@ spv_result_t ValidateImageWrite(ValidationState_t& _, const Instruction* inst) {
   if (spv_result_t result = ValidateImageReadWrite(_, inst, info))
     return result;
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 1);
-  if (!_.IsIntScalarOrVectorType(coord_type)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to be int scalar or vector";
-  }
-
-  const uint32_t min_coord_size = GetMinCoordSize(inst->opcode(), info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 1))
+    return result;
 
   // because it needs to match with 'Sampled Type' the Texel can't be a boolean
   const uint32_t texel_type = _.GetOperandTypeId(inst, 2);
@@ -1933,9 +1959,9 @@ spv_result_t ValidateImageQuerySizeLod(ValidationState_t& _,
   }
 
   const uint32_t lod_type = _.GetOperandTypeId(inst, 3);
-  if (!_.IsIntScalarType(lod_type)) {
+  if (!_.IsIntScalarType(lod_type, 32)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Level of Detail to be int scalar";
+           << "Expected Level of Detail to be a 32-bit int scalar";
   }
   return SPV_SUCCESS;
 }
@@ -2096,27 +2122,9 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
            << "Image 'Dim' must be 1D, 2D, 3D or Cube";
   }
 
-  const uint32_t coord_type = _.GetOperandTypeId(inst, 3);
-  if (_.HasCapability(spv::Capability::Kernel)) {
-    if (!_.IsFloatScalarOrVectorType(coord_type) &&
-        !_.IsIntScalarOrVectorType(coord_type)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Coordinate to be int or float scalar or vector";
-    }
-  } else {
-    if (!_.IsFloatScalarOrVectorType(coord_type)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Coordinate to be float scalar or vector";
-    }
-  }
-
-  const uint32_t min_coord_size = GetPlaneCoordSize(info);
-  const uint32_t actual_coord_size = _.GetDimension(coord_type);
-  if (min_coord_size > actual_coord_size) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected Coordinate to have at least " << min_coord_size
-           << " components, but given only " << actual_coord_size;
-  }
+  if (spv_result_t result =
+          ValidateImageCoordinate(_, inst, info, /* word_index = */ 3))
+    return result;
 
   // The operand is a sampled image.
   // The sampled image type is already checked to be parameterized by an image
@@ -2381,6 +2389,7 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpSampledImage:
       return ValidateSampledImage(_, inst);
     case spv::Op::OpImageTexelPointer:
+    case spv::Op::OpUntypedImageTexelPointerEXT:
       return ValidateImageTexelPointer(_, inst);
 
     case spv::Op::OpImageSampleImplicitLod:

+ 7 - 4
3rdparty/spirv-tools/source/val/validate_interfaces.cpp

@@ -155,11 +155,12 @@ spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
       *num_locations = 1;
       break;
     case spv::Op::OpTypeVector:
+    case spv::Op::OpTypeVectorIdEXT:
       // 3- and 4-component 64-bit vectors consume two locations.
       if ((_.ContainsSizedIntOrFloatType(type->id(), spv::Op::OpTypeInt, 64) ||
            _.ContainsSizedIntOrFloatType(type->id(), spv::Op::OpTypeFloat,
                                          64)) &&
-          (type->GetOperandAs<uint32_t>(2) > 2)) {
+          (_.GetDimension(type->id()) > 2)) {
         *num_locations = 2;
       } else {
         *num_locations = 1;
@@ -239,12 +240,13 @@ uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
       }
       break;
     case spv::Op::OpTypeVector:
+    case spv::Op::OpTypeVectorIdEXT:
       // Vectors consume components equal to the underlying type's consumption
       // times the number of elements in the vector. Note that 3- and 4-element
       // vectors cannot have a component decoration (i.e. assumed to be zero).
       num_components =
           NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
-      num_components *= type->GetOperandAs<uint32_t>(2);
+      num_components *= _.GetDimension(type->id());
       break;
     case spv::Op::OpTypeArray:
       // Skip the array.
@@ -615,7 +617,8 @@ spv_result_t ValidateStorageClass(ValidationState_t& _,
     auto storage_class = interface_var->GetOperandAs<spv::StorageClass>(2);
     switch (storage_class) {
       case spv::StorageClass::PushConstant: {
-        if (has_push_constant) {
+        if (has_push_constant &&
+            !(_.HasCapability(spv::Capability::PushConstantBanksNV))) {
           return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
                  << _.VkErrorID(6673)
                  << "Entry-point has more than one variable with the "
@@ -690,7 +693,7 @@ spv_result_t ValidateStorageClass(ValidationState_t& _,
                   return false;
                 })) {
           return _.diag(SPV_ERROR_INVALID_ID, interface_var)
-                 << "FP8 E4M3/E5M2 OpVariable <id> "  // TODO VUID
+                 << _.VkErrorID(10823) << "FP8 E4M3/E5M2 OpVariable <id> "
                  << _.getIdName(interface_var->id()) << " must not be declared "
                  << "with a Storage Class of Input or Output.";
         }

+ 15 - 0
3rdparty/spirv-tools/source/val/validate_invalid_type.cpp

@@ -100,6 +100,21 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpIsInf:
     case spv::Op::OpIsFinite:
     case spv::Op::OpIsNormal:
+    case spv::Op::OpFOrdEqual:
+    case spv::Op::OpFUnordEqual:
+    case spv::Op::OpFOrdNotEqual:
+    case spv::Op::OpFUnordNotEqual:
+    case spv::Op::OpFOrdLessThan:
+    case spv::Op::OpFUnordLessThan:
+    case spv::Op::OpFOrdGreaterThan:
+    case spv::Op::OpFUnordGreaterThan:
+    case spv::Op::OpFOrdLessThanEqual:
+    case spv::Op::OpFUnordLessThanEqual:
+    case spv::Op::OpFOrdGreaterThanEqual:
+    case spv::Op::OpFUnordGreaterThanEqual:
+    case spv::Op::OpLessOrGreater:
+    case spv::Op::OpOrdered:
+    case spv::Op::OpUnordered:
     case spv::Op::OpSignBitSet: {
       const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
       if (_.IsBfloat16Type(operand_type)) {

+ 8 - 0
3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp

@@ -308,6 +308,10 @@ spv_result_t ValidateLogicalPointerOperands(ValidationState_t& _,
     // SPV_ARM_graph
     case spv::Op::OpGraphEntryPointARM:
       return SPV_SUCCESS;
+    // SPV_EXT_descriptor_heap
+    case spv::Op::OpBufferPointerEXT:
+    case spv::Op::OpUntypedImageTexelPointerEXT:
+      return SPV_SUCCESS;
     // The following cases require a variable pointer capability. Since all
     // instructions are for variable pointers, the storage class and capability
     // are also checked.
@@ -371,6 +375,10 @@ spv_result_t ValidateLogicalPointerReturns(ValidationState_t& _,
     // SPV_AMD_shader_enqueue (spec bugs)
     case spv::Op::OpAllocateNodePayloadsAMDX:
       return SPV_SUCCESS;
+    // SPV_EXT_descriptor_heap
+    case spv::Op::OpBufferPointerEXT:
+    case spv::Op::OpUntypedImageTexelPointerEXT:
+      return SPV_SUCCESS;
     // Core spec with variable pointer capability. Check storage classes since
     // variable pointers can only be in certain storage classes.
     case spv::Op::OpSelect:

+ 4 - 0
3rdparty/spirv-tools/source/val/validate_logicals.cpp

@@ -184,6 +184,10 @@ spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) {
             dimension = type_inst->word(3);
             break;
           }
+          case spv::Op::OpTypeVectorIdEXT: {
+            dimension = _.GetDimension(result_type);
+            break;
+          }
 
           case spv::Op::OpTypeBool:
           case spv::Op::OpTypeInt:

+ 187 - 42
3rdparty/spirv-tools/source/val/validate_memory.cpp

@@ -21,6 +21,7 @@
 
 #include "source/opcode.h"
 #include "source/spirv_target_env.h"
+#include "source/table2.h"
 #include "source/val/instruction.h"
 #include "source/val/validate.h"
 #include "source/val/validate_scopes.h"
@@ -450,7 +451,7 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
       inst->GetOperandAs<spv::StorageClass>(storage_class_index);
   uint32_t value_id = 0;
   if (untyped_pointer) {
-    const auto has_data_type = 3u < inst->operands().size();
+    const bool has_data_type = 3u < inst->operands().size();
     if (has_data_type) {
       value_id = inst->GetOperandAs<uint32_t>(3u);
       auto data_type = _.FindDef(value_id);
@@ -466,10 +467,23 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
                << "Data type must be specified for Function, Private, and "
                   "Workgroup storage classes";
       }
+      // Added from SPV_EXT_descriptor_heap
+      // Vulkan allows untyped pointer without |Data Type| but only for heap
+      // decorated variable that are in UniformConstant
       if (spvIsVulkanEnv(_.context()->target_env)) {
-        return _.diag(SPV_ERROR_INVALID_ID, inst)
-               << _.VkErrorID(11167)
-               << "Vulkan requires that data type be specified";
+        if (storage_class != spv::StorageClass::UniformConstant) {
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
+                 << _.VkErrorID(11167) << "Storage class is "
+                 << StorageClassToString(storage_class)
+                 << ", but Vulkan requires that Data Type be specified when "
+                    "not using UniformConstant storage class";
+        } else if (!(_.IsDescriptorHeapBaseVariable(inst))) {
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
+                 << _.VkErrorID(11347)
+                 << "Storage class is UniformConstant, but Vulkan requires "
+                    "that Data Type be specified if the variable is not "
+                    "decorated with SamplerHeapEXT or ResourceHeapEXT";
+        }
       }
     }
   }
@@ -851,18 +865,72 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
               "parameters";
   }
 
-  if ((storage_class != spv::StorageClass::Function &&
-       storage_class != spv::StorageClass::Private) &&
-      pointee &&
-      _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
-        auto opcode = type_inst->opcode();
-        return opcode == spv::Op::OpTypeCooperativeVectorNV;
-      })) {
-    return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "Cooperative vector types (or types containing them) can only be "
-              "allocated "
-           << "in Function or Private storage classes or as function "
-              "parameters";
+  // Vulkan-specific validation for long vectors
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (_.HasCapability(spv::Capability::LongVectorEXT)) {
+      if ((storage_class != spv::StorageClass::Function &&
+           storage_class != spv::StorageClass::Private &&
+           storage_class != spv::StorageClass::StorageBuffer &&
+           storage_class != spv::StorageClass::PhysicalStorageBuffer &&
+           storage_class != spv::StorageClass::Workgroup &&
+           storage_class != spv::StorageClass::Uniform &&
+           storage_class != spv::StorageClass::PushConstant &&
+           storage_class != spv::StorageClass::ShaderRecordBufferKHR) &&
+          pointee &&
+          _.ContainsType(pointee->id(), [&](const Instruction* type_inst) {
+            auto opcode = type_inst->opcode();
+            if (opcode == spv::Op::OpTypeVector ||
+                opcode == spv::Op::OpTypeVectorIdEXT) {
+              uint32_t dim = _.GetDimension(type_inst->id());
+              return dim > 4;
+            }
+            return false;
+          })) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "Long vector types with more than 4 components (or types "
+                  "containing them) not supported in storage class "
+               << StorageClassToString(storage_class);
+      }
+
+      if (pointee &&
+          (storage_class == spv::StorageClass::StorageBuffer ||
+           storage_class == spv::StorageClass::PhysicalStorageBuffer ||
+           storage_class == spv::StorageClass::Uniform ||
+           storage_class == spv::StorageClass::PushConstant ||
+           storage_class == spv::StorageClass::ShaderRecordBufferKHR ||
+           (storage_class == spv::StorageClass::Workgroup &&
+            _.HasDecoration(pointee->id(), spv::Decoration::Block))) &&
+          _.ContainsType(pointee->id(), [&](const Instruction* type_inst) {
+            auto opcode = type_inst->opcode();
+            if (opcode == spv::Op::OpTypeVectorIdEXT) {
+              auto component_count =
+                  _.FindDef(type_inst->GetOperandAs<uint32_t>(2u));
+              return (bool)spvOpcodeIsSpecConstant(component_count->opcode());
+            }
+            return false;
+          })) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << _.VkErrorID(12294)
+               << "Long vector types with spec constant component count "
+                  "not supported in storage class with explicit layout "
+               << StorageClassToString(storage_class);
+      }
+    } else {
+      if ((storage_class != spv::StorageClass::Function &&
+           storage_class != spv::StorageClass::Private) &&
+          pointee &&
+          _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
+            auto opcode = type_inst->opcode();
+            return opcode == spv::Op::OpTypeVectorIdEXT;
+          })) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "Cooperative vector types (or types containing them) can "
+                  "only be "
+                  "allocated "
+               << "in Function or Private storage classes or as function "
+                  "parameters";
+      }
+    }
   }
 
   if (_.HasCapability(spv::Capability::Shader)) {
@@ -1145,6 +1213,56 @@ spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
 
   _.RegisterQCOMImageProcessingTextureConsumer(pointer_id, inst, nullptr);
 
+  // EXT_descriptor_heap
+  if (spvIsVulkanEnv(_.context()->target_env) &&
+      _.IsDescriptorHeapBaseVariable(_.FindDef(pointer_id))) {
+    auto descBaseVariable = _.FindUntypedBaseVariable(_.FindDef(pointer_id));
+    auto descBaseVariableId = descBaseVariable->id();
+    if (!_.HasDecoration(descBaseVariableId, spv::Decoration::DescriptorSet) &&
+        !_.HasDecoration(descBaseVariableId, spv::Decoration::Binding)) {
+      switch (result_type->opcode()) {
+        case spv::Op::OpTypeSampler:
+          if (!_.IsBuiltin(descBaseVariableId, spv::BuiltIn::SamplerHeapEXT)) {
+            return _.diag(SPV_ERROR_INVALID_ID, inst)
+                   << _.VkErrorID(11336)
+                   << "OpTypeSampler pointer instruction has no descriptor set "
+                   << "or binding and is not derived from a variable decorated "
+                      "with "
+                      "SamplerHeapEXT";
+          }
+          break;
+        case spv::Op::OpTypeImage:
+          if (!_.IsBuiltin(descBaseVariableId, spv::BuiltIn::ResourceHeapEXT)) {
+            return _.diag(SPV_ERROR_INVALID_ID, inst)
+                   << _.VkErrorID(11337)
+                   << "OpTypeImage pointer instruction has no descriptor set "
+                   << "or binding and is not derived from a variable decorated "
+                      "with "
+                      "ResourceHeapEXT";
+          }
+          break;
+        case spv::Op::OpTypeAccelerationStructureKHR:
+          uint32_t data_type;
+          spv::StorageClass sc;
+          if (_.GetPointerTypeInfo(descBaseVariable->type_id(), &data_type,
+                                   &sc) &&
+              sc != spv::StorageClass::Private &&
+              sc != spv::StorageClass::Function &&
+              !_.IsBuiltin(descBaseVariableId, spv::BuiltIn::ResourceHeapEXT)) {
+            return _.diag(SPV_ERROR_INVALID_ID, inst)
+                   << _.VkErrorID(11339)
+                   << "OpTypeAccelerationStructureKHR pointer instruction has "
+                      "no "
+                   << "descriptor set or binding and is not derived from a "
+                      "variable decorated with ResourceHeapEXT";
+          }
+          break;
+        default:
+          break;
+      }
+    }
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -1786,14 +1904,14 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
     switch (type_pointee->opcode()) {
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeVector:
-      case spv::Op::OpTypeCooperativeVectorNV:
+      case spv::Op::OpTypeVectorIdEXT:
       case spv::Op::OpTypeCooperativeMatrixNV:
       case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeArray:
       case spv::Op::OpTypeRuntimeArray:
       case spv::Op::OpTypeNodePayloadArrayAMDX: {
         // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
-        // OpTypeCooperativeVectorNV, OpTypeArray, and OpTypeRuntimeArray, word
+        // OpTypeVectorIdEXT, OpTypeArray, and OpTypeRuntimeArray, word
         // 2 is the Element Type.
         type_pointee = _.FindDef(type_pointee->word(2));
         break;
@@ -2028,10 +2146,11 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
        base_type_storage_class == spv::StorageClass::PushConstant ||
        (_.HasCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR) &&
         base_type_storage_class == spv::StorageClass::Workgroup)) &&
-      !_.HasDecoration(base_type->id(), spv::Decoration::ArrayStride)) {
+      (!_.HasDecoration(base_type->id(), spv::Decoration::ArrayStride) &&
+       !_.HasDecoration(base_type->id(), spv::Decoration::ArrayStrideIdEXT))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "OpPtrAccessChain must have a Base whose type is decorated "
-              "with ArrayStride";
+              "with ArrayStride or ArrayStrideIdEXT";
   }
 
   if (spvIsVulkanEnv(_.context()->target_env)) {
@@ -2072,11 +2191,15 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
   // Result type must be a 32- or 64-bit unsigned int.
   // 64-bit requires CapabilityShader64BitIndexingEXT or a pipeline/shader
   // flag and is validated in VVL.
-  auto result_type = state.FindDef(inst->type_id());
-  if (result_type->opcode() != spv::Op::OpTypeInt ||
-      !(result_type->GetOperandAs<uint32_t>(1) == 32 ||
-        result_type->GetOperandAs<uint32_t>(1) == 64) ||
-      result_type->GetOperandAs<uint32_t>(2) != 0) {
+  const uint32_t result_type_id = inst->type_id();
+  if (!state.IsIntScalarTypeWithSignedness(result_type_id, 0)) {
+    return state.diag(SPV_ERROR_INVALID_ID, inst)
+           << "The Result Type of Op" << spvOpcodeString(opcode) << " <id> "
+           << state.getIdName(inst->id())
+           << " must be OpTypeInt with width 32 or 64 and signedness 0.";
+  }
+  const uint32_t result_type_width = state.GetBitWidth(inst->type_id());
+  if (result_type_width != 32 && result_type_width != 64) {
     return state.diag(SPV_ERROR_INVALID_ID, inst)
            << "The Result Type of Op" << spvOpcodeString(opcode) << " <id> "
            << state.getIdName(inst->id())
@@ -2087,9 +2210,10 @@ spv_result_t ValidateArrayLength(ValidationState_t& state,
   auto pointer_ty_id = state.GetOperandTypeId(inst, (untyped ? 3 : 2));
   auto pointer_ty = state.FindDef(pointer_ty_id);
   if (untyped) {
-    if (pointer_ty->opcode() != spv::Op::OpTypeUntypedPointerKHR) {
+    if (!pointer_ty ||
+        pointer_ty->opcode() != spv::Op::OpTypeUntypedPointerKHR) {
       return state.diag(SPV_ERROR_INVALID_ID, inst)
-             << "Pointer must be an untyped pointer";
+             << "Pointer must be an untyped pointer object";
     }
   } else if (pointer_ty->opcode() != spv::Op::OpTypePointer) {
     return state.diag(SPV_ERROR_INVALID_ID, inst)
@@ -2150,10 +2274,9 @@ spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
                                                const Instruction* inst) {
   const spv::Op opcode = inst->opcode();
   // Result type must be a 32-bit unsigned int.
-  auto result_type = state.FindDef(inst->type_id());
-  if (result_type->opcode() != spv::Op::OpTypeInt ||
-      result_type->GetOperandAs<uint32_t>(1) != 32 ||
-      result_type->GetOperandAs<uint32_t>(2) != 0) {
+  const uint32_t result_type_id = inst->type_id();
+  if (!state.IsIntScalarTypeWithSignedness(result_type_id, 0) ||
+      state.GetBitWidth(inst->type_id()) != 32) {
     return state.diag(SPV_ERROR_INVALID_ID, inst)
            << "The Result Type of Op" << spvOpcodeString(opcode) << " <id> "
            << state.getIdName(inst->id())
@@ -2478,6 +2601,27 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateBufferPointerEXT(ValidationState_t& _,
+                                      const Instruction* inst) {
+  const auto storage_class_ptr = _.FindDef(inst->GetOperandAs<uint32_t>(0));
+  if (storage_class_ptr->opcode() != spv::Op::OpTypeUntypedPointerKHR &&
+      storage_class_ptr->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpBufferPointerEXT's Result Type should be "
+           << "a pointer type.";
+  } else {
+    // Buffer operand
+    auto buffer =
+        _.FindUntypedBaseVariable(_.FindDef(inst->GetOperandAs<uint32_t>(2)));
+    if (!_.IsBuiltin(buffer->id(), spv::BuiltIn::ResourceHeapEXT)) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "OpBufferPointerEXT's buffer must be an untyped pointer"
+             << " into a variable declared with the ResourceHeapEXT built-in";
+    }
+  }
+  return SPV_SUCCESS;
+}
+
 // Returns the number of instruction words taken up by a tensor addressing
 // operands argument and its implied operands.
 int TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask) {
@@ -2700,7 +2844,7 @@ spv_result_t ValidateInt32Operand(ValidationState_t& _, const Instruction* inst,
                                   const char* operand_name) {
   const auto type_id =
       _.FindDef(inst->GetOperandAs<uint32_t>(operand_index))->type_id();
-  if (!_.IsIntScalarType(type_id) || _.GetBitWidth(type_id) != 32) {
+  if (!_.IsIntScalarType(type_id, 32)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " " << operand_name << " type <id> "
            << _.getIdName(type_id) << " is not a 32 bit integer.";
@@ -2800,7 +2944,7 @@ spv_result_t ValidateCooperativeVectorLoadStoreNV(ValidationState_t& _,
 
   auto vector_type = _.FindDef(type_id);
 
-  if (vector_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
+  if (vector_type->opcode() != spv::Op::OpTypeVectorIdEXT) {
     if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "spv::Op::OpCooperativeVectorLoadNV Result Type <id> "
@@ -2852,7 +2996,7 @@ spv_result_t ValidateCooperativeVectorOuterProductNV(ValidationState_t& _,
   auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
   auto a_type = _.FindDef(type_id);
 
-  if (a_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
+  if (a_type->opcode() != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " A type <id> " << _.getIdName(type_id)
            << " is not a cooperative vector type.";
@@ -2861,7 +3005,7 @@ spv_result_t ValidateCooperativeVectorOuterProductNV(ValidationState_t& _,
   type_id = _.FindDef(inst->GetOperandAs<uint32_t>(3))->type_id();
   auto b_type = _.FindDef(type_id);
 
-  if (b_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
+  if (b_type->opcode() != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " B type <id> " << _.getIdName(type_id)
            << " is not a cooperative vector type.";
@@ -2915,7 +3059,7 @@ spv_result_t ValidateCooperativeVectorReduceSumNV(ValidationState_t& _,
   auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
   auto v_type = _.FindDef(type_id);
 
-  if (v_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
+  if (v_type->opcode() != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " V type <id> " << _.getIdName(type_id)
            << " is not a cooperative vector type.";
@@ -2993,18 +3137,16 @@ spv_result_t ValidateCooperativeVectorMatrixMulNV(ValidationState_t& _,
 
   const auto result_type = _.FindDef(result_type_id);
 
-  if (result_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
+  if (result_type->opcode() != spv::Op::OpTypeVectorIdEXT) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " result type <id> " << _.getIdName(result_type_id)
            << " is not a cooperative vector type.";
   }
 
   const auto result_component_type_id = result_type->GetOperandAs<uint32_t>(1u);
-  if (!(_.IsIntScalarType(result_component_type_id) &&
-        _.GetBitWidth(result_component_type_id) == 32) &&
-      !(_.IsFloatScalarType(result_component_type_id) &&
-        (_.GetBitWidth(result_component_type_id) == 32 ||
-         _.GetBitWidth(result_component_type_id) == 16))) {
+  if (!_.IsIntScalarType(result_component_type_id, 32) &&
+      !_.IsFloatScalarType(result_component_type_id, 32) &&
+      !_.IsFloatScalarType(result_component_type_id, 16)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << opcode_name << " result component type <id> "
            << _.getIdName(result_component_type_id)
@@ -3212,6 +3354,9 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpUntypedVariableKHR:
       if (auto error = ValidateVariable(_, inst)) return error;
       break;
+    case spv::Op::OpBufferPointerEXT:
+      if (auto error = ValidateBufferPointerEXT(_, inst)) return error;
+      break;
     case spv::Op::OpLoad:
       if (auto error = ValidateLoad(_, inst)) return error;
       break;

+ 10 - 13
3rdparty/spirv-tools/source/val/validate_ray_query.cpp

@@ -66,8 +66,7 @@ spv_result_t ValidateIntersectionId(ValidationState_t& _,
       inst->GetOperandAs<uint32_t>(intersection_index);
   const uint32_t intersection_type = _.GetTypeId(intersection_id);
   const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
-  if (!_.IsIntScalarType(intersection_type) ||
-      _.GetBitWidth(intersection_type) != 32 ||
+  if (!_.IsIntScalarType(intersection_type, 32) ||
       !spvOpcodeIsConstant(intersection_opcode)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "expected Intersection ID to be a constant 32-bit int scalar";
@@ -94,13 +93,13 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
-      if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
+      if (!_.IsIntScalarType(ray_flags, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray Flags must be a 32-bit int scalar";
       }
 
       const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
-      if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
+      if (!_.IsIntScalarType(cull_mask, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Cull Mask must be a 32-bit int scalar";
       }
@@ -113,7 +112,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
-      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+      if (!_.IsFloatScalarType(ray_tmin, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMin must be a 32-bit float scalar";
       }
@@ -127,7 +126,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
-      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+      if (!_.IsFloatScalarType(ray_tmax, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMax must be a 32-bit float scalar";
       }
@@ -144,7 +143,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
 
       const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
+      if (!_.IsFloatScalarType(hit_t_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Hit T must be a 32-bit float scalar";
       }
@@ -173,8 +172,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpRayQueryGetRayTMinKHR: {
       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
 
-      if (!_.IsFloatScalarType(result_type) ||
-          _.GetBitWidth(result_type) != 32) {
+      if (!_.IsFloatScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "expected Result Type to be 32-bit float scalar type";
       }
@@ -196,7 +194,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpRayQueryGetRayFlagsKHR: {
       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
 
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+      if (!_.IsIntScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "expected Result Type to be 32-bit int scalar type";
       }
@@ -278,7 +276,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
 
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+      if (!_.IsIntScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "expected Result Type to be 32-bit int scalar type";
       }
@@ -335,8 +333,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
 
-      if (!_.IsFloatScalarType(result_type) ||
-          _.GetBitWidth(result_type) != 32) {
+      if (!_.IsFloatScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "expected Result Type to be 32-bit floating point "
                   "scalar type";

+ 8 - 8
3rdparty/spirv-tools/source/val/validate_ray_tracing.cpp

@@ -52,31 +52,31 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_flags = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
+      if (!_.IsIntScalarType(ray_flags, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray Flags must be a 32-bit int scalar";
       }
 
       const uint32_t cull_mask = _.GetOperandTypeId(inst, 2);
-      if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
+      if (!_.IsIntScalarType(cull_mask, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Cull Mask must be a 32-bit int scalar";
       }
 
       const uint32_t sbt_offset = _.GetOperandTypeId(inst, 3);
-      if (!_.IsIntScalarType(sbt_offset) || _.GetBitWidth(sbt_offset) != 32) {
+      if (!_.IsIntScalarType(sbt_offset, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "SBT Offset must be a 32-bit int scalar";
       }
 
       const uint32_t sbt_stride = _.GetOperandTypeId(inst, 4);
-      if (!_.IsIntScalarType(sbt_stride) || _.GetBitWidth(sbt_stride) != 32) {
+      if (!_.IsIntScalarType(sbt_stride, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "SBT Stride must be a 32-bit int scalar";
       }
 
       const uint32_t miss_index = _.GetOperandTypeId(inst, 5);
-      if (!_.IsIntScalarType(miss_index) || _.GetBitWidth(miss_index) != 32) {
+      if (!_.IsIntScalarType(miss_index, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Miss Index must be a 32-bit int scalar";
       }
@@ -89,7 +89,7 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7);
-      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+      if (!_.IsFloatScalarType(ray_tmin, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMin must be a 32-bit float scalar";
       }
@@ -103,7 +103,7 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9);
-      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+      if (!_.IsFloatScalarType(ray_tmax, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMax must be a 32-bit float scalar";
       }
@@ -144,7 +144,7 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t hit = _.GetOperandTypeId(inst, 2);
-      if (!_.IsFloatScalarType(hit) || _.GetBitWidth(hit) != 32) {
+      if (!_.IsFloatScalarType(hit, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Hit must be a 32-bit int scalar";
       }

+ 40 - 50
3rdparty/spirv-tools/source/val/validate_ray_tracing_reorder.cpp

@@ -113,7 +113,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(instance_id_index)) {
     const uint32_t instance_id = _.GetOperandTypeId(inst, instance_id_index);
-    if (!_.IsIntScalarType(instance_id) || _.GetBitWidth(instance_id) != 32) {
+    if (!_.IsIntScalarType(instance_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Instance Id must be a 32-bit int scalar";
     }
@@ -121,7 +121,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(primtive_id_index)) {
     const uint32_t primitive_id = _.GetOperandTypeId(inst, primtive_id_index);
-    if (!_.IsIntScalarType(primitive_id) || _.GetBitWidth(primitive_id) != 32) {
+    if (!_.IsIntScalarType(primitive_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Primitive Id must be a 32-bit int scalar";
     }
@@ -129,8 +129,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(geometry_index)) {
     const uint32_t geometry_index_id = _.GetOperandTypeId(inst, geometry_index);
-    if (!_.IsIntScalarType(geometry_index_id) ||
-        _.GetBitWidth(geometry_index_id) != 32) {
+    if (!_.IsIntScalarType(geometry_index_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Geometry Index must be a 32-bit int scalar";
     }
@@ -214,7 +213,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(ray_tmin_index)) {
     const uint32_t ray_tmin_id = _.GetOperandTypeId(inst, ray_tmin_index);
-    if (!_.IsFloatScalarType(ray_tmin_id) || _.GetBitWidth(ray_tmin_id) != 32) {
+    if (!_.IsFloatScalarType(ray_tmin_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Ray TMin must be a 32-bit float scalar";
     }
@@ -233,7 +232,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(ray_tmax_index)) {
     const uint32_t ray_tmax_id = _.GetOperandTypeId(inst, ray_tmax_index);
-    if (!_.IsFloatScalarType(ray_tmax_id) || _.GetBitWidth(ray_tmax_id) != 32) {
+    if (!_.IsFloatScalarType(ray_tmax_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Ray TMax must be a 32-bit float scalar";
     }
@@ -241,7 +240,7 @@ spv_result_t ValidateHitObjectInstructionCommonParameters(
 
   if (isValidId(ray_flags_index)) {
     const uint32_t ray_flags_id = _.GetOperandTypeId(inst, ray_flags_index);
-    if (!_.IsIntScalarType(ray_flags_id) || _.GetBitWidth(ray_flags_id) != 32) {
+    if (!_.IsIntScalarType(ray_flags_id, 32)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Ray Flags must be a 32-bit int scalar";
     }
@@ -352,7 +351,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
 
-      if (!_.IsIntScalarType(result_type) || !_.GetBitWidth(result_type))
+      if (!_.IsIntScalarType(result_type, 32))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit integer type scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -365,7 +364,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
 
-      if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+      if (!_.IsFloatScalarType(result_type, 32))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit floating-point type scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -481,7 +480,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 3);
-      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+      if (!_.IsFloatScalarType(ray_tmin, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMin must be a 32-bit float scalar";
       }
@@ -495,7 +494,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       }
 
       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 5);
-      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+      if (!_.IsFloatScalarType(ray_tmax, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMax must be a 32-bit float scalar";
       }
@@ -563,8 +562,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
         return error;
       // Current Time
       const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
-      if (!_.IsFloatScalarType(current_time_id) ||
-          _.GetBitWidth(current_time_id) != 32) {
+      if (!_.IsFloatScalarType(current_time_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Current Times must be a 32-bit float scalar type";
       }
@@ -618,12 +616,12 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
 
         // Validate the optional opreands Hint and Bits
         const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
-        if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+        if (!_.IsIntScalarType(hint_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Hint must be a 32-bit int scalar";
         }
         const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
-        if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+        if (!_.IsIntScalarType(bits_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "bits must be a 32-bit int scalar";
         }
@@ -647,13 +645,13 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
               });
 
       const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
-      if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+      if (!_.IsIntScalarType(hint_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Hint must be a 32-bit int scalar";
       }
 
       const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+      if (!_.IsIntScalarType(bits_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "bits must be a 32-bit int scalar";
       }
@@ -664,7 +662,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
 
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+      if (!_.IsIntScalarType(result_type, 32))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit integer type scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -690,8 +688,7 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
 
-      if (!_.IsFloatScalarType(result_type) ||
-          _.GetBitWidth(result_type) != 32) {
+      if (!_.IsFloatScalarType(result_type, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit floating point scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -824,7 +821,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
 
-      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+      if (!_.IsIntScalarType(result_type, 32))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit integer type scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -837,7 +834,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
       RegisterOpcodeForValidModel(_, inst);
       if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
 
-      if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+      if (!_.IsFloatScalarType(result_type, 32))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected 32-bit floating-point type scalar as Result Type: "
                << spvOpcodeString(opcode);
@@ -934,8 +931,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
 
       const uint32_t sbt_index_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(sbt_index_id) ||
-          _.GetBitWidth(sbt_index_id) != 32) {
+      if (!_.IsIntScalarType(sbt_index_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "SBT Index must be a 32-bit integer scalar";
       }
@@ -979,8 +975,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Validate SBT Record Index (operand 2)
       const uint32_t sbt_record_index_id = _.GetOperandTypeId(inst, 2);
-      if (!_.IsIntScalarType(sbt_record_index_id) ||
-          _.GetBitWidth(sbt_record_index_id) != 32) {
+      if (!_.IsIntScalarType(sbt_record_index_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "SBT Record Index must be a 32-bit integer scalar";
       }
@@ -1005,8 +1000,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray Flags (operand 1)
       const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(ray_flags_id) ||
-          _.GetBitWidth(ray_flags_id) != 32) {
+      if (!_.IsIntScalarType(ray_flags_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray Flags must be a 32-bit int scalar";
       }
@@ -1029,7 +1023,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray TMin (operand 4)
       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
-      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+      if (!_.IsFloatScalarType(ray_tmin, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMin must be a 32-bit float scalar";
       }
@@ -1045,7 +1039,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray TMax (operand 6)
       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
-      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+      if (!_.IsFloatScalarType(ray_tmax, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMax must be a 32-bit float scalar";
       }
@@ -1058,8 +1052,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray Flags (operand 1)
       const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(ray_flags_id) ||
-          _.GetBitWidth(ray_flags_id) != 32) {
+      if (!_.IsIntScalarType(ray_flags_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray Flags must be a 32-bit int scalar";
       }
@@ -1082,7 +1075,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray TMin (operand 4)
       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
-      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+      if (!_.IsFloatScalarType(ray_tmin, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMin must be a 32-bit float scalar";
       }
@@ -1098,15 +1091,14 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Ray TMax (operand 6)
       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
-      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+      if (!_.IsFloatScalarType(ray_tmax, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Ray TMax must be a 32-bit float scalar";
       }
 
       // Current Time (operand 7)
       const uint32_t current_time_id = _.GetOperandTypeId(inst, 7);
-      if (!_.IsFloatScalarType(current_time_id) ||
-          _.GetBitWidth(current_time_id) != 32) {
+      if (!_.IsFloatScalarType(current_time_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Current Time must be a 32-bit float scalar";
       }
@@ -1129,13 +1121,13 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
               });
 
       const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
-      if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+      if (!_.IsIntScalarType(hint_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Hint must be a 32-bit int scalar";
       }
 
       const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
-      if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+      if (!_.IsIntScalarType(bits_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Bits must be a 32-bit int scalar";
       }
@@ -1168,12 +1160,12 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
         // Validate the optional operands Hint and Bits
         const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
-        if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+        if (!_.IsIntScalarType(hint_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Hint must be a 32-bit int scalar";
         }
         const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
-        if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+        if (!_.IsIntScalarType(bits_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Bits must be a 32-bit int scalar";
         }
@@ -1221,8 +1213,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Current Time (operand 11)
       const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
-      if (!_.IsFloatScalarType(current_time_id) ||
-          _.GetBitWidth(current_time_id) != 32) {
+      if (!_.IsFloatScalarType(current_time_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Current Time must be a 32-bit float scalar";
       }
@@ -1270,12 +1261,12 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
         // Validate optional Hint and Bits
         const uint32_t hint_id = _.GetOperandTypeId(inst, 2);
-        if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+        if (!_.IsIntScalarType(hint_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Hint must be a 32-bit int scalar";
         }
         const uint32_t bits_id = _.GetOperandTypeId(inst, 3);
-        if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+        if (!_.IsIntScalarType(bits_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Bits must be a 32-bit int scalar";
         }
@@ -1325,12 +1316,12 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
         // Validate optional Hint and Bits
         const uint32_t hint_id = _.GetOperandTypeId(inst, 12);
-        if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+        if (!_.IsIntScalarType(hint_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Hint must be a 32-bit int scalar";
         }
         const uint32_t bits_id = _.GetOperandTypeId(inst, 13);
-        if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+        if (!_.IsIntScalarType(bits_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Bits must be a 32-bit int scalar";
         }
@@ -1372,8 +1363,7 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
       // Current Time (operand 11)
       const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
-      if (!_.IsFloatScalarType(current_time_id) ||
-          _.GetBitWidth(current_time_id) != 32) {
+      if (!_.IsFloatScalarType(current_time_id, 32)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Current Time must be a 32-bit float scalar";
       }
@@ -1388,12 +1378,12 @@ spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
 
         // Validate optional Hint and Bits
         const uint32_t hint_id = _.GetOperandTypeId(inst, 13);
-        if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+        if (!_.IsIntScalarType(hint_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Hint must be a 32-bit int scalar";
         }
         const uint32_t bits_id = _.GetOperandTypeId(inst, 14);
-        if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+        if (!_.IsIntScalarType(bits_id, 32)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Bits must be a 32-bit int scalar";
         }

+ 1 - 2
3rdparty/spirv-tools/source/val/validate_tensor_layout.cpp

@@ -129,8 +129,7 @@ spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
   for (uint32_t i = 0; i < num_values; ++i) {
     const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
     const auto val = _.FindDef(val_id);
-    if (!val || !_.IsIntScalarType(val->type_id()) ||
-        _.GetBitWidth(val->type_id()) != 32) {
+    if (!val || !_.IsIntScalarType(val->type_id(), 32)) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << spvOpcodeString(inst->opcode()) << " operand <id> "
              << _.getIdName(val_id) << " is not a 32-bit integer.";

+ 87 - 51
3rdparty/spirv-tools/source/val/validate_type.cpp

@@ -200,6 +200,9 @@ spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
   auto num_components = inst->GetOperandAs<const uint32_t>(2);
   if (num_components == 2 || num_components == 3 || num_components == 4) {
     return SPV_SUCCESS;
+  } else if (num_components > 0 &&
+             _.HasCapability(spv::Capability::LongVectorEXT)) {
+    return SPV_SUCCESS;
   } else if (num_components == 8 || num_components == 16) {
     if (_.HasCapability(spv::Capability::Vector16)) {
       return SPV_SUCCESS;
@@ -217,15 +220,16 @@ spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateTypeCooperativeVectorNV(ValidationState_t& _,
-                                             const Instruction* inst) {
+spv_result_t ValidateTypeVectorIdEXT(ValidationState_t& _,
+                                     const Instruction* inst) {
   const auto component_index = 1;
   const auto component_type_id = inst->GetOperandAs<uint32_t>(component_index);
   const auto component_type = _.FindDef(component_type_id);
-  if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
-                          spv::Op::OpTypeInt != component_type->opcode())) {
+  if (!component_type || !_.IsScalarType(component_type_id) ||
+      (!_.HasCapability(spv::Capability::LongVectorEXT) &&
+       spv::Op::OpTypeBool == component_type->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeVectorNV Component Type <id> "
+           << "OpTypeVectorIdEXT Component Type <id> "
            << _.getIdName(component_type_id)
            << " is not a scalar numerical type.";
   }
@@ -236,32 +240,25 @@ spv_result_t ValidateTypeCooperativeVectorNV(ValidationState_t& _,
   const auto num_components = _.FindDef(num_components_id);
   if (!num_components || !spvOpcodeIsConstant(num_components->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeVectorNV component count <id> "
+           << "OpTypeVectorIdEXT component count <id> "
            << _.getIdName(num_components_id)
            << " is not a scalar constant type.";
   }
 
-  // NOTE: Check the initialiser value of the constant
-  const auto const_inst = num_components->words();
-  const auto const_result_type_index = 1;
-  const auto const_result_type = _.FindDef(const_inst[const_result_type_index]);
-  if (!const_result_type || spv::Op::OpTypeInt != const_result_type->opcode()) {
+  if (!_.IsIntScalarType(num_components->type_id(), 32)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeVectorNV component count <id> "
-           << _.getIdName(num_components_id)
-           << " is not a constant integer type.";
+           << "OpTypeVectorIdEXT component count type <id> "
+           << _.getIdName(num_components->type_id())
+           << " is not a 32-bit integer type.";
   }
 
-  int64_t num_components_value;
-  if (_.EvalConstantValInt64(num_components_id, &num_components_value)) {
-    auto& type_words = const_result_type->words();
-    const bool is_signed = type_words[3] > 0;
-    if (num_components_value == 0 || (num_components_value < 0 && is_signed)) {
+  uint64_t num_components_value;
+  if (_.EvalConstantValUint64(num_components_id, &num_components_value)) {
+    if (num_components_value == 0) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
-             << "OpTypeCooperativeVectorNV component count <id> "
+             << "OpTypeVectorIdEXT component count <id> "
              << _.getIdName(num_components_id)
-             << " default value must be at least 1: found "
-             << num_components_value;
+             << " default value must be at least 1: found 0.";
     }
   }
 
@@ -318,10 +315,11 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
     if (element_type->opcode() == spv::Op::OpTypeStruct &&
         (_.HasDecoration(element_type->id(), spv::Decoration::Block) ||
          _.HasDecoration(element_type->id(), spv::Decoration::BufferBlock))) {
-      if (_.HasDecoration(inst->id(), spv::Decoration::ArrayStride)) {
+      if (_.HasDecoration(inst->id(), spv::Decoration::ArrayStride) ||
+          _.HasDecoration(inst->id(), spv::Decoration::ArrayStrideIdEXT)) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
                << "Array containing a Block or BufferBlock must not be "
-                  "decorated with ArrayStride";
+                  "decorated with ArrayStride or ArrayStrideIdEXT";
       }
     }
   }
@@ -388,10 +386,11 @@ spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _,
     if (element_type->opcode() == spv::Op::OpTypeStruct &&
         (_.HasDecoration(element_type->id(), spv::Decoration::Block) ||
          _.HasDecoration(element_type->id(), spv::Decoration::BufferBlock))) {
-      if (_.HasDecoration(inst->id(), spv::Decoration::ArrayStride)) {
+      if (_.HasDecoration(inst->id(), spv::Decoration::ArrayStride) ||
+          _.HasDecoration(inst->id(), spv::Decoration::ArrayStrideIdEXT)) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
                << "Array containing a Block or BufferBlock must not be "
-                  "decorated with ArrayStride";
+                  "decorated with ArrayStride or ArrayStrideIdEXT";
       }
     }
   }
@@ -495,7 +494,9 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
   std::unordered_set<uint32_t> built_in_members;
   for (auto decoration : _.id_decorations(struct_id)) {
     if (decoration.dec_type() == spv::Decoration::BuiltIn &&
-        decoration.struct_member_index() != Decoration::kInvalidMember) {
+        decoration.struct_member_index() != Decoration::kInvalidMember &&
+        decoration.builtin() != spv::BuiltIn::ResourceHeapEXT &&
+        decoration.builtin() != spv::BuiltIn::SamplerHeapEXT) {
       built_in_members.insert(decoration.struct_member_index());
     }
   }
@@ -513,25 +514,32 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
     _.RegisterStructTypeWithBuiltInMember(struct_id);
   }
 
-  const auto isOpaqueType = [&_](const Instruction* opaque_inst) {
-    auto opcode = opaque_inst->opcode();
-    if (_.HasCapability(spv::Capability::BindlessTextureNV) &&
-        (opcode == spv::Op::OpTypeImage || opcode == spv::Op::OpTypeSampler ||
-         opcode == spv::Op::OpTypeSampledImage)) {
-      return false;
-    } else if (spvOpcodeIsBaseOpaqueType(opcode)) {
-      return true;
-    }
-    return false;
-  };
-
   if (spvIsVulkanEnv(_.context()->target_env) &&
-      !_.options()->before_hlsl_legalization &&
-      _.ContainsType(inst->id(), isOpaqueType)) {
-    return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << _.VkErrorID(4667) << "In "
-           << spvLogStringForEnv(_.context()->target_env)
-           << ", OpTypeStruct must not contain an opaque type.";
+      !_.options()->before_hlsl_legalization) {
+    // By default, without extensions, all opaque types are invalid in a struct.
+    // Check the exceptions allowed by the various capabilities
+    const auto IsInvalidOpaqueType = [&_](const Instruction* opaque_inst) {
+      const spv::Op opcode = opaque_inst->opcode();
+      if (_.HasCapability(spv::Capability::DescriptorHeapEXT) &&
+          _.IsDescriptorType(opcode)) {
+        return false;
+      } else if (_.HasCapability(spv::Capability::BindlessTextureNV) &&
+                 (opcode == spv::Op::OpTypeImage ||
+                  opcode == spv::Op::OpTypeSampler ||
+                  opcode == spv::Op::OpTypeSampledImage)) {
+        return false;
+      }
+      return spvOpcodeIsBaseOpaqueType(opcode);
+    };
+
+    if (_.ContainsType(inst->id(), IsInvalidOpaqueType)) {
+      const uint32_t vuid =
+          _.HasCapability(spv::Capability::DescriptorHeapEXT) ? 11482 : 4667;
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << _.VkErrorID(vuid) << "In "
+             << spvLogStringForEnv(_.context()->target_env)
+             << ", OpTypeStruct must not contain an invalid opaque type.";
+    }
   }
 
   return SPV_SUCCESS;
@@ -797,8 +805,23 @@ spv_result_t ValidateTypeUntypedPointerKHR(ValidationState_t& _,
       case spv::StorageClass::Uniform:
       case spv::StorageClass::PushConstant:
         break;
+      case spv::StorageClass::UniformConstant:
+        if (!_.HasCapability(spv::Capability::DescriptorHeapEXT)) {
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
+                 << "UniformConstant storage class untyped pointers in Vulkan "
+                    "require DescriptorHeapEXT be declared";
+        }
+        break;
+      case spv::StorageClass::Image:
+        if (!_.HasCapability(spv::Capability::DescriptorHeapEXT)) {
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
+                 << "Image storage class untyped pointers in Vulkan "
+                    "require DescriptorHeapEXT be declared";
+        }
+        break;
       default:
         return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << _.VkErrorID(11417)
                << "In Vulkan, untyped pointers can only be used in an "
                   "explicitly laid out storage class";
     }
@@ -810,8 +833,7 @@ spv_result_t ValidateTensorDim(ValidationState_t& _, const Instruction* inst) {
   const auto dim_index = 1;
   const auto dim_id = inst->GetOperandAs<uint32_t>(dim_index);
   const auto dim = _.FindDef(dim_id);
-  if (!dim || !_.IsIntScalarType(dim->type_id()) ||
-      _.GetBitWidth(dim->type_id()) != 32) {
+  if (!dim || !_.IsIntScalarType(dim->type_id(), 32)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
            << spvOpcodeString(inst->opcode()) << " Dim <id> "
            << _.getIdName(dim_id) << " is not a 32-bit integer.";
@@ -878,8 +900,7 @@ spv_result_t ValidateTypeTensorViewNV(ValidationState_t& _,
   for (size_t p_index = 3; p_index < inst->operands().size(); ++p_index) {
     auto p_id = inst->GetOperandAs<uint32_t>(p_index);
     const auto p = _.FindDef(p_id);
-    if (!p || !_.IsIntScalarType(p->type_id()) ||
-        _.GetBitWidth(p->type_id()) != 32) {
+    if (!p || !_.IsIntScalarType(p->type_id(), 32)) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << spvOpcodeString(inst->opcode()) << " Permutation <id> "
              << _.getIdName(p_id) << " is not a 32-bit integer.";
@@ -990,6 +1011,18 @@ spv_result_t ValidateTypeTensorARM(ValidationState_t& _,
 
   return SPV_SUCCESS;
 }
+
+spv_result_t ValidateTypeBufferEXT(ValidationState_t& _,
+                                   const Instruction* inst) {
+  auto sc = inst->GetOperandAs<spv::StorageClass>(1);
+  if (sc != spv::StorageClass::Uniform &&
+      sc != spv::StorageClass::StorageBuffer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << spvOpcodeString(inst->opcode())
+           << " StorageClass could only be StorageBuffer or Uniform.";
+  }
+  return SPV_SUCCESS;
+}
 }  // namespace
 
 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
@@ -1035,8 +1068,8 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpTypeCooperativeMatrixKHR:
       if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
       break;
-    case spv::Op::OpTypeCooperativeVectorNV:
-      if (auto error = ValidateTypeCooperativeVectorNV(_, inst)) return error;
+    case spv::Op::OpTypeVectorIdEXT:
+      if (auto error = ValidateTypeVectorIdEXT(_, inst)) return error;
       break;
     case spv::Op::OpTypeUntypedPointerKHR:
       if (auto error = ValidateTypeUntypedPointerKHR(_, inst)) return error;
@@ -1050,6 +1083,9 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpTypeTensorARM:
       if (auto error = ValidateTypeTensorARM(_, inst)) return error;
       break;
+    case spv::Op::OpTypeBufferEXT:
+      if (auto error = ValidateTypeBufferEXT(_, inst)) return error;
+      break;
     default:
       break;
   }

+ 155 - 21
3rdparty/spirv-tools/source/val/validation_state.cpp

@@ -17,6 +17,7 @@
 #include "source/val/validation_state.h"
 
 #include <cassert>
+#include <cstdint>
 #include <stack>
 #include <utility>
 
@@ -69,6 +70,7 @@ ModuleLayoutSection InstructionLayoutSection(
       return kLayoutDebug3;
     case spv::Op::OpDecorate:
     case spv::Op::OpMemberDecorate:
+    case spv::Op::OpMemberDecorateIdEXT:
     case spv::Op::OpGroupDecorate:
     case spv::Op::OpGroupMemberDecorate:
     case spv::Op::OpDecorationGroup:
@@ -935,7 +937,7 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
 
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
       return inst->word(2);
 
     case spv::Op::OpTypeTensorARM:
@@ -967,10 +969,18 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
 
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
       // Actual dimension isn't known, return 0
       return 0;
 
+    case spv::Op::OpTypeVectorIdEXT: {
+      uint64_t value = 0;
+      if (EvalConstantValUint64(inst->word(3), &value)) {
+        return static_cast<uint32_t>(value);
+      }
+
+      return 0;
+    }
+
     default:
       break;
   }
@@ -1014,6 +1024,16 @@ bool ValidationState_t::IsScalarType(uint32_t id) const {
   return IsIntScalarType(id) || IsFloatScalarType(id) || IsBoolScalarType(id);
 }
 
+bool ValidationState_t::IsVectorType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  if (!inst) {
+    return false;
+  }
+
+  return inst->opcode() == spv::Op::OpTypeVector ||
+         inst->opcode() == spv::Op::OpTypeVectorIdEXT;
+}
+
 bool ValidationState_t::IsArrayType(uint32_t id, uint64_t length) const {
   const Instruction* inst = FindDef(id);
   if (!inst || inst->opcode() != spv::Op::OpTypeArray) {
@@ -1050,7 +1070,7 @@ bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsBfloat16ScalarType(GetComponentType(id));
   }
 
@@ -1095,7 +1115,7 @@ bool ValidationState_t::IsFP8VectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsFP8ScalarType(GetComponentType(id));
   }
 
@@ -1119,9 +1139,16 @@ bool ValidationState_t::IsFP8Type(uint32_t id) const {
   return IsFP8ScalarType(id) || IsFP8VectorType(id) || IsFP8CoopMatType(id);
 }
 
-bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
+bool ValidationState_t::IsFloatScalarType(uint32_t id, uint32_t width) const {
   const Instruction* inst = FindDef(id);
-  return inst && inst->opcode() == spv::Op::OpTypeFloat;
+  bool is_float = inst && inst->opcode() == spv::Op::OpTypeFloat;
+  if (!is_float) {
+    return false;
+  }
+  if ((width != 0) && (width != inst->word(2))) {
+    return false;
+  }
+  return true;
 }
 
 bool ValidationState_t::IsFloatArrayType(uint32_t id) const {
@@ -1134,7 +1161,7 @@ bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsFloatScalarType(GetComponentType(id));
   }
 
@@ -1142,10 +1169,7 @@ bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
 }
 
 bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
-  const Instruction* inst = FindDef(id);
-  assert(inst);
-
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     uint32_t vectorDim = GetDimension(id);
     return IsFloatScalarType(GetComponentType(id)) &&
            (vectorDim == 2 || vectorDim == 4) &&
@@ -1165,7 +1189,7 @@ bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
     return true;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsFloatScalarType(GetComponentType(id));
   }
 
@@ -1201,7 +1225,7 @@ bool ValidationState_t::IsIntVectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsIntScalarType(GetComponentType(id));
   }
 
@@ -1218,7 +1242,7 @@ bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
     return true;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsIntScalarType(GetComponentType(id));
   }
 
@@ -1235,7 +1259,7 @@ bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsUnsignedIntScalarType(GetComponentType(id));
   }
 
@@ -1252,7 +1276,7 @@ bool ValidationState_t::IsUnsignedIntScalarOrVectorType(uint32_t id) const {
     return inst->GetOperandAs<uint32_t>(2) == 0;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsUnsignedIntScalarType(GetComponentType(id));
   }
 
@@ -1270,7 +1294,7 @@ bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsSignedIntScalarType(GetComponentType(id));
   }
 
@@ -1288,7 +1312,7 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsBoolScalarType(GetComponentType(id));
   }
 
@@ -1305,7 +1329,7 @@ bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
     return true;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeVector) {
+  if (IsVectorType(id)) {
     return IsBoolScalarType(GetComponentType(id));
   }
 
@@ -1413,6 +1437,7 @@ uint32_t ValidationState_t::GetLargestScalarType(uint32_t id) const {
     case spv::Op::OpTypeArray:
       return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
     case spv::Op::OpTypeVector:
+    case spv::Op::OpTypeVectorIdEXT:
       return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
     default:
       return GetBitWidth(id) / 8;
@@ -1499,7 +1524,7 @@ bool ValidationState_t::IsUnsigned64BitHandle(uint32_t id) const {
 
 bool ValidationState_t::IsCooperativeVectorNVType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
-  return inst && inst->opcode() == spv::Op::OpTypeCooperativeVectorNV;
+  return inst && inst->opcode() == spv::Op::OpTypeVectorIdEXT;
 }
 
 bool ValidationState_t::IsFloatCooperativeVectorNVType(uint32_t id) const {
@@ -1523,6 +1548,90 @@ bool ValidationState_t::IsTensorType(uint32_t id) const {
   return inst && inst->opcode() == spv::Op::OpTypeTensorARM;
 }
 
+// Opaque handles from [Descriptor] section (added from SPV_EXT_descriptor_heap)
+bool ValidationState_t::IsDescriptorType(spv::Op opcode) const {
+  return opcode == spv::Op::OpTypeBufferEXT || opcode == spv::Op::OpTypeImage ||
+         opcode == spv::Op::OpTypeTensorARM ||
+         opcode == spv::Op::OpTypeSampler ||
+         opcode == spv::Op::OpTypeAccelerationStructureKHR;
+}
+
+// Opaque handles from [Descriptor] section (added from SPV_EXT_descriptor_heap)
+bool ValidationState_t::IsDescriptorType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && IsDescriptorType(inst->opcode());
+}
+
+const Instruction* ValidationState_t::FindUntypedBaseVariable(
+    const Instruction* inst) {
+  bool found_heap_base = false;
+  const Instruction* base_inst = inst;
+  while (!found_heap_base) {
+    switch (base_inst->opcode()) {
+      case spv::Op::OpUntypedAccessChainKHR:
+      case spv::Op::OpUntypedInBoundsAccessChainKHR:
+      case spv::Op::OpUntypedPtrAccessChainKHR:
+      case spv::Op::OpUntypedInBoundsPtrAccessChainKHR:
+      case spv::Op::OpUntypedArrayLengthKHR:
+        base_inst = FindDef(base_inst->GetOperandAs<uint32_t>(3));
+        break;
+      case spv::Op::OpLoad:
+      case spv::Op::OpAtomicLoad:
+        if (GetIdOpcode(GetOperandTypeId(base_inst, 2)) ==
+            spv::Op::OpTypeUntypedPointerKHR) {
+          base_inst = FindDef(base_inst->GetOperandAs<uint32_t>(2));
+        }
+        break;
+      case spv::Op::OpAtomicExchange:
+      case spv::Op::OpAtomicCompareExchange:
+      case spv::Op::OpAtomicCompareExchangeWeak:
+      case spv::Op::OpAtomicIIncrement:
+      case spv::Op::OpAtomicIDecrement:
+      case spv::Op::OpAtomicIAdd:
+      case spv::Op::OpAtomicISub:
+      case spv::Op::OpAtomicSMin:
+      case spv::Op::OpAtomicUMin:
+      case spv::Op::OpAtomicSMax:
+      case spv::Op::OpAtomicUMax:
+      case spv::Op::OpAtomicAnd:
+      case spv::Op::OpAtomicOr:
+      case spv::Op::OpAtomicXor:
+        base_inst = FindDef(base_inst->GetOperandAs<uint32_t>(2));
+        break;
+      case spv::Op::OpStore:
+      case spv::Op::OpAtomicStore:
+        if (GetIdOpcode(GetOperandTypeId(base_inst, 0)) ==
+            spv::Op::OpTypeUntypedPointerKHR) {
+          base_inst = FindDef(base_inst->GetOperandAs<uint32_t>(0));
+        }
+        break;
+      default:
+        found_heap_base = true;
+        break;
+    }
+
+    if (found_heap_base) {
+      break;
+    }
+  }
+
+  return base_inst;
+}
+
+bool ValidationState_t::IsDescriptorHeapBaseVariable(const Instruction* inst) {
+  if (!HasCapability(spv::Capability::DescriptorHeapEXT)) {
+    return false;
+  }
+  const Instruction* base_inst = FindUntypedBaseVariable(inst);
+  const bool is_heap_base =
+      IsBuiltin(base_inst->id(), spv::BuiltIn::SamplerHeapEXT) ||
+      IsBuiltin(base_inst->id(), spv::BuiltIn::ResourceHeapEXT);
+
+  return FindDef(base_inst->id())->opcode() == spv::Op::OpBufferPointerEXT ||
+         (FindDef(base_inst->id())->opcode() == spv::Op::OpUntypedVariableKHR &&
+          is_heap_base);
+}
+
 spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
     const Instruction* inst, uint32_t result_type_id, uint32_t m2,
     bool is_conversion, bool swap_row_col) {
@@ -1927,7 +2036,7 @@ bool ValidationState_t::ContainsType(
     case spv::Op::OpTypeSampledImage:
     case spv::Op::OpTypeCooperativeMatrixNV:
     case spv::Op::OpTypeCooperativeMatrixKHR:
-    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeVectorIdEXT:
       return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
                           traverse_all_types);
     case spv::Op::OpTypePointer:
@@ -2002,6 +2111,7 @@ bool ValidationState_t::ContainsUntypedPointer(uint32_t id) const {
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeRuntimeArray:
     case spv::Op::OpTypeVector:
+    case spv::Op::OpTypeVectorIdEXT:
     case spv::Op::OpTypeMatrix:
     case spv::Op::OpTypeImage:
     case spv::Op::OpTypeSampledImage:
@@ -2776,6 +2886,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-None-10684);
     case 10685:
       return VUID_WRAP(VUID-StandaloneSpirv-None-10685); // formally 04683/06426
+    case 10823:
+      return VUID_WRAP(VUID-StandaloneSpirv-OpTypeFloat-10823);
     case 10824:
       // This use to be a standalone, but maintenance9 will set allow_vulkan_32_bit_bitwise now
       return VUID_WRAP(VUID-RuntimeSpirv-None-10824);
@@ -2813,10 +2925,32 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-TessLevelInner-10880);
     case 11167:
       return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedVariableKHR-11167);
+    case 11239:
+        return VUID_WRAP(VUID-SamplerHeapEXT-SamplerHeapEXT-11239);
+    case 11241:
+        return VUID_WRAP(VUID-ResourceHeapEXT-ResourceHeapEXT-11241);
+    case 11336:
+        return VUID_WRAP(VUID-StandaloneSpirv-Result-11336);
+    case 11337:
+        return VUID_WRAP(VUID-StandaloneSpirv-Result-11337);
+    case 11339:
+        return VUID_WRAP(VUID-StandaloneSpirv-Result-11339);
+    case 11346:
+        return VUID_WRAP(VUID-StandaloneSpirv-Result-11346);
+    case 11347:
+        return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedVariableKHR-11347);
+    case 11416:
+        return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedImageTexelPointerEXT-11416);
+    case 11417:
+        return VUID_WRAP(VUID-StandaloneSpirv-OpTypeUntypedPointerKHR-11417);
+    case 11482:
+      return VUID_WRAP(VUID-StandaloneSpirv-DescriptorHeapEXT-11482);
     case 11805:
       return VUID_WRAP(VUID-StandaloneSpirv-OpArrayLength-11805);
     case 12243:
       return VUID_WRAP(VUID-StandaloneSpirv-Scope-12243);
+    case 12294:
+      return VUID_WRAP(VUID-StandaloneSpirv-Function-12294);
     default:
       return "";  // unknown id
   }

+ 29 - 1
3rdparty/spirv-tools/source/val/validation_state.h

@@ -522,6 +522,29 @@ class ValidationState_t {
         [dec](const Decoration& d) { return dec == d.dec_type(); });
   }
 
+  /// Returns true if the given id <id> has the given built-in decoration <bt>,
+  /// otherwise returns false.
+  bool IsBuiltin(spv::Id id, spv::BuiltIn bt) {
+    for (auto& dec : id_decorations(id)) {
+      if (dec.dec_type() == spv::Decoration::BuiltIn) {
+        if (dec.builtin() == bt) return true;
+        break;
+      }
+    }
+    return false;
+  }
+
+  bool ContainsBuiltin(spv::Id id, spv::BuiltIn bt) {
+    const auto isHeapType = [&](const Instruction* inst) {
+      if (HasCapability(spv::Capability::DescriptorHeapEXT) &&
+          IsBuiltin(inst->id(), bt)) {
+        return true;
+      }
+      return false;
+    };
+    return ContainsType(uint32_t(id), isHeapType);
+  }
+
   /// Finds id's def, if it exists.  If found, returns the definition otherwise
   /// nullptr
   const Instruction* FindDef(uint32_t id) const;
@@ -664,6 +687,7 @@ class ValidationState_t {
   // Only works for types not for objects.
   bool IsVoidType(uint32_t id) const;
   bool IsScalarType(uint32_t id) const;
+  bool IsVectorType(uint32_t id) const;
   bool IsBfloat16ScalarType(uint32_t id) const;
   bool IsBfloat16VectorType(uint32_t id) const;
   bool IsBfloat16CoopMatType(uint32_t id) const;
@@ -672,7 +696,7 @@ class ValidationState_t {
   bool IsFP8VectorType(uint32_t id) const;
   bool IsFP8CoopMatType(uint32_t id) const;
   bool IsFP8Type(uint32_t id) const;
-  bool IsFloatScalarType(uint32_t id) const;
+  bool IsFloatScalarType(uint32_t id, uint32_t width = 0) const;
   bool IsFloatArrayType(uint32_t id) const;
   bool IsFloatVectorType(uint32_t id) const;
   bool IsFloat16Vector2Or4Type(uint32_t id) const;
@@ -707,6 +731,8 @@ class ValidationState_t {
   bool IsIntCooperativeVectorNVType(uint32_t id) const;
   bool IsUnsignedIntCooperativeVectorNVType(uint32_t id) const;
   bool IsTensorType(uint32_t id) const;
+  bool IsDescriptorType(spv::Op opcode) const;
+  bool IsDescriptorType(uint32_t id) const;
   // When |length| is not 0, return true only if the array length is equal to
   // |length| and the array length is not defined by a specialization constant.
   bool IsArrayType(uint32_t id, uint64_t length = 0) const;
@@ -736,6 +762,8 @@ class ValidationState_t {
   // This is designed to pass in the %type from a PSB pointer
   //   %ptr = OpTypePointer PhysicalStorageBuffer %type
   uint32_t GetLargestScalarType(uint32_t id) const;
+  bool IsDescriptorHeapBaseVariable(const Instruction* inst);
+  const Instruction* FindUntypedBaseVariable(const Instruction* inst);
 
   // Returns true if |id| is a type id that contains |type| (or integer or
   // floating point type) of |width| bits.

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott