2
0
Бранимир Караџић 3 сар өмнө
parent
commit
6f3fb79c0b
99 өөрчлөгдсөн 7729 нэмэгдсэн , 2714 устгасан
  1. 1 1
      3rdparty/spirv-tools/include/generated/build-version.inc
  2. 1482 1470
      3rdparty/spirv-tools/include/generated/core_tables_body.inc
  3. 4 0
      3rdparty/spirv-tools/include/generated/core_tables_header.inc
  4. 25 29
      3rdparty/spirv-tools/include/spirv-tools/libspirv.h
  5. 1 0
      3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp
  6. 24 0
      3rdparty/spirv-tools/include/spirv-tools/linker.hpp
  7. 10 0
      3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp
  8. 4 1
      3rdparty/spirv-tools/source/binary.cpp
  9. 3 2
      3rdparty/spirv-tools/source/disassemble.cpp
  10. 3 0
      3rdparty/spirv-tools/source/ext_inst.cpp
  11. 10 4
      3rdparty/spirv-tools/source/extensions.cpp
  12. 1011 0
      3rdparty/spirv-tools/source/link/fnvar.cpp
  13. 244 0
      3rdparty/spirv-tools/source/link/fnvar.h
  14. 34 8
      3rdparty/spirv-tools/source/link/linker.cpp
  15. 15 0
      3rdparty/spirv-tools/source/mimalloc.cpp
  16. 10 0
      3rdparty/spirv-tools/source/opcode.cpp
  17. 13 0
      3rdparty/spirv-tools/source/operand.cpp
  18. 57 1
      3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp
  19. 6 0
      3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h
  20. 516 0
      3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp
  21. 115 0
      3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.h
  22. 7 0
      3rdparty/spirv-tools/source/opt/ccp_pass.cpp
  23. 44 11
      3rdparty/spirv-tools/source/opt/const_folding_rules.cpp
  24. 1 0
      3rdparty/spirv-tools/source/opt/constants.cpp
  25. 8 7
      3rdparty/spirv-tools/source/opt/debug_info_manager.cpp
  26. 9 10
      3rdparty/spirv-tools/source/opt/debug_info_manager.h
  27. 21 6
      3rdparty/spirv-tools/source/opt/desc_sroa.cpp
  28. 9 3
      3rdparty/spirv-tools/source/opt/feature_manager.cpp
  29. 34 25
      3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp
  30. 9 0
      3rdparty/spirv-tools/source/opt/folding_rules.cpp
  31. 130 28
      3rdparty/spirv-tools/source/opt/graphics_robust_access_pass.cpp
  32. 11 3
      3rdparty/spirv-tools/source/opt/instruction.cpp
  33. 2 1
      3rdparty/spirv-tools/source/opt/instruction.h
  34. 140 77
      3rdparty/spirv-tools/source/opt/interface_var_sroa.cpp
  35. 19 14
      3rdparty/spirv-tools/source/opt/interface_var_sroa.h
  36. 49 20
      3rdparty/spirv-tools/source/opt/invocation_interlock_placement_pass.cpp
  37. 5 5
      3rdparty/spirv-tools/source/opt/invocation_interlock_placement_pass.h
  38. 6 2
      3rdparty/spirv-tools/source/opt/ir_context.cpp
  39. 4 2
      3rdparty/spirv-tools/source/opt/ir_loader.cpp
  40. 11 3
      3rdparty/spirv-tools/source/opt/loop_fission.cpp
  41. 147 61
      3rdparty/spirv-tools/source/opt/loop_peeling.cpp
  42. 15 11
      3rdparty/spirv-tools/source/opt/loop_peeling.h
  43. 42 14
      3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp
  44. 2 1
      3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h
  45. 29 15
      3rdparty/spirv-tools/source/opt/loop_utils.cpp
  46. 2 0
      3rdparty/spirv-tools/source/opt/loop_utils.h
  47. 44 10
      3rdparty/spirv-tools/source/opt/merge_return_pass.cpp
  48. 7 5
      3rdparty/spirv-tools/source/opt/merge_return_pass.h
  49. 7 0
      3rdparty/spirv-tools/source/opt/optimizer.cpp
  50. 1 0
      3rdparty/spirv-tools/source/opt/passes.h
  51. 52 5
      3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp
  52. 4 0
      3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h
  53. 7 2
      3rdparty/spirv-tools/source/opt/remove_unused_interface_variables_pass.cpp
  54. 11 4
      3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp
  55. 3 1
      3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h
  56. 20 2
      3rdparty/spirv-tools/source/opt/split_combined_image_sampler_pass.cpp
  57. 9 6
      3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp
  58. 1 1
      3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h
  59. 14 10
      3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp
  60. 2 2
      3rdparty/spirv-tools/source/opt/strength_reduction_pass.h
  61. 19 17
      3rdparty/spirv-tools/source/opt/trim_capabilities_pass.cpp
  62. 1 0
      3rdparty/spirv-tools/source/opt/trim_capabilities_pass.h
  63. 90 1
      3rdparty/spirv-tools/source/opt/type_manager.cpp
  64. 91 0
      3rdparty/spirv-tools/source/opt/types.cpp
  65. 56 0
      3rdparty/spirv-tools/source/opt/types.h
  66. 58 12
      3rdparty/spirv-tools/source/opt/upgrade_memory_model.cpp
  67. 4 1
      3rdparty/spirv-tools/source/parsed_operand.cpp
  68. 1 1
      3rdparty/spirv-tools/source/text_handler.cpp
  69. 90 0
      3rdparty/spirv-tools/source/util/hex_float.h
  70. 9 1
      3rdparty/spirv-tools/source/util/parse_number.cpp
  71. 59 23
      3rdparty/spirv-tools/source/val/validate.cpp
  72. 5 2
      3rdparty/spirv-tools/source/val/validate.h
  73. 23 2
      3rdparty/spirv-tools/source/val/validate_annotation.cpp
  74. 0 21
      3rdparty/spirv-tools/source/val/validate_atomics.cpp
  75. 4 4
      3rdparty/spirv-tools/source/val/validate_barriers.cpp
  76. 4 2
      3rdparty/spirv-tools/source/val/validate_bitwise.cpp
  77. 411 166
      3rdparty/spirv-tools/source/val/validate_builtins.cpp
  78. 12 5
      3rdparty/spirv-tools/source/val/validate_capability.cpp
  79. 465 1
      3rdparty/spirv-tools/source/val/validate_composites.cpp
  80. 81 4
      3rdparty/spirv-tools/source/val/validate_conversion.cpp
  81. 174 119
      3rdparty/spirv-tools/source/val/validate_decorations.cpp
  82. 10 4
      3rdparty/spirv-tools/source/val/validate_extensions.cpp
  83. 4 6
      3rdparty/spirv-tools/source/val/validate_function.cpp
  84. 547 0
      3rdparty/spirv-tools/source/val/validate_graph.cpp
  85. 55 34
      3rdparty/spirv-tools/source/val/validate_id.cpp
  86. 3 1
      3rdparty/spirv-tools/source/val/validate_image.cpp
  87. 6 2
      3rdparty/spirv-tools/source/val/validate_instruction.cpp
  88. 84 72
      3rdparty/spirv-tools/source/val/validate_interfaces.cpp
  89. 8 11
      3rdparty/spirv-tools/source/val/validate_invalid_type.cpp
  90. 79 3
      3rdparty/spirv-tools/source/val/validate_layout.cpp
  91. 120 22
      3rdparty/spirv-tools/source/val/validate_memory.cpp
  92. 175 147
      3rdparty/spirv-tools/source/val/validate_memory_semantics.cpp
  93. 123 72
      3rdparty/spirv-tools/source/val/validate_mode_setting.cpp
  94. 1 1
      3rdparty/spirv-tools/source/val/validate_non_uniform.cpp
  95. 1 1
      3rdparty/spirv-tools/source/val/validate_scopes.cpp
  96. 2 4
      3rdparty/spirv-tools/source/val/validate_tensor.cpp
  97. 8 5
      3rdparty/spirv-tools/source/val/validate_type.cpp
  98. 216 58
      3rdparty/spirv-tools/source/val/validation_state.cpp
  99. 104 6
      3rdparty/spirv-tools/source/val/validation_state.h

+ 1 - 1
3rdparty/spirv-tools/include/generated/build-version.inc

@@ -1 +1 @@
-"v2025.2", "SPIRV-Tools v2025.2 v2025.2.rc2-58-g007a1f89"
+"v2025.3", "SPIRV-Tools v2025.3 v2025.3.rc1-110-g8fbe2387"

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 1482 - 1470
3rdparty/spirv-tools/include/generated/core_tables_body.inc


+ 4 - 0
3rdparty/spirv-tools/include/generated/core_tables_header.inc

@@ -13,6 +13,7 @@ enum class PrintingClass : uint32_t {
   kDevice_Side_Enqueue,
   kExtension,
   kFunction,
+  kGraph,
   kGroup,
   kImage,
   kMemory,
@@ -43,6 +44,7 @@ enum Extension : uint32_t {
   kSPV_AMD_texture_gather_bias_lod,
   kSPV_ARM_cooperative_matrix_layouts,
   kSPV_ARM_core_builtins,
+  kSPV_ARM_graph,
   kSPV_ARM_tensors,
   kSPV_EXT_arithmetic_fence,
   kSPV_EXT_demote_to_helper_invocation,
@@ -91,6 +93,7 @@ enum Extension : uint32_t {
   kSPV_INTEL_fpga_memory_attributes,
   kSPV_INTEL_fpga_reg,
   kSPV_INTEL_function_pointers,
+  kSPV_INTEL_function_variants,
   kSPV_INTEL_global_variable_fpga_decorations,
   kSPV_INTEL_global_variable_host_access,
   kSPV_INTEL_inline_assembly,
@@ -182,6 +185,7 @@ enum Extension : uint32_t {
   kSPV_NV_stereo_view_rendering,
   kSPV_NV_tensor_addressing,
   kSPV_NV_viewport_array2,
+  kSPV_QCOM_cooperative_matrix_conversion,
   kSPV_QCOM_image_processing,
   kSPV_QCOM_image_processing2,
   kSPV_QCOM_tile_shading,

+ 25 - 29
3rdparty/spirv-tools/include/spirv-tools/libspirv.h

@@ -80,6 +80,8 @@ typedef enum spv_result_t {
   SPV_ERROR_INVALID_DATA = -14,  // Indicates data rules validation failure.
   SPV_ERROR_MISSING_EXTENSION = -15,
   SPV_ERROR_WRONG_VERSION = -16,  // Indicates wrong SPIR-V version
+  SPV_ERROR_FNVAR =
+      -17,  // Error related to SPV_INTEL_function_variants extension
   SPV_FORCE_32_BIT_ENUM(spv_result_t)
 } spv_result_t;
 
@@ -189,36 +191,24 @@ typedef enum spv_operand_type_t {
   SPV_OPERAND_TYPE_MEMORY_ACCESS,          // SPIR-V Sec 3.26
   SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE,  // SPIR-V Sec 3.FSR
 
-// NOTE: New concrete enum values should be added at the end.
-
-// The "optional" and "variable"  operand types are only used internally by
-// the assembler and the binary parser.
-// There are two categories:
-//    Optional : expands to 0 or 1 operand, like ? in regular expressions.
-//    Variable : expands to 0, 1 or many operands or pairs of operands.
-//               This is similar to * in regular expressions.
-
-// NOTE: These FIRST_* and LAST_* enum values are DEPRECATED.
-// The concept of "optional" and "variable" operand types are only intended
-// for use as an implementation detail of parsing SPIR-V, either in text or
-// binary form.  Instead of using enum ranges, use characteristic function
-// spvOperandIsConcrete.
-// The use of enum value ranges in a public API makes it difficult to insert
-// new values into a range without also breaking binary compatibility.
-//
-// Macros for defining bounds on optional and variable operand types.
-// Any variable operand type is also optional.
-// TODO(dneto): Remove SPV_OPERAND_TYPE_FIRST_* and SPV_OPERAND_TYPE_LAST_*
-#define FIRST_OPTIONAL(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_OPTIONAL_TYPE = ENUM
-#define FIRST_VARIABLE(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_VARIABLE_TYPE = ENUM
-#define LAST_VARIABLE(ENUM)                         \
-  ENUM, SPV_OPERAND_TYPE_LAST_VARIABLE_TYPE = ENUM, \
-        SPV_OPERAND_TYPE_LAST_OPTIONAL_TYPE = ENUM
+  // NOTE: New concrete enum values should be added at the end.
+
+  // The "optional" and "variable"  operand types are only used internally by
+  // the assembler and the binary parser.
+  // There are two categories:
+  //    Optional : expands to 0 or 1 operand, like ? in regular expressions.
+  //    Variable : expands to 0, 1 or many operands or pairs of operands.
+  //               This is similar to * in regular expressions.
+
+  // Use characteristic function spvOperandIsConcrete to classify the
+  // operand types; when it returns false, the operand is optional or variable.
+  //
+  // Any variable operand type is also optional.
 
   // An optional operand represents zero or one logical operands.
   // In an instruction definition, this may only appear at the end of the
   // operand types.
-  FIRST_OPTIONAL(SPV_OPERAND_TYPE_OPTIONAL_ID),
+  SPV_OPERAND_TYPE_OPTIONAL_ID,
   // An optional image operand type.
   SPV_OPERAND_TYPE_OPTIONAL_IMAGE,
   // An optional memory access type.
@@ -243,7 +233,7 @@ typedef enum spv_operand_type_t {
   // A variable operand represents zero or more logical operands.
   // In an instruction definition, this may only appear at the end of the
   // operand types.
-  FIRST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID),
+  SPV_OPERAND_TYPE_VARIABLE_ID,
   SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER,
   // A sequence of zero or more pairs of (typed literal integer, Id).
   // Expands to zero or more:
@@ -251,7 +241,7 @@ typedef enum spv_operand_type_t {
   // where the literal number must always be an integer of some sort.
   SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID,
   // A sequence of zero or more pairs of (Id, Literal integer)
-  LAST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER),
+  SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER,
 
   // The following are concrete enum types from the DebugInfo extended
   // instruction set.
@@ -344,6 +334,10 @@ typedef enum spv_operand_type_t {
   SPV_OPERAND_TYPE_TENSOR_OPERANDS,
   SPV_OPERAND_TYPE_OPTIONAL_TENSOR_OPERANDS,
 
+  // SPV_INTEL_function_variants
+  SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY,
+  SPV_OPERAND_TYPE_VARIABLE_CAPABILITY,
+
   // This is a sentinel value, and does not represent an operand type.
   // It should come last.
   SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
@@ -370,6 +364,7 @@ typedef enum spv_ext_inst_type_t {
   SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION,
   SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100,
   SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION,
+  SPV_EXT_INST_TYPE_TOSA_001000_1,
 
   // Multiple distinct extended instruction set types could return this
   // value, if they are prefixed with NonSemantic. and are otherwise
@@ -438,7 +433,7 @@ typedef enum spv_binary_to_text_options_t {
 
 // The default id bound is to the minimum value for the id limit
 // in the spir-v specification under the section "Universal Limits".
-const uint32_t kDefaultMaxIdBound = 0x3FFFFF;
+const static uint32_t kDefaultMaxIdBound = 0x3FFFFF;
 
 // Structures
 
@@ -772,6 +767,7 @@ SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetAllowOffsetTextureOperand(
     spv_validator_options options, bool val);
 
 // Allow base operands of some bit operations to be non-32-bit wide.
+// Was added for VK_KHR_maintenance9
 SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetAllowVulkan32BitBitwise(
     spv_validator_options options, bool val);
 

+ 1 - 0
3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp

@@ -133,6 +133,7 @@ class SPIRV_TOOLS_EXPORT ValidatorOptions {
   }
 
   // Allow base operands of some bit operations to be non-32-bit wide.
+  // Was added for VK_KHR_maintenance9
   void SetAllowVulkan32BitBitwise(bool val) {
     spvValidatorOptionsSetAllowVulkan32BitBitwise(options_, val);
   }

+ 24 - 0
3rdparty/spirv-tools/include/spirv-tools/linker.hpp

@@ -67,12 +67,36 @@ class SPIRV_TOOLS_EXPORT LinkerOptions {
     allow_ptr_type_mismatch_ = allow_ptr_type_mismatch;
   }
 
+  std::string GetFnVarTargetsCsv() const { return fnvar_targets_csv_; }
+  void SetFnVarTargetsCsv(std::string fnvar_targets_csv) {
+    fnvar_targets_csv_ = fnvar_targets_csv;
+  }
+
+  std::string GetFnVarArchitecturesCsv() const {
+    return fnvar_architectures_csv_;
+  }
+  void SetFnVarArchitecturesCsv(std::string fnvar_architectures_csv) {
+    fnvar_architectures_csv_ = fnvar_architectures_csv;
+  }
+
+  bool GetHasFnVarCapabilities() const { return has_fnvar_capabilities_; }
+  void SetHasFnVarCapabilities(bool fnvar_capabilities) {
+    has_fnvar_capabilities_ = fnvar_capabilities;
+  }
+
+  std::vector<std::string> GetInFiles() const { return in_files_; }
+  void SetInFiles(std::vector<std::string> in_files) { in_files_ = in_files; }
+
  private:
   bool create_library_{false};
   bool verify_ids_{false};
   bool allow_partial_linkage_{false};
   bool use_highest_version_{false};
   bool allow_ptr_type_mismatch_{false};
+  std::string fnvar_targets_csv_{""};
+  std::string fnvar_architectures_csv_{""};
+  bool has_fnvar_capabilities_ = false;
+  std::vector<std::string> in_files_{{}};
 };
 
 // Links one or more SPIR-V modules into a new SPIR-V module. That is, combine

+ 10 - 0
3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp

@@ -1022,6 +1022,16 @@ Optimizer::PassToken CreateSplitCombinedImageSamplerPass();
 // This pass assumes binding numbers are not applid via decoration groups
 // (OpDecorationGroup).
 Optimizer::PassToken CreateResolveBindingConflictsPass();
+
+// Create a pass to canonicalize IDs to improve compression of SPIR-V binary
+// files. The resulting modules have an increased ID range (IDs are not as
+// tightly packed around zero), but will compress better when multiple modules
+// are compressed together, since the compressor's dictionary can find better
+// cross module commonality. This pass should be run after most optimization
+// passes except for
+// --strip-debug because this pass will use OpName to canonicalize IDs. i.e. Run
+// --strip-debug after this pass.
+Optimizer::PassToken CreateCanonicalizeIdsPass();
 }  // namespace spvtools
 
 #endif  // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_

+ 4 - 1
3rdparty/spirv-tools/source/binary.cpp

@@ -636,6 +636,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
     } break;
 
     case SPV_OPERAND_TYPE_CAPABILITY:
+    case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
     case SPV_OPERAND_TYPE_MEMORY_MODEL:
@@ -689,6 +690,8 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
         parsed_operand.type = SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT;
       if (type == SPV_OPERAND_TYPE_OPTIONAL_FPENCODING)
         parsed_operand.type = SPV_OPERAND_TYPE_FPENCODING;
+      if (type == SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY)
+        parsed_operand.type = SPV_OPERAND_TYPE_CAPABILITY;
 
       const spvtools::OperandDesc* entry = nullptr;
       if (spvtools::LookupOperand(type, word, &entry)) {
@@ -853,7 +856,7 @@ void Parser::recordNumberType(size_t inst_offset,
       info.type = SPV_NUMBER_FLOATING;
       info.bit_width = peekAt(inst_offset + 2);
       if (inst->num_words >= 4) {
-        const spvtools::OperandDesc* desc;
+        const spvtools::OperandDesc* desc = nullptr;
         spv_result_t status = spvtools::LookupOperand(
             SPV_OPERAND_TYPE_FPENCODING, peekAt(inst_offset + 3), &desc);
         if (status == SPV_SUCCESS) {

+ 3 - 2
3rdparty/spirv-tools/source/disassemble.cpp

@@ -694,12 +694,12 @@ void InstructionDisassembler::EmitInstructionImpl(
   }
 
   if (inst.result_id) {
-    SetBlue();
+    SetBlue(line);
     const std::string id_name = name_mapper_(inst.result_id);
     if (indent_)
       line << std::setw(std::max(0, indent_ - 3 - int(id_name.size())));
     line << "%" << id_name;
-    ResetColor();
+    ResetColor(line);
     line << " = ";
   } else {
     line << std::string(indent_, ' ');
@@ -907,6 +907,7 @@ void InstructionDisassembler::EmitOperand(std::ostream& stream,
       stream << '"';
     } break;
     case SPV_OPERAND_TYPE_CAPABILITY:
+    case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
     case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:

+ 3 - 0
3rdparty/spirv-tools/source/ext_inst.cpp

@@ -55,6 +55,9 @@ spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) {
   if (!strncmp("NonSemantic.VkspReflection.", name, 27)) {
     return SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION;
   }
+  if (!strcmp("TOSA.001000.1", name)) {
+    return SPV_EXT_INST_TYPE_TOSA_001000_1;
+  }
   // ensure to add any known non-semantic extended instruction sets
   // above this point, and update spvExtInstIsNonSemantic()
   if (!strncmp("NonSemantic.", name, 12)) {

+ 10 - 4
3rdparty/spirv-tools/source/extensions.cpp

@@ -24,18 +24,24 @@
 namespace spvtools {
 
 std::string GetExtensionString(const spv_parsed_instruction_t* inst) {
-  if (inst->opcode != static_cast<uint16_t>(spv::Op::OpExtension)) {
+  if ((inst->opcode != static_cast<uint16_t>(spv::Op::OpExtension)) &&
+      (inst->opcode !=
+       static_cast<uint16_t>(spv::Op::OpConditionalExtensionINTEL))) {
     return "ERROR_not_op_extension";
   }
 
-  assert(inst->num_operands == 1);
+  const bool is_conditional =
+      inst->opcode ==
+      static_cast<uint16_t>(spv::Op::OpConditionalExtensionINTEL);
+  assert(inst->num_operands == (is_conditional ? 2 : 1));
+  const uint16_t op_i = is_conditional ? 1 : 0;
 
-  const auto& operand = inst->operands[0];
+  const auto& operand = inst->operands[op_i];
   assert(operand.type == SPV_OPERAND_TYPE_LITERAL_STRING);
   assert(inst->num_words > operand.offset);
   (void)operand; /* No unused variables in release builds. */
 
-  return spvDecodeLiteralStringOperand(*inst, 0);
+  return spvDecodeLiteralStringOperand(*inst, op_i);
 }
 
 std::string ExtensionSetToString(const ExtensionSet& extensions) {

+ 1011 - 0
3rdparty/spirv-tools/source/link/fnvar.cpp

@@ -0,0 +1,1011 @@
+// Copyright 2025 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fnvar.h"
+
+#include <initializer_list>
+#include <memory>
+#include <sstream>
+
+#include "source/opt/instruction.h"
+
+namespace spvtools {
+
+using opt::Function;
+using opt::Instruction;
+using opt::analysis::Type;
+
+namespace {
+// Helper functions
+
+// Parses a CSV source string for the purpose of this extension.
+//
+// Required columns must be known in advance and supplied as the required_cols
+// argument -- this is used for error checking. Values are assumed to be
+// separated by CSV_SEP. The input source string is assumed to be the output of
+// io::ReadTextFile and no other validation, apart from the CSV parsing, is
+// performed.
+//
+// Returns true on success, false on error (with error message stored in
+// err_msg).
+bool ParseCsv(const std::string& source,
+              const std::vector<std::string>& required_cols,
+              std::stringstream& err_msg,
+              std::vector<std::vector<std::string>>& result) {
+  std::stringstream fn_variants_csv_stream(source);
+  std::string line;
+  std::vector<std::string> columns;
+  constexpr char CSV_SEP = ',';
+  bool first_line = true;
+
+  while (std::getline(fn_variants_csv_stream, line, '\n')) {
+    if (line.empty()) {
+      continue;
+    }
+
+    std::vector<std::string> vals;
+    std::string val;
+    std::stringstream line_stream(line);
+    auto* vec = first_line ? &columns : &vals;
+
+    while (std::getline(line_stream, val, CSV_SEP)) {
+      vec->push_back(val);
+    }
+
+    if (!line_stream && val.empty()) {
+      vec->push_back("");
+    }
+
+    if (!first_line) {
+      if (vals.size() != columns.size()) {
+        err_msg << "Number of values does not match the number of columns. "
+                   "Offending line:\n"
+                << line;
+        return false;
+      }
+      result.push_back(vals);
+    }
+
+    first_line = false;
+  }
+
+  // check if required columns match actual columns (ordering matters)
+
+  if (columns.size() != required_cols.size()) {
+    err_msg << "Invalid number of CSV columns: " << columns.size()
+            << ", expected " << required_cols.size() << ".";
+    return false;
+  }
+
+  for (size_t i = 0; i < columns.size(); ++i) {
+    if (columns[i] != required_cols[i]) {
+      err_msg << "Invalid name of column " << i + 1 << ". Expected '"
+              << required_cols[i] << "', got '" << columns[i] << "'.";
+      return false;
+    }
+  }
+
+  return true;
+}
+
+// Annotate ID with ConditionalINTEL decoration
+void DecorateConditional(IRContext* context, uint32_t id_to_decorate,
+                         uint32_t spec_const_id) {
+  auto decor_instr =
+      std::make_unique<Instruction>(context, spv::Op::OpDecorate);
+  decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {id_to_decorate}});
+  decor_instr->AddOperand({SPV_OPERAND_TYPE_DECORATION,
+                           {uint32_t(spv::Decoration::ConditionalINTEL)}});
+  decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
+  context->module()->AddAnnotationInst(std::move(decor_instr));
+}
+
+// Finds entry point corresponding to a function
+//
+// Returns null if not found, otherwise returns pointer to the EP Instruction.
+Instruction* FindEntryPoint(const Instruction& fn_inst) {
+  auto* mod = fn_inst.context()->module();
+  for (auto& entry_point : mod->entry_points()) {
+    const int ep_i =
+        entry_point.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1;
+    if (entry_point.GetOperand(ep_i).AsId() == fn_inst.result_id()) {
+      return &entry_point;
+    }
+  }
+  return nullptr;
+}
+
+// If the function has an entry point, converts it to a conditional one
+void ConvertEPToConditional(Module* module, const Function& fn,
+                            uint32_t spec_const_id) {
+  for (const auto& ep_inst : module->entry_points()) {
+    if (ep_inst.opcode() == spv::Op::OpEntryPoint) {
+      auto* entry_point = FindEntryPoint(fn.DefInst());
+      if (entry_point != nullptr) {
+        std::vector<opt::Operand> old_operands;
+        for (auto operand : *entry_point) {
+          old_operands.push_back(operand);
+        }
+        entry_point->ToNop();
+        entry_point->SetOpcode(spv::Op::OpConditionalEntryPointINTEL);
+        entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
+        for (auto old_operand : old_operands) {
+          entry_point->AddOperand(old_operand);
+        }
+      }
+    }
+  }
+}
+
+// Finds ID of a bool type (returns 0 if not found)
+uint32_t FindIdOfBoolType(const Module* const mod) {
+  return mod->context()->get_type_mgr()->GetBoolTypeId();
+}
+
+// Combines IDs using OpSpecConstantOp with the operation defined by cmp_op.
+//
+// Returns the ID of the final result. If there are no IDs, returns 0. If there
+// is one ID, does not generate any instructions and returns the ID.
+uint32_t CombineIds(IRContext* const context, const std::vector<uint32_t>& ids,
+                    spv::Op cmp_op) {
+  if (ids.empty()) {
+    return 0;
+  } else if (ids.size() == 1) {
+    return ids[0];
+  } else {
+    uint32_t bool_id = FindIdOfBoolType(context->module());
+    assert(bool_id != 0);
+
+    uint32_t prev_spec_const_id = ids[0];
+
+    for (size_t i = 1; i < ids.size(); ++i) {
+      const uint32_t id = ids[i];
+      const uint32_t spec_const_op_id = context->TakeNextId();
+
+      auto inst = std::make_unique<Instruction>(
+          context, spv::Op::OpSpecConstantOp, bool_id, spec_const_op_id,
+          std::initializer_list<opt::Operand>{
+              {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {(uint32_t)(cmp_op)}},
+              {SPV_OPERAND_TYPE_ID, {prev_spec_const_id}},
+              {SPV_OPERAND_TYPE_ID, {id}}});
+      context->module()->AddType(std::move(inst));
+
+      prev_spec_const_id = spec_const_op_id;
+    }
+
+    return prev_spec_const_id;
+  }
+}
+
+// Returns whether instruction can be shared between variant modules and
+// combined using spec constants (such as conditional capabilities).
+bool CanBeFnVarCombined(const Instruction* inst) {
+  const spv::Op opcode = inst->opcode();
+
+  if ((opcode != spv::Op::OpExtInstImport) &&
+      (opcode != spv::Op::OpCapability) && (opcode != spv::Op::OpExtension) &&
+      !spvOpcodeGeneratesType(opcode)) {
+    return false;
+  }
+
+  if ((opcode == spv::Op::OpCapability) &&
+      ((inst->GetSingleWordOperand(0) ==
+        static_cast<uint32_t>(spv::Capability::FunctionVariantsINTEL)) ||
+       (inst->GetSingleWordOperand(0) ==
+        static_cast<uint32_t>(spv::Capability::SpecConditionalINTEL)))) {
+    // Always enabled
+    return false;
+  }
+
+  if ((opcode == spv::Op::OpExtension) &&
+      (inst->GetOperand(0).AsString() == FNVAR_EXT_NAME)) {
+    // Always enabled
+    return false;
+  }
+
+  return true;
+}
+
+// Calculates hash of an instruction.
+//
+// Applicable only to instructions that can be combined (ie. with
+// CanBeFnVarCombined being true) and from those, hash can be only computed for
+// selected instructions. Computing hash from other instruction is unsupported.
+size_t HashInst(const Instruction* inst) {
+  if (CanBeFnVarCombined(inst)) {
+    if (spvOpcodeGeneratesType(inst->opcode())) {
+      const Type* t =
+          inst->context()->get_type_mgr()->GetType(inst->result_id());
+      assert(t != nullptr);
+      return t->HashValue();
+    }
+
+    if (inst->opcode() == spv::Op::OpExtension) {
+      const auto name = inst->GetOperand(0).AsString();
+      return std::hash<std::string>()(name);
+    }
+
+    if (inst->opcode() == spv::Op::OpCapability) {
+      const auto cap = inst->GetSingleWordOperand(0);
+      return std::hash<uint32_t>()(cap);
+    }
+
+    if (inst->opcode() == spv::Op::OpExtInstImport) {
+      const auto name = inst->GetOperand(1).AsString();
+      return std::hash<std::string>()(name);
+    }
+  }
+
+  assert(false && "Unsupported instruction hash");
+  return std::hash<const Instruction*>()(inst);
+}
+
+std::string GetFnName(const Instruction& fn_inst) {
+  // Check entry point
+  const auto* ep_inst = FindEntryPoint(fn_inst);
+  if (ep_inst != nullptr) {
+    const int name_i =
+        ep_inst->opcode() == spv::Op::OpConditionalEntryPointINTEL ? 3 : 2;
+    return ep_inst->GetOperand(name_i).AsString();
+  }
+
+  // Check name of export linkage attribute decoration
+  const auto* decor_mgr = fn_inst.context()->get_decoration_mgr();
+  for (const auto* inst :
+       decor_mgr->GetDecorationsFor(fn_inst.result_id(), true)) {
+    const auto decoration = inst->GetOperand(1);
+    if ((decoration.type == SPV_OPERAND_TYPE_DECORATION) &&
+        (decoration.words.size() == 1) &&
+        (decoration.words[0] ==
+         static_cast<uint32_t>(spv::Decoration::LinkageAttributes))) {
+      const auto linkage = inst->GetOperand(3);
+      if ((linkage.type == SPV_OPERAND_TYPE_LINKAGE_TYPE) &&
+          (linkage.words.size() == 1) &&
+          (linkage.words[0] ==
+           static_cast<uint32_t>(spv::LinkageType::Export))) {
+        // decorates fn with LinkageAttribute and Export linkage type -> get the
+        // name
+        return inst->GetOperand(2).AsString();
+      }
+    }
+  }
+
+  return "";
+}
+
+uint32_t FindSpecConstByName(const Module* mod, std::string name) {
+  for (const auto* const_inst : mod->context()->GetConstants()) {
+    if (opt::IsSpecConstantInst(const_inst->opcode())) {
+      const auto id = const_inst->result_id();
+      for (const auto& name_inst : mod->debugs2()) {
+        if ((name_inst.opcode() == spv::Op::OpName) &&
+            (name_inst.GetOperand(0).AsId() == id) &&
+            (name_inst.GetOperand(1).AsString() == name)) {
+          return id;
+        }
+      }
+    }
+  }
+  return 0;
+}
+
+uint32_t CombineVariantDefs(const std::vector<VariantDef>& variant_defs,
+                            const std::vector<size_t> var_ids,
+                            IRContext* context,
+                            std::map<std::vector<size_t>, uint32_t>& cache) {
+  assert(var_ids.size() <= variant_defs.size());
+  uint32_t spec_const_comb_id = 0;
+  if (var_ids.size() != variant_defs.size()) {
+    // if not used by all variants
+    if (cache.find(var_ids) == cache.end()) {
+      // cache variant combinations
+      std::vector<uint32_t> spec_const_ids;
+      for (const auto& var_id : var_ids) {
+        const auto var_name = variant_defs[var_id].GetName();
+        const auto var_spec_id =
+            FindSpecConstByName(context->module(), var_name);
+        spec_const_ids.push_back(var_spec_id);
+      }
+      spec_const_comb_id =
+          CombineIds(context, spec_const_ids, spv::Op::OpLogicalOr);
+      assert(spec_const_comb_id != 0);
+      cache.insert({var_ids, spec_const_comb_id});
+    } else {
+      spec_const_comb_id = cache[var_ids];
+    }
+  }
+  return spec_const_comb_id;
+}
+
+bool strToInt(std::string s, uint32_t* x) {
+  for (const char& c : s) {
+    if (c < '0' || c > '9') {
+      return false;
+    }
+  }
+  if (!(std::stringstream(s) >> *x)) {
+    return false;
+  }
+  return true;
+}
+
+}  // anonymous namespace
+
+bool VariantDefs::ProcessFnVar(const LinkerOptions& options,
+                               const std::vector<Module*>& modules) {
+  assert(variant_defs_.empty());
+  assert(modules.size() == options.GetInFiles().size());
+
+  for (size_t i = 0; i < modules.size(); ++i) {
+    const auto* feat_mgr = modules[i]->context()->get_feature_mgr();
+    if ((feat_mgr->HasCapability(spv::Capability::FunctionVariantsINTEL)) ||
+        (feat_mgr->HasCapability(spv::Capability::SpecConditionalINTEL)) ||
+        (feat_mgr->HasExtension(kSPV_INTEL_function_variants))) {
+      // In principle, it can be done but it's complicated due to having to
+      // combine the existing conditionals with the new ones. For example,
+      // conditional capabilities would need to become "doubly-conditional".
+      err_ << "Creating multitarget modules from multitarget modules is not "
+              "supported. Offending file: "
+           << options.GetInFiles()[i];
+      return false;
+    }
+  }
+
+  std::vector<std::vector<std::string>> target_rows;
+  std::vector<std::vector<std::string>> architecture_rows;
+
+  if (!options.GetFnVarTargetsCsv().empty()) {
+    const std::vector<std::string> tgt_cols = {"module", "target", "features"};
+    if (!ParseCsv(options.GetFnVarTargetsCsv(), tgt_cols, err_, target_rows)) {
+      return false;
+    }
+  }
+
+  if (!options.GetFnVarArchitecturesCsv().empty()) {
+    const std::vector<std::string> arch_cols = {"module", "category", "family",
+                                                "op", "architecture"};
+    if (!ParseCsv(options.GetFnVarArchitecturesCsv(), arch_cols, err_,
+                  architecture_rows)) {
+      return false;
+    }
+  }
+
+  // check that all modules defined in the CSV exist
+
+  for (const auto& tgt_vals : target_rows) {
+    bool found = false;
+    for (const auto& in_file : options.GetInFiles()) {
+      if (tgt_vals[0] == in_file) {
+        found = true;
+      }
+    }
+    if (!found) {
+      err_ << "Module '" << tgt_vals[0]
+           << "' found in targets CSV not passed to the CLI.";
+      return false;
+    }
+  }
+
+  for (const auto& arch_vals : architecture_rows) {
+    bool found = false;
+    for (const auto& in_file : options.GetInFiles()) {
+      if (arch_vals[0] == in_file) {
+        found = true;
+      }
+    }
+    if (!found) {
+      err_ << "Module '" << arch_vals[0]
+           << "' found in architectures CSV not passed to the CLI.";
+      return false;
+    }
+  }
+
+  // create per-module variant defs
+
+  for (size_t i = 0; i < modules.size(); ++i) {
+    // first module passed to the CLI is considered the base module
+    bool is_base = i == 0;
+    const auto name = options.GetInFiles()[i];
+    auto variant_def = VariantDef(is_base, name, modules[i]);
+
+    for (const auto& arch_row : architecture_rows) {
+      const auto row_name = arch_row[0];
+      if (row_name == name) {
+        uint32_t category, family, op, architecture;
+
+        if (!strToInt(arch_row[1], &category)) {
+          err_ << "Error converting " << arch_row[1]
+               << " to architecture category.";
+          return false;
+        }
+        if (!strToInt(arch_row[2], &family)) {
+          err_ << "Error converting " << arch_row[2]
+               << " to architecture family.";
+          return false;
+        }
+        if (!strToInt(arch_row[3], &op)) {
+          err_ << "Error converting " << arch_row[3] << " to architecture op.";
+          return false;
+        }
+        if (!strToInt(arch_row[4], &architecture)) {
+          err_ << "Error converting " << arch_row[4] << " to architecture.";
+          return false;
+        }
+
+        variant_def.AddArchDef(category, family, op, architecture);
+      }
+    }
+
+    for (const auto& tgt_row : target_rows) {
+      const auto row_name = tgt_row[0];
+      if (row_name == name) {
+        uint32_t target;
+        std::vector<uint32_t> features;
+
+        if (!strToInt(tgt_row[1], &target)) {
+          err_ << "Error converting " << tgt_row[1] << " to target.";
+          return false;
+        }
+
+        // get features as FEAT_SEP-delimited integers
+
+        std::stringstream feat_stream(tgt_row[2]);
+        std::string feat;
+        while (std::getline(feat_stream, feat, FEAT_SEP)) {
+          uint32_t ufeat;
+          // if (!(std::stringstream(feat) >> ufeat)) {
+          if (!strToInt(feat, &ufeat)) {
+            err_ << "Error converting " << feat << " in " << tgt_row[2]
+                 << " to target feature.";
+            return false;
+          }
+          features.push_back(ufeat);
+        }
+
+        variant_def.AddTgtDef(target, features);
+      }
+    }
+
+    if (options.GetHasFnVarCapabilities()) {
+      variant_def.InferCapabilities();
+    }
+
+    variant_defs_.push_back(variant_def);
+  }
+
+  return true;
+}
+
+bool VariantDefs::ProcessVariantDefs() {
+  EnsureBoolType();
+  CollectVarInsts();
+  if (!GenerateFnVarConstants()) {
+    return false;
+  }
+  CollectBaseFnCalls();
+  return true;
+}
+
+void VariantDefs::GenerateHeader(IRContext* linked_context) {
+  linked_context->AddCapability(spv::Capability::SpecConditionalINTEL);
+  linked_context->AddCapability(spv::Capability::FunctionVariantsINTEL);
+  linked_context->AddExtension(std::string(FNVAR_EXT_NAME));
+
+  // Specifies used registry version
+  auto inst =
+      std::make_unique<Instruction>(linked_context, spv::Op::OpModuleProcessed);
+  std::stringstream line;
+  line << "SPV_INTEL_function_variants registry version "
+       << FNVAR_REGISTRY_VERSION;
+  inst->AddOperand(
+      {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(line.str())});
+  linked_context->AddDebug3Inst(std::move(inst));
+}
+
+void VariantDefs::CombineVariantInstructions(IRContext* linked_context) {
+  CombineBaseFnCalls(linked_context);
+  CombineInstructions(linked_context);
+}
+
+void VariantDefs::EnsureBoolType() {
+  for (auto& variant_def : variant_defs_) {
+    Module* module = variant_def.GetModule();
+    IRContext* context = module->context();
+
+    uint32_t bool_id = FindIdOfBoolType(module);
+    if (bool_id == 0) {
+      bool_id = context->TakeNextId();
+      auto variant_bool = std::make_unique<Instruction>(
+          context, spv::Op::OpTypeBool, 0, bool_id,
+          std::initializer_list<opt::Operand>{});
+      module->AddType(std::move(variant_bool));
+    }
+  }
+}
+
+void VariantDefs::CollectVarInsts() {
+  for (size_t i = 0; i < variant_defs_.size(); ++i) {
+    const auto variant_def = variant_defs_[i];
+    const auto* var_mod = variant_def.GetModule();
+
+    var_mod->ForEachInst([this, &i](const Instruction* inst) {
+      if (CanBeFnVarCombined(inst)) {
+        const size_t inst_hash = HashInst(inst);
+        if (fnvar_usage_.find(inst_hash) == fnvar_usage_.end()) {
+          fnvar_usage_.insert({inst_hash, {i}});
+        } else {
+          assert(fnvar_usage_[inst_hash].size() < variant_defs_.size());
+          fnvar_usage_[inst_hash].push_back(i);
+        }
+      }
+    });
+  }
+}
+
+bool VariantDefs::GenerateFnVarConstants() {
+  assert(variant_defs_.size() > 0);
+  assert(variant_defs_[0].IsBase());
+
+  if (variant_defs_.size() == 1) {
+    return true;
+  }
+
+  for (auto& variant_def : variant_defs_) {
+    Module* module = variant_def.GetModule();
+    IRContext* context = module->context();
+
+    uint32_t bool_id = FindIdOfBoolType(module);
+    if (bool_id == 0) {
+      // add a bool type if not present already
+      bool_id = context->TakeNextId();
+      auto variant_bool = std::make_unique<Instruction>(
+          context, spv::Op::OpTypeBool, 0, bool_id,
+          std::initializer_list<opt::Operand>{});
+      module->AddType(std::move(variant_bool));
+    }
+
+    // Spec constant architecture and target
+
+    std::vector<uint32_t> spec_const_arch_ids;
+    for (const auto& arch_def : variant_def.GetArchDefs()) {
+      const uint32_t spec_const_arch_id = context->TakeNextId();
+      spec_const_arch_ids.push_back(spec_const_arch_id);
+
+      auto inst = std::make_unique<Instruction>(
+          context, spv::Op::OpSpecConstantArchitectureINTEL, bool_id,
+          spec_const_arch_id,
+          std::initializer_list<opt::Operand>{
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.category}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.family}},
+              // Using spec op opcode here expects then next operand to be
+              // a type:
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.op}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.architecture}},
+          });
+      module->AddType(std::move(inst));
+    }
+
+    std::vector<uint32_t> spec_const_tgt_ids;
+    for (const auto& tgt_def : variant_def.GetTgtDefs()) {
+      const uint32_t spec_const_tgt_id = context->TakeNextId();
+      spec_const_tgt_ids.push_back(spec_const_tgt_id);
+
+      auto inst = std::make_unique<Instruction>(
+          context, spv::Op::OpSpecConstantTargetINTEL, bool_id,
+          spec_const_tgt_id,
+          std::initializer_list<opt::Operand>{
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {tgt_def.target}},
+          });
+      for (const auto& feat : tgt_def.features) {
+        inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {feat}});
+      }
+      module->AddType(std::move(inst));
+    }
+
+    std::vector<uint32_t> spec_const_ids;
+
+    // Spec constant capabilities
+
+    const auto variant_capabilities = variant_def.GetCapabilities();
+    if (!variant_capabilities.empty()) {
+      const uint32_t spec_const_cap_id = context->TakeNextId();
+      auto inst = std::make_unique<Instruction>(
+          context, spv::Op::OpSpecConstantCapabilitiesINTEL, bool_id,
+          spec_const_cap_id, std::initializer_list<opt::Operand>{});
+      for (const auto& cap : variant_capabilities) {
+        inst->AddOperand({SPV_OPERAND_TYPE_CAPABILITY, {uint32_t(cap)}});
+      }
+      module->AddType(std::move(inst));
+      spec_const_ids.push_back(spec_const_cap_id);
+    }
+
+    // Combine architectures such that, for the same module, those with the same
+    // category and family are combined with AND and different cat/fam are
+    // combined with OR.
+    // This lets you create combinations like "architecture between X and Y".
+
+    // map (category, family) -> IDs
+    std::map<std::pair<uint32_t, uint32_t>, std::vector<uint32_t>> arch_map_and;
+
+    for (size_t i = 0; i < spec_const_arch_ids.size(); ++i) {
+      const auto& arch_def = variant_def.GetArchDefs()[i];
+      const auto id = spec_const_arch_ids[i];
+      const auto key = std::make_pair(arch_def.category, arch_def.family);
+      if (arch_map_and.find(key) == arch_map_and.end()) {
+        arch_map_and[key] = {id};
+      } else {
+        arch_map_and[key].push_back(id);
+      }
+    }
+
+    std::vector<uint32_t> arch_ids_or;
+    for (const auto& it : arch_map_and) {
+      const auto id = CombineIds(context, it.second, spv::Op::OpLogicalAnd);
+      if (id > 0) {
+        arch_ids_or.push_back(id);
+      }
+    }
+
+    const uint32_t spec_const_arch_id =
+        CombineIds(context, arch_ids_or, spv::Op::OpLogicalOr);
+    if (spec_const_arch_id > 0) {
+      spec_const_ids.push_back(spec_const_arch_id);
+    }
+
+    const uint32_t spec_const_tgt_id =
+        CombineIds(context, spec_const_tgt_ids, spv::Op::OpLogicalOr);
+    if (spec_const_tgt_id > 0) {
+      spec_const_ids.push_back(spec_const_tgt_id);
+    }
+
+    uint32_t combined_spec_const_id =
+        CombineIds(context, spec_const_ids, spv::Op::OpLogicalAnd);
+    if (combined_spec_const_id == 0) {
+      // If the variant module has no constraints, use SpecConstantTrue
+      combined_spec_const_id = context->TakeNextId();
+      auto inst = std::make_unique<Instruction>(
+          context, spv::Op::OpSpecConstantTrue, bool_id, combined_spec_const_id,
+          std::initializer_list<opt::Operand>{});
+      context->module()->AddType(std::move(inst));
+    }
+    assert(combined_spec_const_id != 0);
+
+    // Add a name the combined boolean ID so we can look it up after the IDs are
+    // shifted
+    auto inst = std::make_unique<Instruction>(context, spv::Op::OpName);
+    inst->AddOperand({SPV_OPERAND_TYPE_ID, {combined_spec_const_id}});
+    std::vector<uint32_t> str_words;
+    utils::AppendToVector(variant_def.GetName(), &str_words);
+    inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_STRING, {str_words}});
+    module->AddDebug2Inst(std::move(inst));
+
+    // Annotate all instructions in the types section (eg. constants) with
+    // ConditionalINTEL, unless they can be shared between variant_defs_ (eg.
+    // types). Spec constants are excluded because they might have been
+    // generated by this extension.
+    for (const auto& type_inst : module->types_values()) {
+      if (!CanBeFnVarCombined(&type_inst) &&
+          !spvOpcodeIsSpecConstant(type_inst.opcode())) {
+        DecorateConditional(context, type_inst.result_id(),
+                            combined_spec_const_id);
+      }
+    }
+  }
+
+  // Annotate functions with ConditionalINTEL
+
+  for (const auto& base_fn : *variant_defs_[0].GetModule()) {
+    // For each function of the base module, find matching variant functions in
+    // other modules
+
+    auto base_fn_name = GetFnName(base_fn.DefInst());
+    if (base_fn_name.empty()) {
+      err_ << "Could not find name of a function " << base_fn.result_id()
+           << " in a base module " << variant_defs_[0].GetName()
+           << ". To be usable by SPV_INTEL_function_variants, a function "
+              "must either have an entry point or an export "
+              "LinkAttribute decoration.";
+      return false;
+    }
+
+    bool base_fn_needs_conditional = false;
+    for (size_t i = 1; i < variant_defs_.size(); ++i) {
+      const auto& variant_def = variant_defs_[i];
+      auto* variant_module = variant_def.GetModule();
+      auto* variant_context = variant_module->context();
+
+      for (const auto& var_fn : *variant_module) {
+        auto var_fn_name = GetFnName(var_fn.DefInst());
+        if (var_fn_name.empty()) {
+          err_ << "Could not find name of a function " << var_fn.result_id()
+               << " in a base module " << variant_def.GetName()
+               << ". To be usable by SPV_INTEL_function_variants, a function "
+                  "must either have an entry point or an export "
+                  "LinkAttribute decoration.";
+          return false;
+        }
+
+        if (base_fn_name == var_fn_name) {
+          base_fn_needs_conditional = true;
+        }
+
+        // each function in a variant module gets a ConditionalINTEL decoration
+
+        uint32_t spec_const_id =
+            FindSpecConstByName(variant_module, variant_def.GetName());
+        assert(spec_const_id != 0);
+        DecorateConditional(variant_context, var_fn.result_id(), spec_const_id);
+        ConvertEPToConditional(variant_module, var_fn, spec_const_id);
+      }
+    }
+
+    if (base_fn_needs_conditional) {
+      // only a base function that has a variant in another module gets a
+      // ConditionalINTEL decoration, the others are common for all
+      // variant_defs_
+      auto* base_module = variant_defs_[0].GetModule();
+      auto* base_context = base_module->context();
+      uint32_t spec_const_id =
+          FindSpecConstByName(base_module, variant_defs_[0].GetName());
+      assert(spec_const_id != 0);
+      DecorateConditional(base_context, base_fn.result_id(), spec_const_id);
+      ConvertEPToConditional(base_module, base_fn, spec_const_id);
+    }
+  }
+
+  return true;
+}
+
+void VariantDefs::CollectBaseFnCalls() {
+  auto* base_mod = variant_defs_[0].GetModule();
+  assert(variant_defs_[0].IsBase());
+  const auto* base_def_use_mgr = base_mod->context()->get_def_use_mgr();
+
+  base_mod->ForEachInst([this, &base_def_use_mgr](const Instruction* inst) {
+    if (inst->opcode() == spv::Op::OpFunctionCall) {
+      // For each function call in base module, get the function name
+      const auto fn_id = inst->GetOperand(2).AsId();
+      const auto* called_fn_inst = base_def_use_mgr->GetDef(fn_id);
+      assert(called_fn_inst != nullptr);
+      const auto called_fn_name = GetFnName(*called_fn_inst);
+      assert(!called_fn_name.empty());
+
+      std::vector<std::pair<std::string, const opt::Function*>> called_fns;
+      for (size_t i = 1; i < variant_defs_.size(); ++i) {
+        // ... then see in which variant the called function was defined
+        const auto& variant_def = variant_defs_[i];
+        assert(!variant_def.IsBase());
+
+        for (const auto& fn : *variant_def.GetModule()) {
+          const auto fn_name = GetFnName(fn.DefInst());
+          if (fn_name == called_fn_name) {
+            called_fns.push_back(std::make_pair(variant_def.GetName(), &fn));
+          }
+        }
+      }
+
+      if (!called_fns.empty()) {
+        base_fn_calls_[inst->result_id()] = called_fns;
+      }
+    }
+  });
+}
+
+void VariantDefs::CombineBaseFnCalls(IRContext* linked_context) {
+  for (auto kv : base_fn_calls_) {
+    const uint32_t call_id = kv.first;
+    const auto called_fns = kv.second;
+
+    if (called_fns.empty()) {
+      return;
+    }
+
+    opt::BasicBlock* fn_call_bb = linked_context->get_instr_block(call_id);
+
+    Instruction* found_call_inst = nullptr;
+    auto bb_iter = fn_call_bb->begin();
+    while (bb_iter != fn_call_bb->end() && found_call_inst == nullptr) {
+      if (bb_iter->HasResultId() && bb_iter->result_id() == call_id) {
+        found_call_inst = &*bb_iter;
+      }
+      ++bb_iter;
+    }
+
+    if (found_call_inst == nullptr) {
+      return;
+    }
+
+    const auto base_spec_const_id = FindSpecConstByName(
+        variant_defs_[0].GetModule(), variant_defs_[0].GetName());
+    const auto base_type_op = found_call_inst->context()
+                                  ->get_def_use_mgr()
+                                  ->GetDef(found_call_inst->type_id())
+                                  ->opcode();
+    const auto base_call_id = found_call_inst->result_id();
+
+    // decorate the base call with ConditionalINTEL
+    DecorateConditional(linked_context, base_call_id, base_spec_const_id);
+
+    // Add OpFunctionCall for each variant
+    Instruction* last_inst = found_call_inst;
+    std::vector<std::pair<uint32_t, uint32_t>> var_call_ids;
+    for (const auto& kv2 : called_fns) {
+      const std::string var_name = kv2.first;
+      const opt::Function* fn = kv2.second;
+      const uint32_t spec_const_id =
+          FindSpecConstByName(linked_context->module(), var_name);
+      assert(spec_const_id != 0);
+      const uint32_t var_call_id = linked_context->TakeNextId();
+      var_call_ids.push_back(std::make_pair(spec_const_id, var_call_id));
+
+      auto* var_call_inst = found_call_inst->Clone(linked_context);
+      var_call_inst->SetResultId(var_call_id);
+      var_call_inst->SetOperand(2, {fn->result_id()});
+      var_call_inst->InsertAfter(last_inst);
+      linked_context->set_instr_block(var_call_inst, fn_call_bb);
+      last_inst = var_call_inst;
+
+      // decorate the variant call with ConditionalINTEL
+      DecorateConditional(linked_context, var_call_id, spec_const_id);
+    }
+
+    if (base_type_op != spv::Op::OpTypeVoid) {
+      // Add OpConditionalCopyObjectINTEL combining the function calls
+      const uint32_t result_id = linked_context->TakeNextId();
+      auto conditional_copy_inst = new Instruction(
+          linked_context, spv::Op::OpConditionalCopyObjectINTEL,
+          found_call_inst->type_id(), result_id,
+          {{SPV_OPERAND_TYPE_ID, {base_spec_const_id}},
+           {SPV_OPERAND_TYPE_ID, {found_call_inst->result_id()}}});
+
+      for (const auto& kv3 : var_call_ids) {
+        const auto spec_const_id = kv3.first;
+        const auto var_call_id = kv3.second;
+        conditional_copy_inst->AddOperand(
+            {SPV_OPERAND_TYPE_ID, {spec_const_id}});
+        conditional_copy_inst->AddOperand({SPV_OPERAND_TYPE_ID, {var_call_id}});
+      }
+      conditional_copy_inst->InsertAfter(last_inst);
+      linked_context->set_instr_block(conditional_copy_inst, fn_call_bb);
+      last_inst = conditional_copy_inst;
+
+      // In all remaining instructions within the basic block, replace all
+      // usages of the base call ID with the result of
+      // OpConditionalCopyObjectINTEL
+      do {
+        last_inst = last_inst->NextNode();
+        last_inst->ForEachInId([base_call_id, result_id](uint32_t* id) {
+          if (*id == base_call_id) {
+            *id = result_id;
+          }
+        });
+      } while (last_inst != nullptr && *last_inst != *fn_call_bb->tail());
+    }
+  }
+
+  // Combine spec consts for the base module (base module is activated if all
+  // variant defs are inactive AND the base module constraints are satisfied)
+
+  std::vector<uint32_t> var_spec_const_ids;
+  for (const auto& variant_def : variant_defs_) {
+    if (variant_def.IsBase()) {
+      continue;
+    }
+
+    const auto id =
+        FindSpecConstByName(linked_context->module(), variant_def.GetName());
+    assert(id != 0);
+    var_spec_const_ids.push_back(id);
+  }
+  const uint32_t base_or_id =
+      CombineIds(linked_context, var_spec_const_ids, spv::Op::OpLogicalOr);
+
+  if (base_or_id != 0) {
+    const uint32_t bool_id = FindIdOfBoolType(linked_context->module());
+    assert(bool_id != 0);
+
+    const uint32_t base_not_id = linked_context->TakeNextId();
+    auto spec_const_op_inst = std::make_unique<Instruction>(
+        linked_context, spv::Op::OpSpecConstantOp, bool_id, base_not_id,
+        std::initializer_list<opt::Operand>{
+            {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
+             {(uint32_t)(spv::Op::OpLogicalNot)}},
+            {SPV_OPERAND_TYPE_ID, {base_or_id}}});
+    linked_context->module()->AddType(std::move(spec_const_op_inst));
+
+    // Update any ConditionalINTEL annotations, names and entry points
+    // referencing the old spec const ID to use the new one
+
+    const uint32_t old_base_spec_const_id = FindSpecConstByName(
+        linked_context->module(), variant_defs_[0].GetName());
+    assert(old_base_spec_const_id != 0);
+    const uint32_t base_spec_const_id =
+        CombineIds(linked_context, {old_base_spec_const_id, base_not_id},
+                   spv::Op::OpLogicalAnd);
+
+    for (auto& annot_inst : linked_context->module()->annotations()) {
+      if ((annot_inst.GetSingleWordOperand(1) ==
+           uint32_t(spv::Decoration::ConditionalINTEL)) &&
+          (annot_inst.GetOperand(2).AsId() == old_base_spec_const_id)) {
+        annot_inst.SetOperand(2, {base_spec_const_id});
+      }
+    }
+
+    for (auto& name_inst : linked_context->module()->debugs2()) {
+      if ((name_inst.opcode() == spv::Op::OpName) &&
+          (name_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
+        name_inst.SetOperand(0, {base_spec_const_id});
+      }
+    }
+
+    for (auto& ep_inst : linked_context->module()->entry_points()) {
+      if ((ep_inst.opcode() == spv::Op::OpConditionalEntryPointINTEL) &&
+          (ep_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
+        ep_inst.SetOperand(0, {base_spec_const_id});
+      }
+    }
+
+    linked_context->module()->ForEachInst(
+        [old_base_spec_const_id, base_spec_const_id](Instruction* inst) {
+          if (inst->opcode() == spv::Op::OpConditionalCopyObjectINTEL) {
+            inst->ForEachInId(
+                [old_base_spec_const_id, base_spec_const_id](uint32_t* id) {
+                  if (*id == old_base_spec_const_id) {
+                    *id = base_spec_const_id;
+                  }
+                });
+          }
+        });
+  }
+}
+
+void VariantDefs::CombineInstructions(IRContext* linked_context) {
+  // cache for existing variant ID combinations
+  std::map<std::vector<size_t>, uint32_t> spec_const_comb_ids;
+
+  linked_context->module()->ForEachInst(
+      [this, &linked_context, &spec_const_comb_ids](Instruction* inst) {
+        if (!CanBeFnVarCombined(inst)) {
+          return;
+        }
+
+        const size_t inst_hash = HashInst(inst);
+        if (fnvar_usage_.find(inst_hash) != fnvar_usage_.end()) {
+          const std::vector<size_t> var_ids = fnvar_usage_[inst_hash];
+          const uint32_t spec_const_comb_id = CombineVariantDefs(
+              variant_defs_, var_ids, linked_context, spec_const_comb_ids);
+          if (spec_const_comb_id != 0) {
+            if (inst->HasResultId()) {
+              DecorateConditional(linked_context, inst->result_id(),
+                                  spec_const_comb_id);
+            } else if (inst->opcode() == spv::Op::OpCapability) {
+              const uint32_t cap = inst->GetSingleWordOperand(0);
+              inst->SetOpcode(spv::Op::OpConditionalCapabilityINTEL);
+              inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
+                                   {SPV_OPERAND_TYPE_CAPABILITY, {cap}}});
+            } else if (inst->opcode() == spv::Op::OpExtension) {
+              const std::string ext_name = inst->GetOperand(0).AsString();
+              inst->SetOpcode(spv::Op::OpConditionalExtensionINTEL);
+              inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
+                                   {SPV_OPERAND_TYPE_LITERAL_STRING,
+                                    {utils::MakeVector(ext_name)}}});
+            } else {
+              assert(false && "Unsupported");
+            }
+          }
+        }
+      });
+}
+
+}  // namespace spvtools

+ 244 - 0
3rdparty/spirv-tools/source/link/fnvar.h

@@ -0,0 +1,244 @@
+// Copyright 2025 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Implementation of generating multitarget modules according to the
+// *SPV_INTEL_function_variants* extension
+//
+// Multitarget module is generated by linking separate modules: a base module
+// and variant modules containing device-specific variants of the functions in
+// the base module. The behavior is controlled by Comma-Separated Values (CSV)
+// files passed to the following flags:
+// --fnvar-targets: Required columns:
+//   module   - module file name
+//   target   - device target ISA value
+//   features - feature values for the target separated by '/' (FEAT_SEP)
+// --fnvar-architectures: Required columns:
+//   module       - module file name
+//   category     - device category value
+//   family       - device family value
+//   op           - opcode of the comparison instruction
+//   architecture - device architecture
+// The values (except module) are decimal strings with their meaning defined in
+// the 'targets registry' as described in the extension spec. The decimal
+// strings may only encode unsigned 32-bit integers (characters 0-9), possibly
+// with leading zeros.
+//
+// In addition, --fnvar-capabilities generates OpSpecConstantCapabilitiesINTEL
+// for each module with operands corresponding to the module's capabilities.
+//
+// Each line in the targets/architectures CSV file defines one
+// OpSpecConstant<Target/Architecture>INTEL instruction, the columns correspond
+// to the operands of these instructions. One module can have multiple lines, in
+// which case they are combined into a single boolean spec constant using
+// OpSpecConstantOp and OpLogicalOr (except when category and family in the
+// architectures CSV are the same, then the lines are combined with
+// OpLogicalAnd). For example, the following architectures CSV
+//
+//     module,category,family,op,architecture
+//     foo.spv,1,7,174,1
+//     foo.spv,1,7,178,3
+//     foo.spv,1,8,170,1
+//
+// is combined as follows:
+//
+//          %53 = OpSpecConstantArchitectureINTEL %bool 1 7 174 1
+//          %54 = OpSpecConstantArchitectureINTEL %bool 1 7 178 3
+//          %55 = OpSpecConstantArchitectureINTEL %bool 1 8 170 1
+//          %56 = OpSpecConstantOp %bool LogicalAnd %53 %54
+//     %foo_spv = OpSpecConstantOp %bool LogicalOr %55 %56
+//
+// The %foo_spv is annotated with OpName "foo.spv" (the module's name) which
+// serves as an identifier to find the constant later. We cannot use IDs for it
+// because the IDs get shifted during linking.
+//
+// The first module passed to `spirv-link` is considered the 'base' module. For
+// example, if base module defines functions 'foo' and 'bar' and the other
+// modules define only 'foo', only the 'foo' is treated as a function variant
+// guarded by spec constants. The 'bar' function will be untouched and therefore
+// present for all variants. The function variants are matched by name, and
+// therefore they must either have an entry point, or an Export linkage
+// attribute.
+
+#ifndef FNVAR_H
+#define FNVAR_H
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "spirv-tools/linker.hpp"
+
+namespace spvtools {
+
+using opt::IRContext;
+using opt::Module;
+
+// Map of instruction hash -> which variants are using the instruction (denoted
+// by the index to the variants vector)
+using FnVarUsage = std::unordered_map<size_t, std::vector<size_t>>;
+
+// Map of base function call ID -> variant functions corresponding to the
+// called function (along with the variant name)
+using BaseFnCalls =
+    std::map<uint32_t,
+             std::vector<std::pair<std::string, const opt::Function*>>>;
+
+constexpr char FNVAR_EXT_NAME[] = "SPV_INTEL_function_variants";
+constexpr uint32_t FNVAR_REGISTRY_VERSION = 0;
+constexpr char FEAT_SEP = '/';
+
+struct FnVarArchDef {
+  uint32_t category;
+  uint32_t family;
+  uint32_t op;
+  uint32_t architecture;
+};
+
+struct FnVarTargetDef {
+  uint32_t target;
+  std::vector<uint32_t> features;
+};
+
+// Definition of a variant
+//
+// Stores architecture and target definitions inferred from lines in the CSV
+// files for a single module (as well as a pointer to the Module).
+class VariantDef {
+ public:
+  VariantDef(bool isbase, std::string nm, Module* mod)
+      : is_base(isbase), name(nm), module(mod) {}
+
+  bool IsBase() const { return this->is_base; }
+  std::string GetName() const { return this->name; }
+  Module* GetModule() const { return this->module; }
+
+  void AddArchDef(uint32_t category, uint32_t family, uint32_t op,
+                  uint32_t architecture) {
+    FnVarArchDef arch_def;
+    arch_def.category = category;
+    arch_def.family = family;
+    arch_def.op = op;
+    arch_def.architecture = architecture;
+    this->arch_defs.push_back(arch_def);
+  }
+  const std::vector<FnVarArchDef>& GetArchDefs() const {
+    return this->arch_defs;
+  }
+
+  void AddTgtDef(uint32_t target, std::vector<uint32_t> features) {
+    FnVarTargetDef tgt_def;
+    tgt_def.target = target;
+    tgt_def.features = features;
+    this->tgt_defs.push_back(tgt_def);
+  }
+  const std::vector<FnVarTargetDef>& GetTgtDefs() const {
+    return this->tgt_defs;
+  }
+
+  void InferCapabilities() {
+    for (const auto& cap_inst : module->capabilities()) {
+      capabilities.insert(spv::Capability(cap_inst.GetOperand(0).words[0]));
+    }
+  }
+  const std::set<spv::Capability>& GetCapabilities() const {
+    return this->capabilities;
+  }
+
+ private:
+  bool is_base;
+  std::string name;
+  Module* module;
+  std::vector<FnVarTargetDef> tgt_defs;
+  std::vector<FnVarArchDef> arch_defs;
+  std::set<spv::Capability> capabilities;
+};
+
+// Collection of VariantDef instances
+//
+// Apart from being a wrapper around a vector of VariantDef instances, it
+// defines the main API for generating SPV_INTEL_function_variants instructions
+// based on the CSV files.
+class VariantDefs {
+ public:
+  // Returns last error message.
+  std::string GetErr() { return err_.str(); }
+
+  // Processes CSV files passed to the CLI and populate _variants.
+  //
+  // Returns true on success, false on error.
+  bool ProcessFnVar(const LinkerOptions& options,
+                    const std::vector<Module*>& modules);
+
+  // Analyses each variant def module and generates those instructions that are
+  // module-specific, ie., not requiring knowledge from other modules.
+  //
+  // Returns true on success, false on error.
+  bool ProcessVariantDefs();
+
+  // Generates basic instructions required for this extension to work.
+  void GenerateHeader(IRContext* linked_context);
+
+  // Generates instructions from this extension that result from combining
+  // several variant def modules.
+  void CombineVariantInstructions(IRContext* linked_context);
+
+ private:
+  // Adds a boolean type to every module if there is none.
+  //
+  // These are necessary for spec constants.
+  void EnsureBoolType();
+
+  // Collects which combinable instructions are defined in which modules
+  void CollectVarInsts();
+
+  // Generates OpSpecConstant<Target/Architecture/Capabilities>INTEL and
+  // combines them as necessary. Also converts entry points to conditional ones
+  // and decorates module-specific instructions with ConditionalINTEL.
+  //
+  // Returns true on success, false on error.
+  bool GenerateFnVarConstants();
+
+  // Determines which functions in the base module are called by which function
+  // variants.
+  void CollectBaseFnCalls();
+
+  // Combines OpFunctionCall instructions collected with CollectBaseFnCalls()
+  // using conditional copy.
+  void CombineBaseFnCalls(IRContext* linked_context);
+
+  // Decorates instructions shared between modules with ConditionalINTEL or
+  // generates conditional capabilities and extensions, depending on which
+  // variants are used by each.
+  void CombineInstructions(IRContext* linked_context);
+
+  // Accumulates all errors encountered during processing.
+  std::stringstream err_;
+
+  // Collection of VariantDef instances
+  std::vector<VariantDef> variant_defs_;
+
+  // Used for combining OpFunctionCall instructions
+  BaseFnCalls base_fn_calls_;
+
+  // Used for determining which function variant uses which (applicable)
+  // instruction
+  FnVarUsage fnvar_usage_;
+};
+
+}  // namespace spvtools
+
+#endif  // FNVAR_H

+ 34 - 8
3rdparty/spirv-tools/source/link/linker.cpp

@@ -15,9 +15,10 @@
 #include "spirv-tools/linker.hpp"
 
 #include <algorithm>
+#include <cstdint>
 #include <cstdio>
 #include <cstring>
-#include <iostream>
+#include <functional>
 #include <memory>
 #include <numeric>
 #include <string>
@@ -26,18 +27,17 @@
 #include <utility>
 #include <vector>
 
+#include "fnvar.h"
 #include "source/diagnostic.h"
 #include "source/opt/build_module.h"
 #include "source/opt/compact_ids_pass.h"
 #include "source/opt/decoration_manager.h"
 #include "source/opt/ir_builder.h"
-#include "source/opt/ir_loader.h"
 #include "source/opt/pass_manager.h"
 #include "source/opt/remove_duplicates_pass.h"
 #include "source/opt/remove_unused_interface_variables_pass.h"
 #include "source/opt/type_manager.h"
 #include "source/spirv_constant.h"
-#include "source/spirv_target_env.h"
 #include "source/table2.h"
 #include "source/util/make_unique.h"
 #include "source/util/string_utils.h"
@@ -328,7 +328,10 @@ spv_result_t MergeModules(const MessageConsumer& consumer,
   for (const auto& module : input_modules)
     for (const auto& inst : module->entry_points()) {
       const uint32_t model = inst.GetSingleWordInOperand(0);
-      const std::string name = inst.GetInOperand(2).AsString();
+      const std::string name =
+          inst.opcode() == spv::Op::OpConditionalEntryPointINTEL
+              ? inst.GetOperand(3).AsString()
+              : inst.GetOperand(2).AsString();
       const auto i = std::find_if(
           entry_points.begin(), entry_points.end(),
           [model, name](const std::pair<uint32_t, std::string>& v) {
@@ -728,8 +731,7 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
   if (max_id_bound >= SPV_LIMIT_RESULT_ID_BOUND)
     DiagnosticStream({0u, 0u, 4u}, consumer, "", SPV_WARNING)
         << "The minimum limit of IDs, " << (SPV_LIMIT_RESULT_ID_BOUND - 1)
-        << ", was exceeded:"
-        << " " << max_id_bound << " is the current ID bound.\n"
+        << ", was exceeded: " << max_id_bound << " is the current ID bound.\n"
         << "The resulting module might not be supported by all "
            "implementations.";
 
@@ -740,8 +742,8 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
   if (num_global_values >= SPV_LIMIT_GLOBAL_VARIABLES_MAX)
     DiagnosticStream(position, consumer, "", SPV_WARNING)
         << "The minimum limit of global values, "
-        << (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded;"
-        << " " << num_global_values << " global values were found.\n"
+        << (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded; "
+        << num_global_values << " global values were found.\n"
         << "The resulting module might not be supported by all "
            "implementations.";
 
@@ -853,6 +855,22 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
     ir_contexts.push_back(std::move(ir_context));
   }
 
+  const bool make_multitarget = !options.GetFnVarArchitecturesCsv().empty() ||
+                                !options.GetFnVarTargetsCsv().empty();
+
+  VariantDefs variant_defs;
+
+  if (make_multitarget) {
+    if (!variant_defs.ProcessFnVar(options, modules)) {
+      return DiagnosticStream(position, consumer, "", SPV_ERROR_FNVAR)
+             << variant_defs.GetErr();
+    }
+    if (!variant_defs.ProcessVariantDefs()) {
+      return DiagnosticStream(position, consumer, "", SPV_ERROR_FNVAR)
+             << variant_defs.GetErr();
+    }
+  }
+
   // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
   //          range from the other binaries, and compute the new ID bound.
   uint32_t max_id_bound = 0u;
@@ -866,6 +884,10 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
   IRContext linked_context(c_context->target_env, consumer);
   linked_context.module()->SetHeader(header);
 
+  if (make_multitarget) {
+    variant_defs.GenerateHeader(&linked_context);
+  }
+
   // Phase 3: Merge all the binaries into a single one.
   res = MergeModules(consumer, modules, &linked_context);
   if (res != SPV_SUCCESS) return res;
@@ -882,6 +904,10 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
   opt::Pass::Status pass_res = manager.Run(&linked_context);
   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
 
+  if (make_multitarget) {
+    variant_defs.CombineVariantInstructions(&linked_context);
+  }
+
   // Phase 5: Find the import/export pairs
   LinkageTable linkings_to_do;
   res = GetImportExportPairs(consumer, linked_context,

+ 15 - 0
3rdparty/spirv-tools/source/mimalloc.cpp

@@ -0,0 +1,15 @@
+// Copyright (c) 2025 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mimalloc-new-delete.h"

+ 10 - 0
3rdparty/spirv-tools/source/opcode.cpp

@@ -120,6 +120,9 @@ int32_t spvOpcodeIsSpecConstant(const spv::Op opcode) {
     case spv::Op::OpSpecConstantComposite:
     case spv::Op::OpSpecConstantCompositeReplicateEXT:
     case spv::Op::OpSpecConstantOp:
+    case spv::Op::OpSpecConstantArchitectureINTEL:
+    case spv::Op::OpSpecConstantTargetINTEL:
+    case spv::Op::OpSpecConstantCapabilitiesINTEL:
       return true;
     default:
       return false;
@@ -144,6 +147,12 @@ int32_t spvOpcodeIsConstant(const spv::Op opcode) {
     case spv::Op::OpSpecConstantCompositeReplicateEXT:
     case spv::Op::OpSpecConstantOp:
     case spv::Op::OpSpecConstantStringAMDX:
+    case spv::Op::OpGraphConstantARM:
+    case spv::Op::OpAsmTargetINTEL:
+    case spv::Op::OpAsmINTEL:
+    case spv::Op::OpSpecConstantArchitectureINTEL:
+    case spv::Op::OpSpecConstantTargetINTEL:
+    case spv::Op::OpSpecConstantCapabilitiesINTEL:
       return true;
     default:
       return false;
@@ -264,6 +273,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
     case spv::Op::OpTypeTensorViewNV:
     case spv::Op::OpTypeTensorARM:
     case spv::Op::OpTypeTaskSequenceINTEL:
+    case spv::Op::OpTypeGraphARM:
       return true;
     default:
       // In particular, OpTypeForwardPointer does not generate a type,

+ 13 - 0
3rdparty/spirv-tools/source/operand.cpp

@@ -111,6 +111,7 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
       return "kernel profiling info";
     case SPV_OPERAND_TYPE_CAPABILITY:
+    case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
       return "capability";
     case SPV_OPERAND_TYPE_RAY_FLAGS:
       return "ray flags";
@@ -394,6 +395,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
     case SPV_OPERAND_TYPE_OPTIONAL_FPENCODING:
     case SPV_OPERAND_TYPE_OPTIONAL_TENSOR_OPERANDS:
+    case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
       return true;
     default:
       break;
@@ -408,6 +410,7 @@ bool spvOperandIsVariable(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER:
     case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID:
     case SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER:
+    case SPV_OPERAND_TYPE_VARIABLE_CAPABILITY:
       return true;
     default:
       break;
@@ -439,6 +442,10 @@ bool spvExpandOperandSequenceOnce(spv_operand_type_t type,
       pattern->push_back(SPV_OPERAND_TYPE_LITERAL_INTEGER);
       pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID);
       return true;
+    case SPV_OPERAND_TYPE_VARIABLE_CAPABILITY:
+      pattern->push_back(type);
+      pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY);
+      return true;
     default:
       break;
   }
@@ -521,6 +528,9 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
     case spv::Op::OpMemberDecorateStringGOOGLE:
     case spv::Op::OpBranch:
     case spv::Op::OpLoopMerge:
+    case spv::Op::OpConditionalEntryPointINTEL:
+    case spv::Op::OpConditionalCapabilityINTEL:
+    case spv::Op::OpConditionalExtensionINTEL:
       out = [](unsigned) { return true; };
       break;
     case spv::Op::OpGroupDecorate:
@@ -571,6 +581,9 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
       // approximate, due to variable operands
       out = [](unsigned index) { return index > 6; };
       break;
+    case spv::Op::OpGraphEntryPointARM:
+      out = [](unsigned index) { return index == 0; };
+      break;
     default:
       out = [](unsigned) { return false; };
       break;

+ 57 - 1
3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp

@@ -44,6 +44,9 @@ constexpr uint32_t kExtInstSetInIdx = 0;
 constexpr uint32_t kExtInstOpInIdx = 1;
 constexpr uint32_t kInterpolantInIdx = 2;
 constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0;
+constexpr uint32_t kDebugValueLocalVariable = 2;
+constexpr uint32_t kDebugValueValue = 3;
+constexpr uint32_t kDebugValueExpression = 4;
 
 // Sorting functor to present annotation instructions in an easy-to-process
 // order. The functor orders by opcode first and falls back on unique id
@@ -277,9 +280,53 @@ bool AggressiveDCEPass::AggressiveDCE(Function* func) {
   live_local_vars_.clear();
   InitializeWorkList(func, structured_order);
   ProcessWorkList(func);
+  ProcessDebugInformation(structured_order);
+  ProcessWorkList(func);
   return KillDeadInstructions(func, structured_order);
 }
 
+void AggressiveDCEPass::ProcessDebugInformation(
+    std::list<BasicBlock*>& structured_order) {
+  for (auto bi = structured_order.begin(); bi != structured_order.end(); bi++) {
+    (*bi)->ForEachInst([this](Instruction* inst) {
+      // DebugDeclare is not dead. It must be converted to DebugValue in a
+      // later pass
+      if (inst->IsNonSemanticInstruction() &&
+          inst->GetShader100DebugOpcode() ==
+              NonSemanticShaderDebugInfo100DebugDeclare) {
+        AddToWorklist(inst);
+        return;
+      }
+
+      // If the Value of a DebugValue is killed, set Value operand to Undef
+      if (inst->IsNonSemanticInstruction() &&
+          inst->GetShader100DebugOpcode() ==
+              NonSemanticShaderDebugInfo100DebugValue) {
+        uint32_t id = inst->GetSingleWordInOperand(kDebugValueValue);
+        auto def = get_def_use_mgr()->GetDef(id);
+        if (!live_insts_.Set(def->unique_id())) {
+          AddToWorklist(inst);
+          context()->get_def_use_mgr()->UpdateDefUse(inst);
+          worklist_.push(def);
+          def->SetOpcode(spv::Op::OpUndef);
+          def->SetInOperands({});
+          id = inst->GetSingleWordInOperand(kDebugValueLocalVariable);
+          auto localVar = get_def_use_mgr()->GetDef(id);
+          AddToWorklist(localVar);
+          context()->get_def_use_mgr()->UpdateDefUse(localVar);
+          AddOperandsToWorkList(localVar);
+          context()->get_def_use_mgr()->UpdateDefUse(def);
+          id = inst->GetSingleWordInOperand(kDebugValueExpression);
+          auto expression = get_def_use_mgr()->GetDef(id);
+          AddToWorklist(expression);
+          context()->get_def_use_mgr()->UpdateDefUse(expression);
+          return;
+        }
+      }
+    });
+  }
+}
+
 bool AggressiveDCEPass::KillDeadInstructions(
     const Function* func, std::list<BasicBlock*>& structured_order) {
   bool modified = false;
@@ -916,8 +963,17 @@ bool AggressiveDCEPass::ProcessGlobalValues() {
     }
     // Save debug build identifier even if no other instructions refer to it.
     if (dbg.GetShader100DebugOpcode() ==
-        NonSemanticShaderDebugInfo100DebugBuildIdentifier)
+        NonSemanticShaderDebugInfo100DebugBuildIdentifier) {
+      // The debug build identifier refers to other instructions that
+      // can potentially be removed, they also need to be kept alive.
+      dbg.ForEachInId([this](const uint32_t* id) {
+        Instruction* ref_inst = get_def_use_mgr()->GetDef(*id);
+        if (ref_inst) {
+          live_insts_.Set(ref_inst->unique_id());
+        }
+      });
       continue;
+    }
     to_kill_.push_back(&dbg);
     modified = true;
   }

+ 6 - 0
3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h

@@ -150,6 +150,12 @@ class AggressiveDCEPass : public MemPass {
   // will be empty at the end.
   void ProcessWorkList(Function* func);
 
+  // Process each DebugDeclare and DebugValue in |func| that has not been
+  // marked as live in the work list. DebugDeclare's are marked live now, and
+  // DebugValue Value operands are set to OpUndef.  The work list will be empty
+  // at the end.
+  void ProcessDebugInformation(std::list<BasicBlock*>& structured_order);
+
   // Kills any instructions in |func| that have not been marked as live.
   bool KillDeadInstructions(const Function* func,
                             std::list<BasicBlock*>& structured_order);

+ 516 - 0
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp

@@ -0,0 +1,516 @@
+// Copyright (c) 2025 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/canonicalize_ids_pass.h"
+
+#include <algorithm>
+#include <limits>
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status CanonicalizeIdsPass::Process() {
+  // Initialize the new ID map.
+  new_id_.resize(GetBound(), unused_);
+
+  // Scan the IDs and set to unmapped.
+  ScanIds();
+
+  // Create new IDs for types and consts.
+  CanonicalizeTypeAndConst();
+
+  // Create new IDs for names.
+  CanonicalizeNames();
+
+  // Create new IDs for functions.
+  CanonicalizeFunctions();
+
+  // Create new IDs for everything else.
+  CanonicalizeRemainders();
+
+  // Apply the new IDs to the module.
+  auto const modified = ApplyMap();
+
+  // Update bound in the header.
+  if (modified) {
+    UpdateBound();
+  }
+
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+void CanonicalizeIdsPass::ScanIds() {
+  get_module()->ForEachInst(
+      [this](Instruction* inst) {
+        // Look for types and constants.
+        if (spvOpcodeGeneratesType(inst->opcode()) ||
+            spvOpcodeIsConstant(inst->opcode())) {
+          type_and_const_ids_.push_back(inst->result_id());
+          SetNewId(inst->result_id(), unmapped_);
+        }
+        // Look for names.
+        else if (inst->opcode() == spv::Op::OpName) {
+          // store name string in map so that we can compute the hash later
+          auto const name = inst->GetOperand(1).AsString();
+          auto const target = inst->GetSingleWordInOperand(0);
+          name_ids_[name] = target;
+          SetNewId(target, unmapped_);
+        }
+        // Look for function IDs.
+        else if (inst->opcode() == spv::Op::OpFunction) {
+          auto const res_id = inst->result_id();
+          function_ids_.push_back(res_id);
+          SetNewId(res_id, unmapped_);
+        }
+        // Look for remaining result IDs.
+        else if (inst->HasResultId()) {
+          auto const res_id = inst->result_id();
+          SetNewId(res_id, unmapped_);
+        }
+      },
+      true);
+}
+
+void CanonicalizeIdsPass::CanonicalizeTypeAndConst() {
+  // Remap type IDs.
+  static constexpr std::uint32_t soft_type_id_limit = 3011;  // small prime.
+  static constexpr std::uint32_t first_mapped_id = 8;  // offset into ID space
+  for (auto const id : type_and_const_ids_) {
+    if (!IsOldIdUnmapped(id)) {
+      continue;
+    }
+
+    // Compute the hash value.
+    auto const hash_value = HashTypeAndConst(id);
+    if (hash_value != unmapped_) {
+      SetNewId(id, hash_value % soft_type_id_limit + first_mapped_id);
+    }
+  }
+}
+
+// Hash types to canonical values.  This can return ID collisions (it's a bit
+// inevitable): it's up to the caller to handle that gracefully.
+spv::Id CanonicalizeIdsPass::HashTypeAndConst(spv::Id const id) const {
+  spv::Id value = 0;
+
+  auto const inst = get_def_use_mgr()->GetDef(id);
+  auto const op_code = inst->opcode();
+  switch (op_code) {
+    case spv::Op::OpTypeVoid:
+      value = 0;
+      break;
+    case spv::Op::OpTypeBool:
+      value = 1;
+      break;
+    case spv::Op::OpTypeInt: {
+      auto const signedness = inst->GetSingleWordOperand(2);
+      value = 3 + signedness;
+      break;
+    }
+    case spv::Op::OpTypeFloat:
+      value = 5;
+      break;
+    case spv::Op::OpTypeVector: {
+      auto const component_type = inst->GetSingleWordOperand(1);
+      auto const component_count = inst->GetSingleWordOperand(2);
+      value = 6 + HashTypeAndConst(component_type) * (component_count - 1);
+      break;
+    }
+    case spv::Op::OpTypeMatrix: {
+      auto const column_type = inst->GetSingleWordOperand(1);
+      auto const column_count = inst->GetSingleWordOperand(2);
+      value = 30 + HashTypeAndConst(column_type) * (column_count - 1);
+      break;
+    }
+    case spv::Op::OpTypeImage: {
+      // TODO: Why isn't the format used to compute the hash value?
+      auto const sampled_type = inst->GetSingleWordOperand(1);
+      auto const dim = inst->GetSingleWordOperand(2);
+      auto const depth = inst->GetSingleWordOperand(3);
+      auto const arrayed = inst->GetSingleWordOperand(4);
+      auto const ms = inst->GetSingleWordOperand(5);
+      auto const sampled = inst->GetSingleWordOperand(6);
+      value = 120 + HashTypeAndConst(sampled_type) + dim + depth * 8 * 16 +
+              arrayed * 4 * 16 + ms * 2 * 16 + sampled * 1 * 16;
+      break;
+    }
+    case spv::Op::OpTypeSampler:
+      value = 500;
+      break;
+    case spv::Op::OpTypeSampledImage:
+      value = 502;
+      break;
+    case spv::Op::OpTypeArray: {
+      auto const element_type = inst->GetSingleWordOperand(1);
+      auto const length = inst->GetSingleWordOperand(2);
+      value = 501 + HashTypeAndConst(element_type) * length;
+      break;
+    }
+    case spv::Op::OpTypeRuntimeArray: {
+      auto const element_type = inst->GetSingleWordOperand(1);
+      value = 5000 + HashTypeAndConst(element_type);
+      break;
+    }
+    case spv::Op::OpTypeStruct:
+      value = 10000;
+      for (uint32_t w = 1; w < inst->NumOperandWords(); ++w) {
+        value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
+      }
+      break;
+    case spv::Op::OpTypeOpaque: {
+      // TODO: Name is a literal that may have more than one word.
+      auto const name = inst->GetSingleWordOperand(1);
+      value = 6000 + name;
+      break;
+    }
+    case spv::Op::OpTypePointer: {
+      auto const type = inst->GetSingleWordOperand(2);
+      value = 100000 + HashTypeAndConst(type);
+      break;
+    }
+    case spv::Op::OpTypeFunction:
+      value = 200000;
+      for (uint32_t w = 1; w < inst->NumOperandWords(); ++w) {
+        value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
+      }
+      break;
+    case spv::Op::OpTypeEvent:
+      value = 300000;
+      break;
+    case spv::Op::OpTypeDeviceEvent:
+      value = 300001;
+      break;
+    case spv::Op::OpTypeReserveId:
+      value = 300002;
+      break;
+    case spv::Op::OpTypeQueue:
+      value = 300003;
+      break;
+    case spv::Op::OpTypePipe:
+      value = 300004;
+      break;
+    case spv::Op::OpTypePipeStorage:
+      value = 300005;
+      break;
+    case spv::Op::OpTypeNamedBarrier:
+      value = 300006;
+      break;
+    case spv::Op::OpConstantTrue:
+      value = 300007;
+      break;
+    case spv::Op::OpConstantFalse:
+      value = 300008;
+      break;
+    case spv::Op::OpTypeRayQueryKHR:
+      value = 300009;
+      break;
+    case spv::Op::OpTypeAccelerationStructureKHR:
+      value = 300010;
+      break;
+    // Don't map the following types.
+    // TODO: These types were not remapped in the glslang version of the
+    // remapper. Support should be added as necessary.
+    case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
+    case spv::Op::OpTypeCooperativeVectorNV:
+    case spv::Op::OpTypeHitObjectNV:
+    case spv::Op::OpTypeUntypedPointerKHR:
+    case spv::Op::OpTypeNodePayloadArrayAMDX:
+    case spv::Op::OpTypeTensorLayoutNV:
+    case spv::Op::OpTypeTensorViewNV:
+    case spv::Op::OpTypeTensorARM:
+    case spv::Op::OpTypeTaskSequenceINTEL:
+      value = unmapped_;
+      break;
+    case spv::Op::OpConstant: {
+      auto const result_type = inst->GetSingleWordOperand(0);
+      value = 400011 + HashTypeAndConst(result_type);
+      auto const literal = inst->GetOperand(2);
+      for (uint32_t w = 0; w < literal.words.size(); ++w) {
+        value += (w + 3) * literal.words[w];
+      }
+      break;
+    }
+    case spv::Op::OpConstantComposite: {
+      auto const result_type = inst->GetSingleWordOperand(0);
+      value = 300011 + HashTypeAndConst(result_type);
+      for (uint32_t w = 2; w < inst->NumOperandWords(); ++w) {
+        value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
+      }
+      break;
+    }
+    case spv::Op::OpConstantNull: {
+      auto const result_type = inst->GetSingleWordOperand(0);
+      value = 500009 + HashTypeAndConst(result_type);
+      break;
+    }
+    case spv::Op::OpConstantSampler: {
+      auto const result_type = inst->GetSingleWordOperand(0);
+      value = 600011 + HashTypeAndConst(result_type);
+      for (uint32_t w = 2; w < inst->NumOperandWords(); ++w) {
+        value += (w + 1) * inst->GetSingleWordOperand(w);
+      }
+      break;
+    }
+    // Don't map the following constants.
+    // TODO: These constants were not remapped in the glslang version of the
+    // remapper. Support should be added as necessary.
+    case spv::Op::OpConstantCompositeReplicateEXT:
+    case spv::Op::OpConstantFunctionPointerINTEL:
+    case spv::Op::OpConstantStringAMDX:
+    case spv::Op::OpSpecConstantTrue:
+    case spv::Op::OpSpecConstantFalse:
+    case spv::Op::OpSpecConstant:
+    case spv::Op::OpSpecConstantComposite:
+    case spv::Op::OpSpecConstantCompositeReplicateEXT:
+    case spv::Op::OpSpecConstantOp:
+    case spv::Op::OpSpecConstantStringAMDX:
+      value = unmapped_;
+      break;
+    // TODO: Add additional types/constants as needed. See
+    // spvOpcodeGeneratesType and spvOpcodeIsConstant.
+    default:
+      context()->consumer()(SPV_MSG_WARNING, "", {0, 0, 0},
+                            "unhandled opcode will not be canonicalized");
+      break;
+  }
+
+  return value;
+}
+
+void CanonicalizeIdsPass::CanonicalizeNames() {
+  static constexpr std::uint32_t soft_type_id_limit = 3011;  // Small prime.
+  static constexpr std::uint32_t first_mapped_id =
+      3019;  // Offset into ID space.
+
+  for (auto const& [name, target] : name_ids_) {
+    if (!IsOldIdUnmapped(target)) {
+      continue;
+    }
+
+    spv::Id hash_value = 1911;
+    for (const char c : name) {
+      hash_value = hash_value * 1009 + c;
+    }
+
+    if (IsOldIdUnmapped(target)) {
+      SetNewId(target, hash_value % soft_type_id_limit + first_mapped_id);
+    }
+  }
+}
+
+void CanonicalizeIdsPass::CanonicalizeFunctions() {
+  static constexpr std::uint32_t soft_type_id_limit = 19071;  // Small prime.
+  static constexpr std::uint32_t first_mapped_id =
+      6203;  // Offset into ID space.
+  // Window size for context-sensitive canonicalization values
+  // Empirical best size from a single data set.  TODO: Would be a good tunable.
+  // We essentially perform a little convolution around each instruction,
+  // to capture the flavor of nearby code, to hopefully match to similar
+  // code in other modules.
+  static const int32_t window_size = 2;
+
+  for (auto const func_id : function_ids_) {
+    // Store the instructions and opcode hash values in vectors so that the
+    // window of instructions can be easily accessed and avoid having to
+    // recompute the hash value repeatedly in overlapping windows.
+    std::vector<Instruction*> insts;
+    std::vector<uint32_t> opcode_hashvals;
+    auto const func = context()->GetFunction(func_id);
+    func->WhileEachInst([&](Instruction* inst) {
+      insts.emplace_back(inst);
+      opcode_hashvals.emplace_back(HashOpCode(inst));
+      return true;
+    });
+
+    // For every instruction in the function, compute the hash value using the
+    // instruction and a small window of surrounding instructions.
+    assert(insts.size() < (size_t)std::numeric_limits<int32_t>::max());
+    for (int32_t i = 0; i < (int32_t)insts.size(); ++i) {
+      auto const inst = insts[i];
+      if (!inst->HasResultId()) {
+        continue;
+      }
+
+      auto const old_id = inst->result_id();
+      if (!IsOldIdUnmapped(old_id)) {
+        continue;
+      }
+
+      int32_t const lower_bound = std::max(0, i - window_size);
+      int32_t const upper_bound =
+          std::min((int32_t)insts.size() - 1, i + window_size);
+      spv::Id hash_value = func_id * 17;  // Small prime.
+      // Include the hash value of the preceding instructions in the hash but
+      // don't include instructions before the OpFunction.
+      for (int32_t j = i - 1; j >= lower_bound; --j) {
+        auto const local_inst = insts[j];
+        if (local_inst->opcode() == spv::Op::OpFunction) {
+          break;
+        }
+
+        hash_value = hash_value * 30103 +
+                     opcode_hashvals[j];  // 30103 is a semi-arbitrary prime.
+      }
+
+      // Include the hash value of the subsequent instructions in the hash but
+      // don't include instructions past OpFunctionEnd.
+      for (int32_t j = i; j <= upper_bound; ++j) {
+        auto const local_inst = insts[j];
+        if (local_inst->opcode() == spv::Op::OpFunctionEnd) {
+          break;
+        }
+
+        hash_value = hash_value * 30103 +
+                     opcode_hashvals[j];  // 30103 is a semiarbitrary prime.
+      }
+
+      SetNewId(old_id, hash_value % soft_type_id_limit + first_mapped_id);
+    }
+  }
+}
+
+spv::Id CanonicalizeIdsPass::HashOpCode(Instruction const* const inst) const {
+  auto const op_code = inst->opcode();
+  std::uint32_t offset = 0;
+  if (op_code == spv::Op::OpExtInst) {
+    // offset is literal instruction
+    offset = inst->GetSingleWordOperand(3);
+  }
+
+  return (std::uint32_t)op_code * 19 + offset;  // 19 is a small prime.
+}
+
+// Assign remaining IDs sequentially from remaining holes in the new ID space.
+void CanonicalizeIdsPass::CanonicalizeRemainders() {
+  spv::Id next_id = 1;
+  for (uint32_t old_id = 0; old_id < new_id_.size(); ++old_id) {
+    if (IsOldIdUnmapped(old_id)) {
+      next_id = SetNewId(old_id, next_id);
+    }
+  }
+}
+
+bool CanonicalizeIdsPass::ApplyMap() {
+  bool modified = false;
+  context()->module()->ForEachInst(
+      [this, &modified](Instruction* inst) {
+        for (auto operand = inst->begin(); operand != inst->end(); ++operand) {
+          const auto type = operand->type;
+          if (spvIsIdType(type)) {
+            uint32_t& id = operand->words[0];
+            uint32_t const new_id = GetNewId(id);
+            if (new_id == unused_) {
+              continue;
+            }
+
+            assert(new_id != unmapped_ && "new_id should not be unmapped_");
+
+            if (id != new_id) {
+              modified = true;
+              id = new_id;
+              if (type == SPV_OPERAND_TYPE_RESULT_ID) {
+                inst->SetResultId(new_id);
+              } else if (type == SPV_OPERAND_TYPE_TYPE_ID) {
+                inst->SetResultType(new_id);
+              }
+            }
+          }
+        }
+      },
+      true);
+
+  return modified;
+}
+
+spv::Id CanonicalizeIdsPass::GetBound() const {
+  return context()->module()->id_bound();
+}
+
+void CanonicalizeIdsPass::UpdateBound() {
+  context()->module()->SetIdBound(context()->module()->ComputeIdBound());
+
+  context()->ResetFeatureManager();
+}
+
+// Set a new ID. If the new ID is alreadly claimed, the next consecutive ID
+// will be claimed, mapped, and returned to the caller.
+spv::Id CanonicalizeIdsPass::SetNewId(spv::Id const old_id, spv::Id new_id) {
+  assert(old_id < GetBound() && "don't remap an ID that is out of bounds");
+
+  if (old_id >= new_id_.size()) {
+    new_id_.resize(old_id + 1, unused_);
+  }
+
+  if (new_id != unmapped_ && new_id != unused_) {
+    assert(!IsOldIdUnused(old_id) && "don't remap unused IDs");
+    assert(IsOldIdUnmapped(old_id) && "don't remap already mapped IDs");
+
+    new_id = ClaimNewId(new_id);
+  }
+
+  new_id_[old_id] = new_id;
+
+  return new_id;
+}
+
+// Helper function for SetNewID. Claim a new ID. If the new ID is already
+// claimed, the next consecutive ID will be claimed and returned to the caller.
+spv::Id CanonicalizeIdsPass::ClaimNewId(spv::Id new_id) {
+  // Return the ID if it's not taken.
+  auto iter = claimed_new_ids_.find(new_id);
+  if (iter != claimed_new_ids_.end()) {
+    // Otherwise, search for the next unused ID using our current iterator.
+    // Technically, it's a linear search across the set starting at the
+    // iterator, but it's not as bad as it would appear in practice assuming the
+    // hash values are well distributed.
+    iter = std::adjacent_find(iter, claimed_new_ids_.end(), [](int a, int b) {
+      return a + 1 != b;  // Stop at the first non-consecutive pair.
+    });
+    if (iter != claimed_new_ids_.end()) {
+      new_id =
+          *iter + 1;  // We need the next ID after where the search stopped.
+    } else {
+      new_id = *(--iter) + 1;  // We reached the end so we use the next ID.
+    }
+  }
+
+  assert(!IsNewIdClaimed(new_id) &&
+         "don't remap to an ID that is already claimed");
+  iter = claimed_new_ids_.insert(iter, new_id);
+  assert(*iter == new_id);
+
+  return new_id;
+}
+
+std::string CanonicalizeIdsPass::IdAsString(spv::Id const id) const {
+  if (id == unused_) {
+    return "unused";
+  } else if (id == unmapped_) {
+    return "unmapped";
+  } else {
+    return std::to_string(id);
+  }
+}
+
+void CanonicalizeIdsPass::PrintNewIds() const {
+  for (spv::Id id = 0; id < new_id_.size(); ++id) {
+    auto const message =
+        "new id[" + IdAsString(id) + "]: " + IdAsString(new_id_[id]);
+    context()->consumer()(SPV_MSG_INFO, "", {0, 0, 0}, message.c_str());
+  }
+}
+
+}  // namespace opt
+}  // namespace spvtools

+ 115 - 0
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.h

@@ -0,0 +1,115 @@
+// Copyright (c) 2025 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <map>
+#include <set>
+#include <vector>
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// The canonicalize IDs pass is an optimization to improve compression of SPIR-V
+// binary files via entropy reduction. It transforms SPIR-V to SPIR-V, remapping
+// IDs. The resulting modules have an increased ID range (IDs are not as tightly
+// packed around zero), but will compress better when multiple modules are
+// compressed together, since the compressor's dictionary can find better cross
+// module commonality. Remapping is accomplished via canonicalization. Thus,
+// modules can be compressed one at a time with no loss of quality relative to
+// operating on many modules at once.
+
+// This pass should be run after most optimization passes except for
+// --strip-debug because this pass will use OpName to canonicalize IDs. i.e. Run
+// --strip-debug after this pass.
+
+// This is a port of remap utility in glslang. There are great deal of magic
+// numbers that are present throughout this code. The general goal is to replace
+// the IDs with a hash value such that the distribution of IDs is deterministic
+// and minimizes collisions. The magic numbers in the glslang version were
+// chosen semi-arbitrarily and have been preserved in this port in order to
+// maintain backward compatibility.
+
+class CanonicalizeIdsPass : public Pass {
+ public:
+  CanonicalizeIdsPass() = default;
+  virtual ~CanonicalizeIdsPass() = default;
+
+  Pass::Status Process() override;
+
+  const char* name() const override { return "canonicalize-ids"; }
+
+ private:
+  // Special values for IDs.
+  static constexpr spv::Id unmapped_{spv::Id(-10000)};
+  static constexpr spv::Id unused_{spv::Id(-10001)};
+
+  // Scans the module for IDs and sets them to unmapped_.
+  void ScanIds();
+
+  // Functions to compute new IDs.
+  void CanonicalizeTypeAndConst();
+  spv::Id HashTypeAndConst(
+      spv::Id const id) const;  // Helper for CanonicalizeTypeAndConst.
+  void CanonicalizeNames();
+  void CanonicalizeFunctions();
+  spv::Id HashOpCode(Instruction const* const inst)
+      const;  // Helper for CanonicalizeFunctions.
+  void CanonicalizeRemainders();
+
+  // Applies the new IDs.
+  bool ApplyMap();
+
+  // Methods to manage the bound field in header.
+  spv::Id GetBound() const;  // All IDs must satisfy 0 < ID < bound.
+  void UpdateBound();
+
+  // Methods to map from old IDs to new IDs.
+  spv::Id GetNewId(spv::Id const old_id) const { return new_id_[old_id]; }
+  spv::Id SetNewId(spv::Id const old_id, spv::Id new_id);
+
+  // Methods to manage claimed IDs.
+  spv::Id ClaimNewId(spv::Id new_id);
+  bool IsNewIdClaimed(spv::Id const new_id) const {
+    return claimed_new_ids_.find(new_id) != claimed_new_ids_.end();
+  }
+
+  // Queries for old IDs.
+  bool IsOldIdUnmapped(spv::Id const old_id) const {
+    return GetNewId(old_id) == unmapped_;
+  }
+  bool IsOldIdUnused(spv::Id const old_id) const {
+    return GetNewId(old_id) == unused_;
+  }
+
+  // Container to map old IDs to new IDs. e.g. new_id_[old_id] = new_id
+  std::vector<spv::Id> new_id_;
+
+  // IDs from the new ID space that have been claimed (faster than searching
+  // through new_id_).
+  std::set<spv::Id> claimed_new_ids_;
+
+  // Helper functions for printing IDs (useful for debugging).
+  std::string IdAsString(spv::Id const id) const;
+  void PrintNewIds() const;
+
+  // Containers to track IDs we want to canonicalize.
+  std::vector<spv::Id> type_and_const_ids_;
+  std::map<std::string, spv::Id> name_ids_;
+  std::vector<spv::Id> function_ids_;
+};
+
+}  // namespace opt
+}  // namespace spvtools

+ 7 - 0
3rdparty/spirv-tools/source/opt/ccp_pass.cpp

@@ -360,6 +360,13 @@ void CCPPass::Initialize() {
     }
   }
 
+  // Mark the extended instruction imports as `kVarying`. We know they
+  // will not be constants, and will be used by `OpExtInst` instructions.
+  // This allows those instructions to be fully processed.
+  for (const auto& inst : get_module()->ext_inst_imports()) {
+    values_[inst.result_id()] = kVaryingSSAId;
+  }
+
   original_id_bound_ = context()->module()->IdBound();
 }
 

+ 44 - 11
3rdparty/spirv-tools/source/opt/const_folding_rules.cpp

@@ -1395,9 +1395,12 @@ ConstantFoldingRule FoldFMix() {
     if (base_type->AsFloat()->width() == 32) {
       one = const_mgr->GetConstant(base_type,
                                    utils::FloatProxy<float>(1.0f).GetWords());
-    } else {
+    } else if (base_type->AsFloat()->width() == 64) {
       one = const_mgr->GetConstant(base_type,
                                    utils::FloatProxy<double>(1.0).GetWords());
+    } else {
+      // We won't support folding half types.
+      return nullptr;
     }
 
     if (is_vector) {
@@ -1433,14 +1436,29 @@ const analysis::Constant* FoldMin(const analysis::Type* result_type,
                                   const analysis::Constant* b,
                                   analysis::ConstantManager*) {
   if (const analysis::Integer* int_type = result_type->AsInteger()) {
-    if (int_type->width() == 32) {
+    if (int_type->width() <= 32) {
+      assert(
+          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
+          "Must be an integer or null constant.");
+      assert(
+          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
+          "Must be an integer or null constant.");
+
       if (int_type->IsSigned()) {
-        int32_t va = a->GetS32();
-        int32_t vb = b->GetS32();
+        int32_t va = (a->AsIntConstant() != nullptr)
+                         ? a->AsIntConstant()->GetS32BitValue()
+                         : 0;
+        int32_t vb = (b->AsIntConstant() != nullptr)
+                         ? b->AsIntConstant()->GetS32BitValue()
+                         : 0;
         return (va < vb ? a : b);
       } else {
-        uint32_t va = a->GetU32();
-        uint32_t vb = b->GetU32();
+        uint32_t va = (a->AsIntConstant() != nullptr)
+                          ? a->AsIntConstant()->GetU32BitValue()
+                          : 0;
+        uint32_t vb = (b->AsIntConstant() != nullptr)
+                          ? b->AsIntConstant()->GetU32BitValue()
+                          : 0;
         return (va < vb ? a : b);
       }
     } else if (int_type->width() == 64) {
@@ -1473,14 +1491,29 @@ const analysis::Constant* FoldMax(const analysis::Type* result_type,
                                   const analysis::Constant* b,
                                   analysis::ConstantManager*) {
   if (const analysis::Integer* int_type = result_type->AsInteger()) {
-    if (int_type->width() == 32) {
+    if (int_type->width() <= 32) {
+      assert(
+          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
+          "Must be an integer or null constant.");
+      assert(
+          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
+          "Must be an integer or null constant.");
+
       if (int_type->IsSigned()) {
-        int32_t va = a->GetS32();
-        int32_t vb = b->GetS32();
+        int32_t va = (a->AsIntConstant() != nullptr)
+                         ? a->AsIntConstant()->GetS32BitValue()
+                         : 0;
+        int32_t vb = (b->AsIntConstant() != nullptr)
+                         ? b->AsIntConstant()->GetS32BitValue()
+                         : 0;
         return (va > vb ? a : b);
       } else {
-        uint32_t va = a->GetU32();
-        uint32_t vb = b->GetU32();
+        uint32_t va = (a->AsIntConstant() != nullptr)
+                          ? a->AsIntConstant()->GetU32BitValue()
+                          : 0;
+        uint32_t vb = (b->AsIntConstant() != nullptr)
+                          ? b->AsIntConstant()->GetU32BitValue()
+                          : 0;
         return (va > vb ? a : b);
       }
     } else if (int_type->width() == 64) {

+ 1 - 0
3rdparty/spirv-tools/source/opt/constants.cpp

@@ -315,6 +315,7 @@ const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
     case spv::Op::OpConstant:
     case spv::Op::OpConstantComposite:
     case spv::Op::OpSpecConstantComposite:
+    case spv::Op::OpSpecConstantCompositeReplicateEXT:
       break;
     default:
       return nullptr;

+ 8 - 7
3rdparty/spirv-tools/source/opt/debug_info_manager.cpp

@@ -558,11 +558,11 @@ bool DebugInfoManager::IsDeclareVisibleToInstr(Instruction* dbg_declare,
   return false;
 }
 
-bool DebugInfoManager::AddDebugValueForVariable(Instruction* scope_and_line,
+bool DebugInfoManager::AddDebugValueForVariable(Instruction* line,
                                                 uint32_t variable_id,
                                                 uint32_t value_id,
                                                 Instruction* insert_pos) {
-  assert(scope_and_line != nullptr);
+  assert(line != nullptr);
 
   auto dbg_decl_itr = var_id_to_dbg_decl_.find(variable_id);
   if (dbg_decl_itr == var_id_to_dbg_decl_.end()) return false;
@@ -577,14 +577,15 @@ bool DebugInfoManager::AddDebugValueForVariable(Instruction* scope_and_line,
       insert_before = insert_before->NextNode();
     }
     modified |= AddDebugValueForDecl(dbg_decl_or_val, value_id, insert_before,
-                                     scope_and_line) != nullptr;
+                                     line) != nullptr;
   }
   return modified;
 }
 
-Instruction* DebugInfoManager::AddDebugValueForDecl(
-    Instruction* dbg_decl, uint32_t value_id, Instruction* insert_before,
-    Instruction* scope_and_line) {
+Instruction* DebugInfoManager::AddDebugValueForDecl(Instruction* dbg_decl,
+                                                    uint32_t value_id,
+                                                    Instruction* insert_before,
+                                                    Instruction* line) {
   if (dbg_decl == nullptr || !IsDebugDeclare(dbg_decl)) return nullptr;
 
   std::unique_ptr<Instruction> dbg_val(dbg_decl->Clone(context()));
@@ -593,7 +594,7 @@ Instruction* DebugInfoManager::AddDebugValueForDecl(
   dbg_val->SetOperand(kDebugDeclareOperandVariableIndex, {value_id});
   dbg_val->SetOperand(kDebugValueOperandExpressionIndex,
                       {GetEmptyDebugExpression()->result_id()});
-  dbg_val->UpdateDebugInfoFrom(scope_and_line);
+  dbg_val->UpdateDebugInfoFrom(dbg_decl, line);
 
   auto* added_dbg_val = insert_before->InsertBefore(std::move(dbg_val));
   AnalyzeDebugInst(added_dbg_val);

+ 9 - 10
3rdparty/spirv-tools/source/opt/debug_info_manager.h

@@ -143,22 +143,21 @@ class DebugInfoManager {
   bool KillDebugDeclares(uint32_t variable_id);
 
   // Generates a DebugValue instruction with value |value_id| for every local
-  // variable that is in the scope of |scope_and_line| and whose memory is
-  // |variable_id| and inserts it after the instruction |insert_pos|.
+  // variable that is in the scope of |line| and whose memory is |variable_id|
+  // and inserts it after the instruction |insert_pos|.
   // Returns whether a DebugValue is added or not.
-  bool AddDebugValueForVariable(Instruction* scope_and_line,
-                                uint32_t variable_id, uint32_t value_id,
-                                Instruction* insert_pos);
+  bool AddDebugValueForVariable(Instruction* line, uint32_t variable_id,
+                                uint32_t value_id, Instruction* insert_pos);
 
   // Creates a DebugValue for DebugDeclare |dbg_decl| and inserts it before
-  // |insert_before|. The new DebugValue has the same line and scope as
-  // |scope_and_line|, or no scope and line information if |scope_and_line|
-  // is nullptr. The new DebugValue has the same operands as DebugDeclare
-  // but it uses |value_id| for the value. Returns the created DebugValue,
+  // |insert_before|. The new DebugValue has the same line as |line} and the
+  // same scope as |dbg_decl|. The new DebugValue has the same operands as
+  // DebugDeclare but it uses |value_id| for the value. Returns the created
+  // DebugValue,
   // or nullptr if fails to create one.
   Instruction* AddDebugValueForDecl(Instruction* dbg_decl, uint32_t value_id,
                                     Instruction* insert_before,
-                                    Instruction* scope_and_line);
+                                    Instruction* line);
 
   // Erases |instr| from data structures of this class.
   void ClearDebugInfo(Instruction* instr);

+ 21 - 6
3rdparty/spirv-tools/source/opt/desc_sroa.cpp

@@ -58,7 +58,7 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
   std::vector<Instruction*> access_chain_work_list;
   std::vector<Instruction*> load_work_list;
   std::vector<Instruction*> entry_point_work_list;
-  bool failed = !get_def_use_mgr()->WhileEachUser(
+  bool ok = get_def_use_mgr()->WhileEachUser(
       var->result_id(), [this, &access_chain_work_list, &load_work_list,
                          &entry_point_work_list](Instruction* use) {
         if (use->opcode() == spv::Op::OpName) {
@@ -88,7 +88,7 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
         return true;
       });
 
-  if (failed) {
+  if (!ok) {
     return false;
   }
 
@@ -128,6 +128,9 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
 
   uint32_t idx = const_index->GetU32();
   uint32_t replacement_var = GetReplacementVariable(var, idx);
+  if (replacement_var == 0) {
+    return false;
+  }
 
   if (use->NumInOperands() == 2) {
     // We are not indexing into the replacement variable.  We can replaces the
@@ -186,8 +189,11 @@ bool DescriptorScalarReplacement::ReplaceEntryPoint(Instruction* var,
   uint32_t num_replacement_vars =
       descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
   for (uint32_t i = 0; i < num_replacement_vars; i++) {
-    new_operands.push_back(
-        {SPV_OPERAND_TYPE_ID, {GetReplacementVariable(var, i)}});
+    uint32_t replacement_var_id = GetReplacementVariable(var, i);
+    if (replacement_var_id == 0) {
+      return false;
+    }
+    new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var_id}});
   }
 
   use->ReplaceOperands(new_operands);
@@ -310,7 +316,10 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
       element_type_id, storage_class);
 
   // Create the variable.
-  uint32_t id = TakeNextId();
+  uint32_t id = context()->TakeNextId();
+  if (id == 0) {
+    return 0;
+  }
   std::unique_ptr<Instruction> variable(
       new Instruction(context(), spv::Op::OpVariable, ptr_element_type_id, id,
                       std::initializer_list<Operand>{
@@ -444,10 +453,16 @@ bool DescriptorScalarReplacement::ReplaceCompositeExtract(
 
   uint32_t replacement_var =
       GetReplacementVariable(var, extract->GetSingleWordInOperand(1));
+  if (replacement_var == 0) {
+    return false;
+  }
 
   // The result type of the OpLoad is the same as the result type of the
   // OpCompositeExtract.
-  uint32_t load_id = TakeNextId();
+  uint32_t load_id = context()->TakeNextId();
+  if (load_id == 0) {
+    return false;
+  }
   std::unique_ptr<Instruction> load(
       new Instruction(context(), spv::Op::OpLoad, extract->type_id(), load_id,
                       std::initializer_list<Operand>{

+ 9 - 3
3rdparty/spirv-tools/source/opt/feature_manager.cpp

@@ -34,10 +34,13 @@ void FeatureManager::AddExtensions(Module* module) {
 }
 
 void FeatureManager::AddExtension(Instruction* ext) {
-  assert(ext->opcode() == spv::Op::OpExtension &&
+  assert((ext->opcode() == spv::Op::OpExtension ||
+          ext->opcode() == spv::Op::OpConditionalExtensionINTEL) &&
          "Expecting an extension instruction.");
 
-  const std::string name = ext->GetInOperand(0u).AsString();
+  const uint32_t name_i =
+      ext->opcode() == spv::Op::OpConditionalExtensionINTEL ? 1u : 0u;
+  const std::string name = ext->GetInOperand(name_i).AsString();
   Extension extension;
   if (GetExtensionFromString(name.c_str(), &extension)) {
     extensions_.insert(extension);
@@ -72,7 +75,10 @@ void FeatureManager::RemoveCapability(spv::Capability cap) {
 
 void FeatureManager::AddCapabilities(Module* module) {
   for (Instruction& inst : module->capabilities()) {
-    AddCapability(static_cast<spv::Capability>(inst.GetSingleWordInOperand(0)));
+    const uint32_t i_cap =
+        inst.opcode() == spv::Op::OpConditionalCapabilityINTEL ? 1 : 0;
+    AddCapability(
+        static_cast<spv::Capability>(inst.GetSingleWordInOperand(i_cap)));
   }
 }
 

+ 34 - 25
3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp

@@ -1,4 +1,5 @@
 // Copyright (c) 2016 Google Inc.
+// Copyright (c) 2025 Arm Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -31,21 +32,20 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
   // instructions, records their values in two internal maps: id_to_const_val_
   // and const_val_to_id_ so that we can use them to infer the value of Spec
   // Constants later.
-  // For Spec Constants defined with OpSpecConstantComposite instructions, if
-  // all of their components are Normal Constants, they will be turned into
-  // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
-  // instructions, we check if they only depends on Normal Constants and fold
-  // them when possible. The two maps for Normal Constants: id_to_const_val_
-  // and const_val_to_id_ will be updated along the traversal so that the new
-  // Normal Constants generated from folding can be used to fold following Spec
-  // Constants.
-  // This algorithm depends on the SSA property of SPIR-V when
-  // defining constants. The dependent constants must be defined before the
-  // dependee constants. So a dependent Spec Constant must be defined and
-  // will be processed before its dependee Spec Constant. When we encounter
-  // the dependee Spec Constants, all its dependent constants must have been
-  // processed and all its dependent Spec Constants should have been folded if
-  // possible.
+  // For Spec Constants defined with OpSpecConstantComposite or
+  // OpSpecConstantCompositeReplicateEXT instructions, if all of their
+  // components are Normal Constants, they will be turned into Normal Constants
+  // too. For Spec Constants defined with OpSpecConstantOp instructions, we
+  // check if they only depends on Normal Constants and fold them when possible.
+  // The two maps for Normal Constants: id_to_const_val_ and const_val_to_id_
+  // will be updated along the traversal so that the new Normal Constants
+  // generated from folding can be used to fold following Spec Constants. This
+  // algorithm depends on the SSA property of SPIR-V when defining constants.
+  // The dependent constants must be defined before the dependee constants. So a
+  // dependent Spec Constant must be defined and will be processed before its
+  // dependee Spec Constant. When we encounter the dependee Spec Constants, all
+  // its dependent constants must have been processed and all its dependent Spec
+  // Constants should have been folded if possible.
   Module::inst_iterator next_inst = context()->types_values_begin();
   for (Module::inst_iterator inst_iter = next_inst;
        // Need to re-evaluate the end iterator since we may modify the list of
@@ -54,8 +54,9 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
     ++next_inst;
     Instruction* inst = &*inst_iter;
     // Collect constant values of normal constants and process the
-    // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
-    // The constant values will be stored in analysis::Constant instances.
+    // OpSpecConstantOp, OpSpecConstantComposite, and
+    // OpSpecConstantCompositeReplicateEXT instructions if possible. The
+    // constant values will be stored in analysis::Constant instances.
     // OpConstantSampler instruction is not collected here because it cannot be
     // used in OpSpecConstant{Composite|Op} instructions.
     // TODO(qining): If the constant or its type has decoration, we may need
@@ -70,21 +71,29 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
       case spv::Op::OpConstant:
       case spv::Op::OpConstantNull:
       case spv::Op::OpConstantComposite:
-      case spv::Op::OpSpecConstantComposite: {
+      case spv::Op::OpSpecConstantComposite:
+      case spv::Op::OpSpecConstantCompositeReplicateEXT: {
         // A Constant instance will be created if the given instruction is a
         // Normal Constant whose value(s) are fixed. Note that for a composite
-        // Spec Constant defined with OpSpecConstantComposite instruction, if
-        // all of its components are Normal Constants already, the Spec
-        // Constant will be turned in to a Normal Constant. In that case, a
-        // Constant instance should also be created successfully and recorded
-        // in the id_to_const_val_ and const_val_to_id_ mapps.
+        // Spec Constant defined with OpSpecConstantComposite or
+        // OpSpecConstantCompositeReplicateEXT instruction, if all of its
+        // components are Normal Constants already, the Spec Constant will be
+        // turned in to a Normal Constant. In that case, a Constant instance
+        // should also be created successfully and recorded in the
+        // id_to_const_val_ and const_val_to_id_ mapps.
         if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
-          // Need to replace the OpSpecConstantComposite instruction with a
-          // corresponding OpConstantComposite instruction.
+          // Need to replace the OpSpecConstantComposite or
+          // OpSpecConstantCompositeReplicateEXT instruction with a
+          // corresponding OpConstantComposite or
+          // OpConstantCompositeReplicateEXT instruction.
           if (opcode == spv::Op::OpSpecConstantComposite) {
             inst->SetOpcode(spv::Op::OpConstantComposite);
             modified = true;
           }
+          if (opcode == spv::Op::OpSpecConstantCompositeReplicateEXT) {
+            inst->SetOpcode(spv::Op::OpConstantCompositeReplicateEXT);
+            modified = true;
+          }
           const_mgr->MapConstantToInst(const_value, inst);
         }
         break;

+ 9 - 0
3rdparty/spirv-tools/source/opt/folding_rules.cpp

@@ -1998,6 +1998,15 @@ FoldingRule FMixFeedingExtract() {
     bool use_x = false;
 
     assert(a_const->type()->AsFloat());
+
+    const analysis::Type* type =
+        context->get_type_mgr()->GetType(inst->type_id());
+    uint32_t width = ElementWidth(type);
+    if (width != 32 && width != 64) {
+      // We won't support folding half float values.
+      return false;
+    }
+
     double element_value = a_const->GetValueAsDouble();
     if (element_value == 0.0) {
       use_x = true;

+ 130 - 28
3rdparty/spirv-tools/source/opt/graphics_robust_access_pass.cpp

@@ -283,9 +283,14 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
   // use 0 for %min_value).
   auto clamp_index = [&inst, type_mgr, this, &replace_index](
                          uint32_t operand_index, Instruction* old_value,
-                         Instruction* min_value, Instruction* max_value) {
+                         Instruction* min_value,
+                         Instruction* max_value) -> spv_result_t {
     auto* clamp_inst =
         MakeSClampInst(*type_mgr, old_value, min_value, max_value, &inst);
+    if (clamp_inst == nullptr) {
+      Fail();
+      return SPV_ERROR_INTERNAL;
+    }
     return replace_index(operand_index, clamp_inst);
   };
 
@@ -304,7 +309,11 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
 
     if (count <= 1) {
       // Replace the index with 0.
-      return replace_index(operand_index, GetValueForType(0, index_type));
+      Instruction* new_value = GetValueForType(0, index_type);
+      if (new_value == nullptr) {
+        return Fail();
+      }
+      return replace_index(operand_index, new_value);
     }
 
     uint64_t maxval = count - 1;
@@ -318,8 +327,15 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
     // Determine the type for |maxval|.
     uint32_t next_id = context()->module()->IdBound();
     analysis::Integer signed_type_for_query(maxval_width, true);
-    auto* maxval_type =
-        type_mgr->GetRegisteredType(&signed_type_for_query)->AsInteger();
+    auto* maxval_type_registered =
+        type_mgr->GetRegisteredType(&signed_type_for_query);
+    if (maxval_type_registered == nullptr) {
+      return Fail();
+    }
+    auto* maxval_type = maxval_type_registered->AsInteger();
+    if (maxval_type == nullptr) {
+      return Fail();
+    }
     if (next_id != context()->module()->IdBound()) {
       module_status_.modified = true;
     }
@@ -352,15 +368,22 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
         value = int_index_constant->GetS64BitValue();
       }
       if (value < 0) {
-        return replace_index(operand_index, GetValueForType(0, index_type));
+        Instruction* new_value = GetValueForType(0, index_type);
+        if (new_value == nullptr) {
+          return Fail();
+        }
+        return replace_index(operand_index, new_value);
       } else if (uint64_t(value) <= maxval) {
         // Nothing to do.
         return SPV_SUCCESS;
       } else {
         // Replace with maxval.
         assert(count > 0);  // Already took care of this case above.
-        return replace_index(operand_index,
-                             GetValueForType(maxval, maxval_type));
+        Instruction* new_value = GetValueForType(maxval, maxval_type);
+        if (new_value == nullptr) {
+          return Fail();
+        }
+        return replace_index(operand_index, new_value);
       }
     } else {
       // Generate a clamp instruction.
@@ -389,6 +412,9 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
         }
         index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
                                   index_inst, &inst);
+        if (index_inst == nullptr) {
+          return Fail();
+        }
       }
 
       // Finally, clamp the index.
@@ -438,28 +464,51 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
       if (index_type->width() < target_width) {
         // Access chain indices are treated as signed integers.
         index_inst = WidenInteger(true, target_width, index_inst, &inst);
+        if (index_inst == nullptr) {
+          return Fail();
+        }
       } else if (count_type->width() < target_width) {
         // Assume type sizes are treated as unsigned.
         count_inst = WidenInteger(false, target_width, count_inst, &inst);
+        if (count_inst == nullptr) {
+          return Fail();
+        }
       }
       // Compute count - 1.
       // It doesn't matter if 1 is signed or unsigned.
       auto* one = GetValueForType(1, wider_type);
-      auto* count_minus_1 = InsertInst(
-          &inst, spv::Op::OpISub, type_mgr->GetId(wider_type), TakeNextId(),
-          {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
-           {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
+      if (!one) {
+        return Fail();
+      }
+      auto* count_minus_1 =
+          InsertInst(&inst, spv::Op::OpISub, type_mgr->GetId(wider_type),
+                     context()->TakeNextId(),
+                     {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
+                      {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
+      if (count_minus_1 == nullptr) {
+        return Fail();
+      }
       auto* zero = GetValueForType(0, wider_type);
+      if (!zero) {
+        return Fail();
+      }
       // Make sure we clamp to an upper bound that is at most the signed max
       // for the target type.
       const uint64_t max_signed_value =
           ((uint64_t(1) << (target_width - 1)) - 1);
+      Instruction* max_signed_inst =
+          GetValueForType(max_signed_value, wider_type);
+      if (!max_signed_inst) {
+        return Fail();
+      }
       // Use unsigned-min to ensure that the result is always non-negative.
       // That ensures we satisfy the invariant for SClamp, where the "min"
       // argument we give it (zero), is no larger than the third argument.
       auto* upper_bound =
-          MakeUMinInst(*type_mgr, count_minus_1,
-                       GetValueForType(max_signed_value, wider_type), &inst);
+          MakeUMinInst(*type_mgr, count_minus_1, max_signed_inst, &inst);
+      if (upper_bound == nullptr) {
+        return Fail();
+      }
       // Now clamp the index to this upper bound.
       return clamp_index(operand_index, index_inst, zero, upper_bound);
     }
@@ -485,7 +534,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
       case spv::Op::OpTypeVector:  // Use component count
       {
         const uint32_t count = pointee_type->GetSingleWordOperand(2);
-        clamp_to_literal_count(idx, count);
+        if (clamp_to_literal_count(idx, count) != SPV_SUCCESS) return;
         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
       } break;
 
@@ -493,7 +542,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
         // The array length can be a spec constant, so go through the general
         // case.
         Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
-        clamp_to_count(idx, array_len);
+        if (clamp_to_count(idx, array_len) != SPV_SUCCESS) return;
         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
       } break;
 
@@ -537,7 +586,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
         if (!array_len) {  // We've already signaled an error.
           return;
         }
-        clamp_to_count(idx, array_len);
+        if (clamp_to_count(idx, array_len) != SPV_SUCCESS) return;
         if (module_status_.failed) return;
         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
       } break;
@@ -563,7 +612,10 @@ uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
     }
     if (module_status_.glsl_insts_id == 0) {
       // Make a new import instruction.
-      module_status_.glsl_insts_id = TakeNextId();
+      module_status_.glsl_insts_id = context()->TakeNextId();
+      if (module_status_.glsl_insts_id == 0) {
+        return 0;
+      }
       std::vector<uint32_t> words = spvtools::utils::MakeVector(glsl);
       auto import_inst = MakeUnique<Instruction>(
           context(), spv::Op::OpExtInstImport, 0, module_status_.glsl_insts_id,
@@ -602,7 +654,10 @@ opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
   auto* type_mgr = context()->get_type_mgr();
   auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
   auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
-  auto conversion_id = TakeNextId();
+  auto conversion_id = context()->TakeNextId();
+  if (conversion_id == 0) {
+    return nullptr;
+  }
   auto* conversion = InsertInst(
       before_inst, (sign_extend ? spv::Op::OpSConvert : spv::Op::OpUConvert),
       type_id, conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
@@ -616,7 +671,13 @@ Instruction* GraphicsRobustAccessPass::MakeUMinInst(
   // the function so we force a deterministic ordering in case both of them need
   // to take a new ID.
   const uint32_t glsl_insts_id = GetGlslInsts();
-  uint32_t smin_id = TakeNextId();
+  if (glsl_insts_id == 0) {
+    return nullptr;
+  }
+  uint32_t smin_id = context()->TakeNextId();
+  if (smin_id == 0) {
+    return nullptr;
+  }
   const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
   const auto ywidth = tm.GetType(y->type_id())->AsInteger()->width();
   assert(xwidth == ywidth);
@@ -640,7 +701,13 @@ Instruction* GraphicsRobustAccessPass::MakeSClampInst(
   // the function so we force a deterministic ordering in case both of them need
   // to take a new ID.
   const uint32_t glsl_insts_id = GetGlslInsts();
-  uint32_t clamp_id = TakeNextId();
+  if (glsl_insts_id == 0) {
+    return nullptr;
+  }
+  uint32_t clamp_id = context()->TakeNextId();
+  if (clamp_id == 0) {
+    return nullptr;
+  }
   const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
   const auto minwidth = tm.GetType(min->type_id())->AsInteger()->width();
   const auto maxwidth = tm.GetType(max->type_id())->AsInteger()->width();
@@ -755,7 +822,11 @@ Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
               base_ptr_type->storage_class());
 
           // Create the instruction and insert it.
-          const auto new_access_chain_id = TakeNextId();
+          const auto new_access_chain_id = context()->TakeNextId();
+          if (new_access_chain_id == 0) {
+            Fail();
+            return nullptr;
+          }
           auto* new_access_chain =
               InsertInst(current_access_chain, current_access_chain->opcode(),
                          new_access_chain_type_id, new_access_chain_id, ops);
@@ -784,7 +855,11 @@ Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
       uint32_t(struct_type->element_types().size() - 1);
   // Create the length-of-array instruction before the original access chain,
   // but after the generation of the pointer to the struct.
-  const auto array_len_id = TakeNextId();
+  const auto array_len_id = context()->TakeNextId();
+  if (array_len_id == 0) {
+    Fail();
+    return nullptr;
+  }
   analysis::Integer uint_type_for_query(32, false);
   auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
   auto* array_len = InsertInst(
@@ -935,12 +1010,18 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
     return type_mgr->GetRegisteredType(&proposed);
   }();
 
-  const uint32_t image_id = TakeNextId();
+  const uint32_t image_id = context()->TakeNextId();
+  if (image_id == 0) {
+    return Fail();
+  }
   auto* image =
       InsertInst(image_texel_pointer, spv::Op::OpLoad, image_type_id, image_id,
                  {{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
 
-  const uint32_t query_size_id = TakeNextId();
+  const uint32_t query_size_id = context()->TakeNextId();
+  if (query_size_id == 0) {
+    return Fail();
+  }
   auto* query_size =
       InsertInst(image_texel_pointer, spv::Op::OpImageQuerySize,
                  type_mgr->GetTypeInstruction(query_size_type), query_size_id,
@@ -968,7 +1049,10 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
         query_size_type, {component_1_id, component_1_id, component_6_id});
     auto* multiplicand_inst =
         constant_mgr->GetDefiningInstruction(multiplicand);
-    const auto query_size_including_faces_id = TakeNextId();
+    const auto query_size_including_faces_id = context()->TakeNextId();
+    if (query_size_including_faces_id == 0) {
+      return Fail();
+    }
     query_size_including_faces = InsertInst(
         image_texel_pointer, spv::Op::OpIMul,
         type_mgr->GetTypeInstruction(query_size_type),
@@ -992,7 +1076,10 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
                 query_size_type,
                 std::vector<uint32_t>(query_num_components, component_0_id));
 
-  const uint32_t query_max_including_faces_id = TakeNextId();
+  const uint32_t query_max_including_faces_id = context()->TakeNextId();
+  if (query_max_including_faces_id == 0) {
+    return Fail();
+  }
   auto* query_max_including_faces = InsertInst(
       image_texel_pointer, spv::Op::OpISub,
       type_mgr->GetTypeInstruction(query_size_type),
@@ -1005,18 +1092,27 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
   auto* clamp_coord = MakeSClampInst(
       *type_mgr, coord, constant_mgr->GetDefiningInstruction(coordinate_0),
       query_max_including_faces, image_texel_pointer);
+  if (clamp_coord == nullptr) {
+    return Fail();
+  }
   image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
 
   // Clamp the sample index
   if (multisampled) {
     // Get the sample count via OpImageQuerySamples
-    const auto query_samples_id = TakeNextId();
+    const auto query_samples_id = context()->TakeNextId();
+    if (query_samples_id == 0) {
+      return Fail();
+    }
     auto* query_samples = InsertInst(
         image_texel_pointer, spv::Op::OpImageQuerySamples,
         constant_mgr->GetDefiningInstruction(component_0)->type_id(),
         query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
 
-    const auto max_samples_id = TakeNextId();
+    const auto max_samples_id = context()->TakeNextId();
+    if (max_samples_id == 0) {
+      return Fail();
+    }
     auto* max_samples = InsertInst(image_texel_pointer, spv::Op::OpImageQuerySamples,
                                    query_samples->type_id(), max_samples_id,
                                    {{SPV_OPERAND_TYPE_ID, {query_samples_id}},
@@ -1025,6 +1121,9 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
     auto* clamp_samples = MakeSClampInst(
         *type_mgr, samples, constant_mgr->GetDefiningInstruction(coordinate_0),
         max_samples, image_texel_pointer);
+    if (clamp_samples == nullptr) {
+      return Fail();
+    }
     image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
 
   } else {
@@ -1041,6 +1140,9 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
 opt::Instruction* GraphicsRobustAccessPass::InsertInst(
     opt::Instruction* where_inst, spv::Op opcode, uint32_t type_id,
     uint32_t result_id, const Instruction::OperandList& operands) {
+  if (result_id == 0) {
+    return nullptr;
+  }
   module_status_.modified = true;
   auto* result = where_inst->InsertBefore(
       MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));

+ 11 - 3
3rdparty/spirv-tools/source/opt/instruction.cpp

@@ -546,11 +546,13 @@ void Instruction::ClearDbgLineInsts() {
   clear_dbg_line_insts();
 }
 
-void Instruction::UpdateDebugInfoFrom(const Instruction* from) {
+void Instruction::UpdateDebugInfoFrom(const Instruction* from,
+                                      const Instruction* line) {
   if (from == nullptr) return;
   ClearDbgLineInsts();
-  if (!from->dbg_line_insts().empty())
-    AddDebugLine(&from->dbg_line_insts().back());
+  const Instruction* fromLine = line != nullptr ? line : from;
+  if (!fromLine->dbg_line_insts().empty())
+    AddDebugLine(&fromLine->dbg_line_insts().back());
   SetDebugScope(from->GetDebugScope());
   if (!IsLineInst() &&
       context()->AreAnalysesValid(IRContext::kAnalysisDebugInfo)) {
@@ -1033,6 +1035,12 @@ bool Instruction::IsOpcodeSafeToDelete() const {
     return true;
   }
 
+  if (IsNonSemanticInstruction() &&
+      (GetShader100DebugOpcode() == NonSemanticShaderDebugInfo100DebugDeclare ||
+       GetShader100DebugOpcode() == NonSemanticShaderDebugInfo100DebugValue)) {
+    return true;
+  }
+
   switch (opcode()) {
     case spv::Op::OpDPdx:
     case spv::Op::OpDPdy:

+ 2 - 1
3rdparty/spirv-tools/source/opt/instruction.h

@@ -338,7 +338,8 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // Updates lexical scope of DebugScope and OpLine.
   void UpdateLexicalScope(uint32_t scope);
   // Updates OpLine and DebugScope based on the information of |from|.
-  void UpdateDebugInfoFrom(const Instruction* from);
+  void UpdateDebugInfoFrom(const Instruction* from,
+                           const Instruction* line = nullptr);
   // Remove the |index|-th operand
   void RemoveOperand(uint32_t index) {
     operands_.erase(operands_.begin() + index);

+ 140 - 77
3rdparty/spirv-tools/source/opt/interface_var_sroa.cpp

@@ -239,28 +239,34 @@ void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
       });
 }
 
-bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
+Pass::Status
+InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
     Instruction* interface_var, Instruction* interface_var_type,
     uint32_t location, uint32_t component, uint32_t extra_array_length) {
-  NestedCompositeComponents scalar_interface_vars =
+  std::optional<NestedCompositeComponents> scalar_interface_vars =
       CreateScalarInterfaceVarsForReplacement(interface_var_type,
                                               GetStorageClass(interface_var),
                                               extra_array_length);
 
-  AddLocationAndComponentDecorations(scalar_interface_vars, &location,
+  if (!scalar_interface_vars) {
+    return Status::Failure;
+  }
+
+  AddLocationAndComponentDecorations(*scalar_interface_vars, &location,
                                      component);
   KillLocationAndComponentDecorations(interface_var->result_id());
 
-  if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
-                               scalar_interface_vars)) {
-    return false;
+  Status status = ReplaceInterfaceVarWith(interface_var, extra_array_length,
+                                          *scalar_interface_vars);
+  if (status == Status::Failure) {
+    return status;
   }
 
   context()->KillInst(interface_var);
-  return true;
+  return status;
 }
 
-bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
+Pass::Status InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
     Instruction* interface_var, uint32_t extra_array_length,
     const NestedCompositeComponents& scalar_interface_vars) {
   std::vector<Instruction*> users;
@@ -276,21 +282,24 @@ bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
     // interface variable.
     for (uint32_t index = 0; index < extra_array_length; ++index) {
       std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
-      if (!ReplaceComponentsOfInterfaceVarWith(
-              interface_var, users, scalar_interface_vars,
-              interface_var_component_indices, &index,
-              &loads_to_component_values,
-              &loads_for_access_chain_to_composites)) {
-        return false;
+      Status status = ReplaceComponentsOfInterfaceVarWith(
+          interface_var, users, scalar_interface_vars,
+          interface_var_component_indices, &index, &loads_to_component_values,
+          &loads_for_access_chain_to_composites);
+      if (status == Status::Failure) {
+        return Status::Failure;
       }
       AddComponentsToCompositesForLoads(loads_to_component_values,
                                         &loads_to_composites, 0);
     }
-  } else if (!ReplaceComponentsOfInterfaceVarWith(
-                 interface_var, users, scalar_interface_vars,
-                 interface_var_component_indices, nullptr, &loads_to_composites,
-                 &loads_for_access_chain_to_composites)) {
-    return false;
+  } else {
+    Status status = ReplaceComponentsOfInterfaceVarWith(
+        interface_var, users, scalar_interface_vars,
+        interface_var_component_indices, nullptr, &loads_to_composites,
+        &loads_for_access_chain_to_composites);
+    if (status == Status::Failure) {
+      return Status::Failure;
+    }
   }
 
   ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
@@ -298,7 +307,7 @@ bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
                                     loads_for_access_chain_to_composites);
 
   KillInstructionsAndUsers(users);
-  return true;
+  return Status::SuccessWithChange;
 }
 
 void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
@@ -318,7 +327,8 @@ void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
   }
 }
 
-bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
+Pass::Status
+InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
     Instruction* interface_var,
     const std::vector<Instruction*>& interface_var_users,
     const NestedCompositeComponents& scalar_interface_vars,
@@ -329,15 +339,16 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
         loads_for_access_chain_to_composites) {
   if (!scalar_interface_vars.HasMultipleComponents()) {
     for (Instruction* interface_var_user : interface_var_users) {
-      if (!ReplaceComponentOfInterfaceVarWith(
-              interface_var, interface_var_user,
-              scalar_interface_vars.GetComponentVariable(),
-              interface_var_component_indices, extra_array_index,
-              loads_to_composites, loads_for_access_chain_to_composites)) {
-        return false;
+      Status status = ReplaceComponentOfInterfaceVarWith(
+          interface_var, interface_var_user,
+          scalar_interface_vars.GetComponentVariable(),
+          interface_var_component_indices, extra_array_index,
+          loads_to_composites, loads_for_access_chain_to_composites);
+      if (status == Status::Failure) {
+        return Status::Failure;
       }
     }
-    return true;
+    return Status::SuccessWithChange;
   }
   return ReplaceMultipleComponentsOfInterfaceVarWith(
       interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
@@ -345,27 +356,28 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
       loads_for_access_chain_to_composites);
 }
 
-bool InterfaceVariableScalarReplacement::
-    ReplaceMultipleComponentsOfInterfaceVarWith(
-        Instruction* interface_var,
-        const std::vector<Instruction*>& interface_var_users,
-        const std::vector<NestedCompositeComponents>& components,
-        std::vector<uint32_t>& interface_var_component_indices,
-        const uint32_t* extra_array_index,
-        std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
-        std::unordered_map<Instruction*, Instruction*>*
-            loads_for_access_chain_to_composites) {
+Pass::Status
+InterfaceVariableScalarReplacement::ReplaceMultipleComponentsOfInterfaceVarWith(
+    Instruction* interface_var,
+    const std::vector<Instruction*>& interface_var_users,
+    const std::vector<NestedCompositeComponents>& components,
+    std::vector<uint32_t>& interface_var_component_indices,
+    const uint32_t* extra_array_index,
+    std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
+    std::unordered_map<Instruction*, Instruction*>*
+        loads_for_access_chain_to_composites) {
   for (uint32_t i = 0; i < components.size(); ++i) {
     interface_var_component_indices.push_back(i);
     std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
     std::unordered_map<Instruction*, Instruction*>
         loads_for_access_chain_to_component_values;
-    if (!ReplaceComponentsOfInterfaceVarWith(
-            interface_var, interface_var_users, components[i],
-            interface_var_component_indices, extra_array_index,
-            &loads_to_component_values,
-            &loads_for_access_chain_to_component_values)) {
-      return false;
+    Status status = ReplaceComponentsOfInterfaceVarWith(
+        interface_var, interface_var_users, components[i],
+        interface_var_component_indices, extra_array_index,
+        &loads_to_component_values,
+        &loads_for_access_chain_to_component_values);
+    if (status == Status::Failure) {
+      return Status::Failure;
     }
     interface_var_component_indices.pop_back();
 
@@ -378,10 +390,11 @@ bool InterfaceVariableScalarReplacement::
     AddComponentsToCompositesForLoads(loads_to_component_values,
                                       loads_to_composites, depth_to_component);
   }
-  return true;
+  return Status::SuccessWithChange;
 }
 
-bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
+Pass::Status
+InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
     Instruction* interface_var, Instruction* interface_var_user,
     Instruction* scalar_var,
     const std::vector<uint32_t>& interface_var_component_indices,
@@ -395,42 +408,49 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
     StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
                                      scalar_var, extra_array_index,
                                      interface_var_user);
-    return true;
+    return Status::SuccessWithChange;
   }
   if (opcode == spv::Op::OpLoad) {
     Instruction* scalar_load =
         LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
+    if (scalar_load == nullptr) {
+      return Status::Failure;
+    }
     loads_to_component_values->insert({interface_var_user, scalar_load});
-    return true;
+    return Status::SuccessWithChange;
   }
 
   // Copy OpName and annotation instructions only once. Therefore, we create
   // them only for the first element of the extra array.
-  if (extra_array_index && *extra_array_index != 0) return true;
+  if (extra_array_index && *extra_array_index != 0)
+    return Status::SuccessWithChange;
 
   if (opcode == spv::Op::OpDecorateId || opcode == spv::Op::OpDecorateString ||
       opcode == spv::Op::OpDecorate) {
     CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
-    return true;
+    return Status::SuccessWithChange;
   }
 
   if (opcode == spv::Op::OpName) {
     std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
     new_inst->SetInOperand(0, {scalar_var->result_id()});
     context()->AddDebug2Inst(std::move(new_inst));
-    return true;
+    return Status::SuccessWithChange;
   }
 
   if (opcode == spv::Op::OpEntryPoint) {
-    return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
-                                           scalar_var->result_id());
+    if (ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
+                                        scalar_var->result_id())) {
+      return Status::SuccessWithChange;
+    }
+    return Status::Failure;
   }
 
   if (opcode == spv::Op::OpAccessChain) {
     ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
                            scalar_var,
                            loads_for_access_chain_to_component_values);
-    return true;
+    return Status::SuccessWithChange;
   }
 
   std::string message("Unhandled instruction");
@@ -440,7 +460,7 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
       "\nfor interface variable scalar replacement\n  " +
       interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
   context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
-  return false;
+  return Status::Failure;
 }
 
 void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
@@ -470,10 +490,14 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
   uint32_t ptr_type_id =
       GetPointerType(*component_type_id, GetStorageClass(var));
 
-  std::unique_ptr<Instruction> new_access_chain(new Instruction(
-      context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
-      std::initializer_list<Operand>{
-          {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
+  uint32_t new_id = TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
+  std::unique_ptr<Instruction> new_access_chain(
+      new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
+                      std::initializer_list<Operand>{
+                          {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
   for (uint32_t index_id : index_ids) {
     new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
   }
@@ -490,12 +514,16 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
   uint32_t ptr_type_id =
       GetPointerType(component_type_id, GetStorageClass(var));
   uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
-  std::unique_ptr<Instruction> new_access_chain(new Instruction(
-      context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
-      std::initializer_list<Operand>{
-          {SPV_OPERAND_TYPE_ID, {var->result_id()}},
-          {SPV_OPERAND_TYPE_ID, {index_id}},
-      }));
+  uint32_t new_id = TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
+  std::unique_ptr<Instruction> new_access_chain(
+      new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
+                      std::initializer_list<Operand>{
+                          {SPV_OPERAND_TYPE_ID, {var->result_id()}},
+                          {SPV_OPERAND_TYPE_ID, {index_id}},
+                      }));
   Instruction* inst = new_access_chain.get();
   context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
   insert_before->InsertBefore(std::move(new_access_chain));
@@ -617,6 +645,9 @@ void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
     component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
     ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
                                      *extra_array_index, insert_before);
+    if (ptr == nullptr) {
+      return;
+    }
   }
 
   StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
@@ -635,6 +666,9 @@ Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
     component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
     ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
                                      *extra_array_index, insert_before);
+    if (ptr == nullptr) {
+      return nullptr;
+    }
   }
 
   return CreateLoad(component_type_id, ptr, insert_before);
@@ -642,8 +676,12 @@ Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
 
 Instruction* InterfaceVariableScalarReplacement::CreateLoad(
     uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
+  uint32_t new_id = TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
   std::unique_ptr<Instruction> load(
-      new Instruction(context(), spv::Op::OpLoad, type_id, TakeNextId(),
+      new Instruction(context(), spv::Op::OpLoad, type_id, new_id,
                       std::initializer_list<Operand>{
                           {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
   Instruction* load_inst = load.get();
@@ -658,6 +696,9 @@ void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
     const uint32_t* extra_array_index, Instruction* insert_before) {
   std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
       component_type_id, value_id, component_indices, extra_array_index));
+  if (composite_extract == nullptr) {
+    return;
+  }
 
   std::unique_ptr<Instruction> new_store(
       new Instruction(context(), spv::Op::OpStore));
@@ -677,6 +718,9 @@ Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
     uint32_t type_id, uint32_t composite_id,
     const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
   uint32_t component_id = TakeNextId();
+  if (component_id == 0) {
+    return nullptr;
+  }
   Instruction* composite_extract = new Instruction(
       context(), spv::Op::OpCompositeExtract, type_id, component_id,
       std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
@@ -716,6 +760,9 @@ Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
   if (!indexes.empty()) {
     ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
                                  &component_type_id);
+    if (ptr == nullptr) {
+      return nullptr;
+    }
   }
 
   return CreateLoad(component_type_id, ptr, insert_before);
@@ -730,7 +777,10 @@ InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
     type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
                                             depth_to_component);
   }
-  uint32_t new_id = context()->TakeNextId();
+  uint32_t new_id = TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
   std::unique_ptr<Instruction> new_composite_construct(new Instruction(
       context(), spv::Op::OpCompositeConstruct, type_id, new_id, {}));
   Instruction* composite_construct = new_composite_construct.get();
@@ -767,6 +817,10 @@ void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
     if (itr == loads_to_composites->end()) {
       composite_construct =
           CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
+      if (composite_construct == nullptr) {
+        assert(false && "Could not create composite construct");
+        return;
+      }
       loads_to_composites->insert({load, composite_construct});
     } else {
       composite_construct = itr->second;
@@ -795,7 +849,7 @@ uint32_t InterfaceVariableScalarReplacement::GetPointerType(
   return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
 }
 
-InterfaceVariableScalarReplacement::NestedCompositeComponents
+std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
     Instruction* interface_var_type, spv::StorageClass storage_class,
     uint32_t extra_array_length) {
@@ -807,16 +861,19 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
 
   NestedCompositeComponents scalar_vars;
   while (array_length > 0) {
-    NestedCompositeComponents scalar_vars_for_element =
+    std::optional<NestedCompositeComponents> scalar_vars_for_element =
         CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
                                                 extra_array_length);
-    scalar_vars.AddComponent(scalar_vars_for_element);
+    if (!scalar_vars_for_element) {
+      return std::nullopt;
+    }
+    scalar_vars.AddComponent(*scalar_vars_for_element);
     --array_length;
   }
   return scalar_vars;
 }
 
-InterfaceVariableScalarReplacement::NestedCompositeComponents
+std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
     Instruction* interface_var_type, spv::StorageClass storage_class,
     uint32_t extra_array_length) {
@@ -830,16 +887,19 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
 
   NestedCompositeComponents scalar_vars;
   while (column_count > 0) {
-    NestedCompositeComponents scalar_vars_for_column =
+    std::optional<NestedCompositeComponents> scalar_vars_for_column =
         CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
                                                 extra_array_length);
-    scalar_vars.AddComponent(scalar_vars_for_column);
+    if (!scalar_vars_for_column) {
+      return std::nullopt;
+    }
+    scalar_vars.AddComponent(*scalar_vars_for_column);
     --column_count;
   }
   return scalar_vars;
 }
 
-InterfaceVariableScalarReplacement::NestedCompositeComponents
+std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
 InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
     Instruction* interface_var_type, spv::StorageClass storage_class,
     uint32_t extra_array_length) {
@@ -864,6 +924,9 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
   uint32_t ptr_type_id =
       context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
   uint32_t id = TakeNextId();
+  if (id == 0) {
+    return std::nullopt;
+  }
   std::unique_ptr<Instruction> variable(
       new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id,
                       std::initializer_list<Operand>{
@@ -953,9 +1016,9 @@ InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
       continue;
     }
 
-    if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
-                                             location, component,
-                                             extra_array_length)) {
+    if (ReplaceInterfaceVariableWithScalars(
+            interface_var, interface_var_type, location, component,
+            extra_array_length) == Pass::Status::Failure) {
       return Pass::Status::Failure;
     }
     status = Pass::Status::SuccessWithChange;

+ 19 - 14
3rdparty/spirv-tools/source/opt/interface_var_sroa.h

@@ -15,6 +15,7 @@
 #ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
 #define SOURCE_OPT_INTERFACE_VAR_SROA_H_
 
+#include <optional>
 #include <unordered_set>
 
 #include "source/opt/pass.h"
@@ -100,25 +101,26 @@ class InterfaceVariableScalarReplacement : public Pass {
   // If |extra_array_length| is 0, it means |interface_var| has a Patch
   // decoration. Otherwise, |extra_array_length| denotes the length of the extra
   // array of |interface_var|.
-  bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
-                                           Instruction* interface_var_type,
-                                           uint32_t location,
-                                           uint32_t component,
-                                           uint32_t extra_array_length);
+  Status ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
+                                             Instruction* interface_var_type,
+                                             uint32_t location,
+                                             uint32_t component,
+                                             uint32_t extra_array_length);
 
   // Creates scalar variables with the storage classe |storage_class| to replace
   // an interface variable whose type is |interface_var_type|. If
   // |extra_array_length| is not zero, adds the extra arrayness to the created
   // scalar variables.
-  NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
-      Instruction* interface_var_type, spv::StorageClass storage_class,
-      uint32_t extra_array_length);
+  std::optional<NestedCompositeComponents>
+  CreateScalarInterfaceVarsForReplacement(Instruction* interface_var_type,
+                                          spv::StorageClass storage_class,
+                                          uint32_t extra_array_length);
 
   // Creates scalar variables with the storage classe |storage_class| to replace
   // the interface variable whose type is OpTypeArray |interface_var_type| with.
   // If |extra_array_length| is not zero, adds the extra arrayness to all the
   // scalar variables.
-  NestedCompositeComponents CreateScalarInterfaceVarsForArray(
+  std::optional<NestedCompositeComponents> CreateScalarInterfaceVarsForArray(
       Instruction* interface_var_type, spv::StorageClass storage_class,
       uint32_t extra_array_length);
 
@@ -126,7 +128,7 @@ class InterfaceVariableScalarReplacement : public Pass {
   // the interface variable whose type is OpTypeMatrix |interface_var_type|
   // with. If |extra_array_length| is not zero, adds the extra arrayness to all
   // the scalar variables.
-  NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
+  std::optional<NestedCompositeComponents> CreateScalarInterfaceVarsForMatrix(
       Instruction* interface_var_type, spv::StorageClass storage_class,
       uint32_t extra_array_length);
 
@@ -142,7 +144,7 @@ class InterfaceVariableScalarReplacement : public Pass {
   // |extra_arrayness| is the extra arrayness of the interface variable.
   // |scalar_interface_vars| contains the nested variables to replace the
   // interface variable with.
-  bool ReplaceInterfaceVarWith(
+  Status ReplaceInterfaceVarWith(
       Instruction* interface_var, uint32_t extra_arrayness,
       const NestedCompositeComponents& scalar_interface_vars);
 
@@ -155,7 +157,7 @@ class InterfaceVariableScalarReplacement : public Pass {
   // construct instructions to be replaced with load instructions of access
   // chain instructions in |interface_var_users| via
   // |loads_for_access_chain_to_composites|.
-  bool ReplaceComponentsOfInterfaceVarWith(
+  Status ReplaceComponentsOfInterfaceVarWith(
       Instruction* interface_var,
       const std::vector<Instruction*>& interface_var_users,
       const NestedCompositeComponents& scalar_interface_vars,
@@ -174,7 +176,7 @@ class InterfaceVariableScalarReplacement : public Pass {
   // via |loads_to_composites|. Returns composite construct instructions to be
   // replaced with load instructions of access chain instructions in
   // |interface_var_users| via |loads_for_access_chain_to_composites|.
-  bool ReplaceMultipleComponentsOfInterfaceVarWith(
+  Status ReplaceMultipleComponentsOfInterfaceVarWith(
       Instruction* interface_var,
       const std::vector<Instruction*>& interface_var_users,
       const std::vector<NestedCompositeComponents>& components,
@@ -192,7 +194,7 @@ class InterfaceVariableScalarReplacement : public Pass {
   // |loads_to_component_values|. If |interface_var_user| is an access chain,
   // returns the component value for loads of |interface_var_user| via
   // |loads_for_access_chain_to_component_values|.
-  bool ReplaceComponentOfInterfaceVarWith(
+  Status ReplaceComponentOfInterfaceVarWith(
       Instruction* interface_var, Instruction* interface_var_user,
       Instruction* scalar_var,
       const std::vector<uint32_t>& interface_var_component_indices,
@@ -389,6 +391,9 @@ class InterfaceVariableScalarReplacement : public Pass {
   // A set of interface variables without the extra arrayness for any of the
   // entry points.
   std::unordered_set<Instruction*> vars_without_extra_arrayness;
+
+  // Returns the next available id, or 0 if the id overflows.
+  uint32_t TakeNextId() { return context()->TakeNextId(); }
 };
 
 }  // namespace opt

+ 49 - 20
3rdparty/spirv-tools/source/opt/invocation_interlock_placement_pass.cpp

@@ -294,8 +294,12 @@ bool InvocationInterlockPlacementPass::removeUnneededInstructions(
 BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block,
                                                         uint32_t succ_id) {
   // Create a new block to replace the critical edge.
+  uint32_t new_id = context()->TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
   auto new_succ_temp = MakeUnique<BasicBlock>(
-      MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, TakeNextId(),
+      MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, new_id,
                               std::initializer_list<Operand>{}));
   auto* new_succ = new_succ_temp.get();
 
@@ -325,7 +329,7 @@ BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block,
   return new_succ;
 }
 
-bool InvocationInterlockPlacementPass::placeInstructionsForEdge(
+Pass::Status InvocationInterlockPlacementPass::placeInstructionsForEdge(
     BasicBlock* block, uint32_t next_id, BlockSet& inside,
     BlockSet& previous_inside, spv::Op opcode, bool reverse_cfg) {
   bool modified = false;
@@ -372,31 +376,45 @@ bool InvocationInterlockPlacementPass::placeInstructionsForEdge(
         new_branch = splitEdge(cfg()->block(next_id), block->id());
       }
 
+      if (!new_branch) {
+        return Status::Failure;
+      }
+
       auto inst = new Instruction(context(), opcode);
       inst->InsertBefore(&*new_branch->tail());
     }
   }
 
-  return modified;
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
-bool InvocationInterlockPlacementPass::placeInstructions(BasicBlock* block) {
-  bool modified = false;
+Pass::Status InvocationInterlockPlacementPass::placeInstructions(
+    BasicBlock* block) {
+  Status status = Status::SuccessWithoutChange;
 
-  block->ForEachSuccessorLabel([this, block, &modified](uint32_t succ_id) {
-    modified |= placeInstructionsForEdge(
+  block->ForEachSuccessorLabel([this, block, &status](uint32_t succ_id) {
+    if (status == Status::Failure) {
+      return;
+    }
+    Status edge_status = placeInstructionsForEdge(
         block, succ_id, after_begin_, predecessors_after_begin_,
         spv::Op::OpBeginInvocationInterlockEXT, /* reverse_cfg= */ true);
-    modified |= placeInstructionsForEdge(cfg()->block(succ_id), block->id(),
-                                         before_end_, successors_before_end_,
-                                         spv::Op::OpEndInvocationInterlockEXT,
-                                         /* reverse_cfg= */ false);
+    status = CombineStatus(status, edge_status);
+    if (status == Status::Failure) {
+      return;
+    }
+
+    edge_status = placeInstructionsForEdge(cfg()->block(succ_id), block->id(),
+                                           before_end_, successors_before_end_,
+                                           spv::Op::OpEndInvocationInterlockEXT,
+                                           /* reverse_cfg= */ false);
+    status = CombineStatus(status, edge_status);
   });
 
-  return modified;
+  return status;
 }
 
-bool InvocationInterlockPlacementPass::processFragmentShaderEntry(
+Pass::Status InvocationInterlockPlacementPass::processFragmentShaderEntry(
     Function* entry_func) {
   bool modified = false;
 
@@ -417,9 +435,15 @@ bool InvocationInterlockPlacementPass::processFragmentShaderEntry(
 
   for (BasicBlock* block : original_blocks) {
     modified |= removeUnneededInstructions(block);
-    modified |= placeInstructions(block);
+    Status place_status = placeInstructions(block);
+    if (place_status == Status::Failure) {
+      return Status::Failure;
+    }
+    if (place_status == Status::SuccessWithChange) {
+      modified = true;
+    }
   }
-  return modified;
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
 bool InvocationInterlockPlacementPass::isFragmentShaderInterlockEnabled() {
@@ -452,7 +476,7 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
     return Status::SuccessWithoutChange;
   }
 
-  bool modified = false;
+  Status status = Status::SuccessWithoutChange;
 
   std::unordered_set<Function*> entry_points;
   for (Instruction& entry_inst : context()->module()->entry_points()) {
@@ -466,7 +490,9 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
     Function* func = &*fi;
     recordBeginOrEndInFunction(func);
     if (!entry_points.count(func) && extracted_functions_.count(func)) {
-      modified |= removeBeginAndEndInstructionsFromFunction(func);
+      if (removeBeginAndEndInstructionsFromFunction(func)) {
+        status = Status::SuccessWithChange;
+      }
     }
   }
 
@@ -482,11 +508,14 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
       continue;
     }
 
-    modified |= processFragmentShaderEntry(entry_func);
+    Status frag_status = processFragmentShaderEntry(entry_func);
+    if (frag_status == Status::Failure) {
+      return Status::Failure;
+    }
+    status = CombineStatus(status, frag_status);
   }
 
-  return modified ? Pass::Status::SuccessWithChange
-                  : Pass::Status::SuccessWithoutChange;
+  return status;
 }
 
 }  // namespace opt

+ 5 - 5
3rdparty/spirv-tools/source/opt/invocation_interlock_placement_pass.h

@@ -120,14 +120,14 @@ class InvocationInterlockPlacementPass : public Pass {
   // For the edge from block to next_id, places a begin or end instruction on
   // the edge, based on the direction we are walking the CFG, specified in
   // reverse_cfg.
-  bool placeInstructionsForEdge(BasicBlock* block, uint32_t next_id,
-                                BlockSet& inside, BlockSet& previous_inside,
-                                spv::Op opcode, bool reverse_cfg);
+  Status placeInstructionsForEdge(BasicBlock* block, uint32_t next_id,
+                                  BlockSet& inside, BlockSet& previous_inside,
+                                  spv::Op opcode, bool reverse_cfg);
   // Calls placeInstructionsForEdge for each edge in block.
-  bool placeInstructions(BasicBlock* block);
+  Status placeInstructions(BasicBlock* block);
 
   // Processes a single fragment shader entry function.
-  bool processFragmentShaderEntry(Function* entry_func);
+  Status processFragmentShaderEntry(Function* entry_func);
 
   // Returns whether the module has the SPV_EXT_fragment_shader_interlock
   // extension and one of the FragmentShader*InterlockEXT capabilities.

+ 6 - 2
3rdparty/spirv-tools/source/opt/ir_context.cpp

@@ -201,7 +201,9 @@ Instruction* IRContext::KillInst(Instruction* inst) {
     constant_mgr_->RemoveId(inst->result_id());
   }
   if (inst->opcode() == spv::Op::OpCapability ||
-      inst->opcode() == spv::Op::OpExtension) {
+      inst->opcode() == spv::Op::OpConditionalCapabilityINTEL ||
+      inst->opcode() == spv::Op::OpExtension ||
+      inst->opcode() == spv::Op::OpConditionalExtensionINTEL) {
     // We reset the feature manager, instead of updating it, because it is just
     // as much work.  We would have to remove all capabilities implied by this
     // capability that are not also implied by the remaining OpCapability
@@ -382,6 +384,7 @@ bool IRContext::IsConsistent() {
     }
   }
 
+  return true;
   if (AreAnalysesValid(kAnalysisIdToFuncMapping)) {
     for (auto& fn : *module_) {
       if (id_to_func_[fn.result_id()] != &fn) {
@@ -398,8 +401,9 @@ bool IRContext::IsConsistent() {
                 return false;
               }
               return true;
-            }))
+            })) {
           return false;
+        }
       }
     }
   }

+ 4 - 2
3rdparty/spirv-tools/source/opt/ir_loader.cpp

@@ -181,9 +181,11 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) {
   } else {
     if (function_ == nullptr) {  // Outside function definition
       SPIRV_ASSERT(consumer_, block_ == nullptr);
-      if (opcode == spv::Op::OpCapability) {
+      if (opcode == spv::Op::OpCapability ||
+          opcode == spv::Op::OpConditionalCapabilityINTEL) {
         module_->AddCapability(std::move(spv_inst));
-      } else if (opcode == spv::Op::OpExtension) {
+      } else if (opcode == spv::Op::OpExtension ||
+                 opcode == spv::Op::OpConditionalExtensionINTEL) {
         module_->AddExtension(std::move(spv_inst));
       } else if (opcode == spv::Op::OpExtInstImport) {
         module_->AddExtInstImport(std::move(spv_inst));

+ 11 - 3
3rdparty/spirv-tools/source/opt/loop_fission.cpp

@@ -362,14 +362,19 @@ Loop* LoopFissionImpl::SplitLoop() {
   LoopUtils util{context_, loop_};
   LoopUtils::LoopCloningResult clone_results;
   Loop* cloned_loop = util.CloneAndAttachLoopToHeader(&clone_results);
+  if (!cloned_loop) {
+    return nullptr;
+  }
 
   // Update the OpLoopMerge in the cloned loop.
   cloned_loop->UpdateLoopMergeInst();
 
   // Add the loop_ to the module.
-  // TODO(1841): Handle failure to create pre-header.
-  Function::iterator it =
-      util.GetFunction()->FindBlock(loop_->GetOrCreatePreHeaderBlock()->id());
+  BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
+  if (!pre_header) {
+    return nullptr;
+  }
+  Function::iterator it = util.GetFunction()->FindBlock(pre_header->id());
   util.GetFunction()->AddBasicBlocks(clone_results.cloned_bb_.begin(),
                                      clone_results.cloned_bb_.end(), ++it);
   loop_->SetPreHeaderBlock(cloned_loop->GetMergeBlock());
@@ -478,6 +483,9 @@ Pass::Status LoopFissionPass::Process() {
 
         if (impl.CanPerformSplit()) {
           Loop* second_loop = impl.SplitLoop();
+          if (!second_loop) {
+            return Status::Failure;
+          }
           changed = true;
           context()->InvalidateAnalysesExceptFor(
               IRContext::kAnalysisLoopAnalysis);

+ 147 - 61
3rdparty/spirv-tools/source/opt/loop_peeling.cpp

@@ -45,7 +45,7 @@ void GetBlocksInPath(uint32_t block, uint32_t entry,
 
 size_t LoopPeelingPass::code_grow_threshold_ = 1000;
 
-void LoopPeeling::DuplicateAndConnectLoop(
+bool LoopPeeling::DuplicateAndConnectLoop(
     LoopUtils::LoopCloningResult* clone_results) {
   CFG& cfg = *context_->cfg();
   analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
@@ -53,12 +53,17 @@ void LoopPeeling::DuplicateAndConnectLoop(
   assert(CanPeelLoop() && "Cannot peel loop!");
 
   std::vector<BasicBlock*> ordered_loop_blocks;
-  // TODO(1841): Handle failure to create pre-header.
   BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
+  if (!pre_header) {
+    return false;
+  }
 
   loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks);
 
   cloned_loop_ = loop_utils_.CloneLoop(clone_results, ordered_loop_blocks);
+  if (!cloned_loop_) {
+    return false;
+  }
 
   // Add the basic block to the function.
   Function::iterator it =
@@ -146,17 +151,21 @@ void LoopPeeling::DuplicateAndConnectLoop(
 
   // Force the creation of a new preheader for the original loop and set it as
   // the merge block for the cloned loop.
-  // TODO(1841): Handle failure to create pre-header.
-  cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock());
+  BasicBlock* new_pre_header = loop_->GetOrCreatePreHeaderBlock();
+  if (!new_pre_header) {
+    return false;
+  }
+  cloned_loop_->SetMergeBlock(new_pre_header);
+  return true;
 }
 
-void LoopPeeling::InsertCanonicalInductionVariable(
+bool LoopPeeling::InsertCanonicalInductionVariable(
     LoopUtils::LoopCloningResult* clone_results) {
   if (original_loop_canonical_induction_variable_) {
     canonical_induction_variable_ =
         context_->get_def_use_mgr()->GetDef(clone_results->value_map_.at(
             original_loop_canonical_induction_variable_->result_id()));
-    return;
+    return true;
   }
 
   BasicBlock::iterator insert_point = GetClonedLoop()->GetLatchBlock()->tail();
@@ -168,19 +177,25 @@ void LoopPeeling::InsertCanonicalInductionVariable(
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
   Instruction* uint_1_cst =
       builder.GetIntConstant<uint32_t>(1, int_type_->IsSigned());
+  if (!uint_1_cst) return false;
   // Create the increment.
   // Note that we do "1 + 1" here, one of the operand should the phi
   // value but we don't have it yet. The operand will be set latter.
   Instruction* iv_inc = builder.AddIAdd(
       uint_1_cst->type_id(), uint_1_cst->result_id(), uint_1_cst->result_id());
+  if (!iv_inc) return false;
 
   builder.SetInsertPoint(&*GetClonedLoop()->GetHeaderBlock()->begin());
 
+  Instruction* initial_value =
+      builder.GetIntConstant<uint32_t>(0, int_type_->IsSigned());
+  if (!initial_value) return false;
+
   canonical_induction_variable_ = builder.AddPhi(
       uint_1_cst->type_id(),
-      {builder.GetIntConstant<uint32_t>(0, int_type_->IsSigned())->result_id(),
-       GetClonedLoop()->GetPreHeaderBlock()->id(), iv_inc->result_id(),
-       GetClonedLoop()->GetLatchBlock()->id()});
+      {initial_value->result_id(), GetClonedLoop()->GetPreHeaderBlock()->id(),
+       iv_inc->result_id(), GetClonedLoop()->GetLatchBlock()->id()});
+  if (!canonical_induction_variable_) return false;
   // Connect everything.
   iv_inc->SetInOperand(0, {canonical_induction_variable_->result_id()});
 
@@ -191,6 +206,7 @@ void LoopPeeling::InsertCanonicalInductionVariable(
   if (do_while_form_) {
     canonical_induction_variable_ = iv_inc;
   }
+  return true;
 }
 
 void LoopPeeling::GetIteratorUpdateOperations(
@@ -308,7 +324,7 @@ void LoopPeeling::GetIteratingExitValues() {
   }
 }
 
-void LoopPeeling::FixExitCondition(
+bool LoopPeeling::FixExitCondition(
     const std::function<uint32_t(Instruction*)>& condition_builder) {
   CFG& cfg = *context_->cfg();
 
@@ -329,7 +345,11 @@ void LoopPeeling::FixExitCondition(
     --insert_point;
   }
 
-  exit_condition->SetInOperand(0, {condition_builder(&*insert_point)});
+  uint32_t new_cond_id = condition_builder(&*insert_point);
+  if (new_cond_id == 0) {
+    return false;
+  }
+  exit_condition->SetInOperand(0, {new_cond_id});
 
   uint32_t to_continue_block_idx =
       GetClonedLoop()->IsInsideLoop(exit_condition->GetSingleWordInOperand(1))
@@ -341,6 +361,7 @@ void LoopPeeling::FixExitCondition(
 
   // Update def/use manager.
   context_->get_def_use_mgr()->AnalyzeInstUse(exit_condition);
+  return true;
 }
 
 BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
@@ -348,10 +369,13 @@ BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
   CFG& cfg = *context_->cfg();
   assert(cfg.preds(bb->id()).size() == 1 && "More than one predecessor");
 
-  // TODO(1841): Handle id overflow.
+  uint32_t new_id = context_->TakeNextId();
+  if (new_id == 0) {
+    return nullptr;
+  }
   std::unique_ptr<BasicBlock> new_bb =
-      MakeUnique<BasicBlock>(std::unique_ptr<Instruction>(new Instruction(
-          context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})));
+      MakeUnique<BasicBlock>(std::unique_ptr<Instruction>(
+          new Instruction(context_, spv::Op::OpLabel, 0, new_id, {})));
   // Update the loop descriptor.
   Loop* in_loop = (*loop_utils_.GetLoopDescriptor())[bb];
   if (in_loop) {
@@ -394,8 +418,10 @@ BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
 
 BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition,
                                      BasicBlock* if_merge) {
-  // TODO(1841): Handle failure to create pre-header.
   BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock();
+  if (!if_block) {
+    return nullptr;
+  }
   // Will no longer be a pre-header because of the if.
   loop->SetPreHeaderBlock(nullptr);
   // Kill the branch to the header.
@@ -411,48 +437,63 @@ BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition,
   return if_block;
 }
 
-void LoopPeeling::PeelBefore(uint32_t peel_factor) {
+bool LoopPeeling::PeelBefore(uint32_t peel_factor) {
   assert(CanPeelLoop() && "Cannot peel loop");
   LoopUtils::LoopCloningResult clone_results;
 
   // Clone the loop and insert the cloned one before the loop.
-  DuplicateAndConnectLoop(&clone_results);
+  if (!DuplicateAndConnectLoop(&clone_results)) {
+    return false;
+  }
 
   // Add a canonical induction variable "canonical_induction_variable_".
-  InsertCanonicalInductionVariable(&clone_results);
+  if (!InsertCanonicalInductionVariable(&clone_results)) {
+    return false;
+  }
 
   InstructionBuilder builder(
       context_, &*cloned_loop_->GetPreHeaderBlock()->tail(),
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
   Instruction* factor =
       builder.GetIntConstant(peel_factor, int_type_->IsSigned());
+  if (!factor) return false;
 
   Instruction* has_remaining_iteration = builder.AddLessThan(
       factor->result_id(), loop_iteration_count_->result_id());
+  if (!has_remaining_iteration) return false;
   Instruction* max_iteration = builder.AddSelect(
       factor->type_id(), has_remaining_iteration->result_id(),
       factor->result_id(), loop_iteration_count_->result_id());
+  if (!max_iteration) return false;
 
   // Change the exit condition of the cloned loop to be (exit when become
   // false):
   //  "canonical_induction_variable_" < min("factor", "loop_iteration_count_")
-  FixExitCondition([max_iteration, this](Instruction* insert_before_point) {
-    return InstructionBuilder(context_, insert_before_point,
-                              IRContext::kAnalysisDefUse |
-                                  IRContext::kAnalysisInstrToBlockMapping)
-        .AddLessThan(canonical_induction_variable_->result_id(),
-                     max_iteration->result_id())
-        ->result_id();
-  });
+  if (!FixExitCondition(
+          [max_iteration, this](Instruction* insert_before_point) {
+            Instruction* new_cond =
+                InstructionBuilder(context_, insert_before_point,
+                                   IRContext::kAnalysisDefUse |
+                                       IRContext::kAnalysisInstrToBlockMapping)
+                    .AddLessThan(canonical_induction_variable_->result_id(),
+                                 max_iteration->result_id());
+            return new_cond ? new_cond->result_id() : 0;
+          })) {
+    return false;
+  }
 
   // "Protect" the second loop: the second loop can only be executed if
   // |has_remaining_iteration| is true (i.e. factor < loop_iteration_count_).
   BasicBlock* if_merge_block = loop_->GetMergeBlock();
-  loop_->SetMergeBlock(CreateBlockBefore(loop_->GetMergeBlock()));
+  BasicBlock* new_merge_block = CreateBlockBefore(loop_->GetMergeBlock());
+  if (!new_merge_block) return false;
+  loop_->SetMergeBlock(new_merge_block);
   // Prevent the second loop from being executed if we already executed all the
   // required iterations.
   BasicBlock* if_block =
       ProtectLoop(loop_, has_remaining_iteration, if_merge_block);
+  if (!if_block) return false;
+
   // Patch the phi of the merge block.
   if_merge_block->ForEachPhiInst(
       [&clone_results, if_block, this](Instruction* phi) {
@@ -471,14 +512,17 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) {
   context_->InvalidateAnalysesExceptFor(
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
       IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG);
+  return true;
 }
 
-void LoopPeeling::PeelAfter(uint32_t peel_factor) {
+bool LoopPeeling::PeelAfter(uint32_t peel_factor) {
   assert(CanPeelLoop() && "Cannot peel loop");
   LoopUtils::LoopCloningResult clone_results;
 
   // Clone the loop and insert the cloned one before the loop.
-  DuplicateAndConnectLoop(&clone_results);
+  if (!DuplicateAndConnectLoop(&clone_results)) {
+    return false;
+  }
 
   // Add a canonical induction variable "canonical_induction_variable_".
   InsertCanonicalInductionVariable(&clone_results);
@@ -488,28 +532,33 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
   Instruction* factor =
       builder.GetIntConstant(peel_factor, int_type_->IsSigned());
+  if (!factor) return false;
 
   Instruction* has_remaining_iteration = builder.AddLessThan(
       factor->result_id(), loop_iteration_count_->result_id());
+  if (!has_remaining_iteration) return false;
 
   // Change the exit condition of the cloned loop to be (exit when become
   // false):
   //  "canonical_induction_variable_" + "factor" < "loop_iteration_count_"
-  FixExitCondition([factor, this](Instruction* insert_before_point) {
-    InstructionBuilder cond_builder(
-        context_, insert_before_point,
-        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-    // Build the following check: canonical_induction_variable_ + factor <
-    // iteration_count
-    return cond_builder
-        .AddLessThan(cond_builder
-                         .AddIAdd(canonical_induction_variable_->type_id(),
-                                  canonical_induction_variable_->result_id(),
-                                  factor->result_id())
-                         ->result_id(),
-                     loop_iteration_count_->result_id())
-        ->result_id();
-  });
+  if (!FixExitCondition([factor,
+                         this](Instruction* insert_before_point) -> uint32_t {
+        InstructionBuilder cond_builder(
+            context_, insert_before_point,
+            IRContext::kAnalysisDefUse |
+                IRContext::kAnalysisInstrToBlockMapping);
+        // Build the following check: canonical_induction_variable_ + factor <
+        // iteration_count
+        Instruction* add = cond_builder.AddIAdd(
+            canonical_induction_variable_->type_id(),
+            canonical_induction_variable_->result_id(), factor->result_id());
+        if (!add) return 0;
+        Instruction* new_cond = cond_builder.AddLessThan(
+            add->result_id(), loop_iteration_count_->result_id());
+        return new_cond ? new_cond->result_id() : 0;
+      })) {
+    return false;
+  }
 
   // "Protect" the first loop: the first loop can only be executed if
   // factor < loop_iteration_count_.
@@ -517,11 +566,17 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
   // The original loop's pre-header was the cloned loop merge block.
   GetClonedLoop()->SetMergeBlock(
       CreateBlockBefore(GetOriginalLoop()->GetPreHeaderBlock()));
+  if (!GetClonedLoop()->GetMergeBlock()) {
+    return false;
+  }
   // Use the second loop preheader as if merge block.
 
   // Prevent the first loop if only the peeled loop needs it.
   BasicBlock* if_block = ProtectLoop(cloned_loop_, has_remaining_iteration,
                                      GetOriginalLoop()->GetPreHeaderBlock());
+  if (!if_block) {
+    return false;
+  }
 
   // Patch the phi of the header block.
   // We added an if to enclose the first loop and because the phi node are
@@ -529,8 +584,10 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
   // dominate the preheader.
   // We had to the preheader (our if merge block) the required phi instruction
   // and patch the header phi.
+  bool ok = true;
   GetOriginalLoop()->GetHeaderBlock()->ForEachPhiInst(
-      [&clone_results, if_block, this](Instruction* phi) {
+      [&clone_results, if_block, &ok, this](Instruction* phi) {
+        if (!ok) return;
         analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
 
         auto find_value_idx = [](Instruction* phi_inst, Loop* loop) {
@@ -554,15 +611,21 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
                              find_value_idx(phi, GetOriginalLoop())),
                          GetClonedLoop()->GetMergeBlock()->id(),
                          cloned_preheader_value, if_block->id()});
+        if (!new_phi) {
+          ok = false;
+          return;
+        }
 
         phi->SetInOperand(find_value_idx(phi, GetOriginalLoop()),
                           {new_phi->result_id()});
         def_use_mgr->AnalyzeInstUse(phi);
       });
+  if (!ok) return false;
 
   context_->InvalidateAnalysesExceptFor(
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
       IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG);
+  return true;
 }
 
 Pass::Status LoopPeelingPass::Process() {
@@ -571,13 +634,19 @@ Pass::Status LoopPeelingPass::Process() {
 
   // Process each function in the module
   for (Function& f : *module) {
-    modified |= ProcessFunction(&f);
+    Pass::Status status = ProcessFunction(&f);
+    if (status == Status::Failure) {
+      return Status::Failure;
+    }
+    if (status == Status::SuccessWithChange) {
+      modified = true;
+    }
   }
 
   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
-bool LoopPeelingPass::ProcessFunction(Function* f) {
+Pass::Status LoopPeelingPass::ProcessFunction(Function* f) {
   bool modified = false;
   LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
 
@@ -593,41 +662,54 @@ bool LoopPeelingPass::ProcessFunction(Function* f) {
     CodeMetrics loop_size;
     loop_size.Analyze(*loop);
 
-    auto try_peel = [&loop_size, &modified, this](Loop* loop_to_peel) -> Loop* {
+    auto try_peel = [&loop_size, &modified, this](
+                        Loop* loop_to_peel) -> std::pair<Pass::Status, Loop*> {
       if (!loop_to_peel->IsLCSSA()) {
         LoopUtils(context(), loop_to_peel).MakeLoopClosedSSA();
       }
 
-      bool peeled_loop;
+      Pass::Status status;
       Loop* still_peelable_loop;
-      std::tie(peeled_loop, still_peelable_loop) =
+      std::tie(status, still_peelable_loop) =
           ProcessLoop(loop_to_peel, &loop_size);
 
-      if (peeled_loop) {
+      if (status == Pass::Status::SuccessWithChange) {
         modified = true;
       }
 
-      return still_peelable_loop;
+      return {status, still_peelable_loop};
     };
 
-    Loop* still_peelable_loop = try_peel(loop);
+    Pass::Status status;
+    Loop* still_peelable_loop;
+    std::tie(status, still_peelable_loop) = try_peel(loop);
+
+    if (status == Pass::Status::Failure) {
+      return Pass::Status::Failure;
+    }
+
     // The pass is working out the maximum factor by which a loop can be peeled.
     // If the loop can potentially be peeled again, then there is only one
     // possible direction, so only one call is still needed.
     if (still_peelable_loop) {
-      try_peel(loop);
+      std::tie(status, still_peelable_loop) = try_peel(still_peelable_loop);
+      if (status == Pass::Status::Failure) {
+        return Pass::Status::Failure;
+      }
     }
   }
 
-  return modified;
+  return modified ? Pass::Status::SuccessWithChange
+                  : Pass::Status::SuccessWithoutChange;
 }
 
-std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
-                                                    CodeMetrics* loop_size) {
+std::tuple<Pass::Status, Loop*> LoopPeelingPass::ProcessLoop(
+    Loop* loop, CodeMetrics* loop_size) {
   ScalarEvolutionAnalysis* scev_analysis =
       context()->GetScalarEvolutionAnalysis();
   // Default values for bailing out.
-  std::pair<bool, Loop*> bail_out{false, nullptr};
+  std::tuple<Pass::Status, Loop*> bail_out{Pass::Status::SuccessWithoutChange,
+                                           nullptr};
 
   BasicBlock* exit_block = loop->FindConditionBlock();
   if (!exit_block) {
@@ -744,7 +826,9 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
   Loop* extra_opportunity = nullptr;
 
   if (direction == PeelDirection::kBefore) {
-    peeler.PeelBefore(factor);
+    if (!peeler.PeelBefore(factor)) {
+      return {Pass::Status::Failure, nullptr};
+    }
     if (stats_) {
       stats_->peeled_loops_.emplace_back(loop, PeelDirection::kBefore, factor);
     }
@@ -753,7 +837,9 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
       extra_opportunity = peeler.GetOriginalLoop();
     }
   } else {
-    peeler.PeelAfter(factor);
+    if (!peeler.PeelAfter(factor)) {
+      return {Pass::Status::Failure, nullptr};
+    }
     if (stats_) {
       stats_->peeled_loops_.emplace_back(loop, PeelDirection::kAfter, factor);
     }
@@ -763,7 +849,7 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
     }
   }
 
-  return {true, extra_opportunity};
+  return {Pass::Status::SuccessWithChange, extra_opportunity};
 }
 
 uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstLoopInvariantOperand(

+ 15 - 11
3rdparty/spirv-tools/source/opt/loop_peeling.h

@@ -148,11 +148,11 @@ class LoopPeeling {
 
   // Moves the execution of the |factor| first iterations of the loop into a
   // dedicated loop.
-  void PeelBefore(uint32_t factor);
+  bool PeelBefore(uint32_t factor);
 
   // Moves the execution of the |factor| last iterations of the loop into a
   // dedicated loop.
-  void PeelAfter(uint32_t factor);
+  bool PeelAfter(uint32_t factor);
 
   // Returns the cloned loop.
   Loop* GetClonedLoop() { return cloned_loop_; }
@@ -184,19 +184,19 @@ class LoopPeeling {
   // Duplicate |loop_| and place the new loop before the cloned loop. Iterating
   // values from the cloned loop are then connected to the original loop as
   // initializer.
-  void DuplicateAndConnectLoop(LoopUtils::LoopCloningResult* clone_results);
+  bool DuplicateAndConnectLoop(LoopUtils::LoopCloningResult* clone_results);
 
   // Insert the canonical induction variable into the first loop as a simplified
-  // counter.
-  void InsertCanonicalInductionVariable(
+  // counter. Returns true on success.
+  bool InsertCanonicalInductionVariable(
       LoopUtils::LoopCloningResult* clone_results);
 
   // Fixes the exit condition of the before loop. The function calls
   // |condition_builder| to get the condition to use in the conditional branch
   // of the loop exit. The loop will be exited if the condition evaluate to
   // true. |condition_builder| takes an Instruction* that represent the
-  // insertion point.
-  void FixExitCondition(
+  // insertion point. Returns true on success.
+  bool FixExitCondition(
       const std::function<uint32_t(Instruction*)>& condition_builder);
 
   // Gathers all operations involved in the update of |iterator| into
@@ -321,10 +321,14 @@ class LoopPeelingPass : public Pass {
     ScalarEvolutionAnalysis* scev_analysis_;
     size_t loop_max_iterations_;
   };
-  // Peel profitable loops in |f|.
-  bool ProcessFunction(Function* f);
-  // Peel |loop| if profitable.
-  std::pair<bool, Loop*> ProcessLoop(Loop* loop, CodeMetrics* loop_size);
+  // Peel profitable loops in |f|. Returns Pass::Status::Failure if an error
+  // occurs.
+  Pass::Status ProcessFunction(Function* f);
+  // Peel |loop| if profitable. Returns Pass::Status::Failure if an error
+  // occurs. Returns {Pass::Status::SuccessWithChange, Loop*} if the loop is
+  // peeled and there is another peeling opportunity.
+  std::tuple<Pass::Status, Loop*> ProcessLoop(Loop* loop,
+                                              CodeMetrics* loop_size);
 
   static size_t code_grow_threshold_;
   LoopPeelingStats* stats_;

+ 42 - 14
3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp

@@ -92,12 +92,16 @@ class LoopUnswitch {
   // position |ip|. This function preserves the def/use and instr to block
   // managers.
   BasicBlock* CreateBasicBlock(Function::iterator ip) {
+    uint32_t new_label_id = TakeNextId();
+    if (new_label_id == 0) {
+      return nullptr;
+    }
+
     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
 
-    // TODO(1841): Handle id overflow.
     BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
-            context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})))));
+            context_, spv::Op::OpLabel, 0, new_label_id, {})))));
     bb->SetParent(function_);
     def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
     context_->set_instr_block(bb->GetLabelInst(), bb);
@@ -135,7 +139,7 @@ class LoopUnswitch {
   }
 
   // Unswitches |loop_|.
-  void PerformUnswitch() {
+  bool PerformUnswitch() {
     assert(CanUnswitchLoop() &&
            "Cannot unswitch if there is not constant condition");
     assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
@@ -165,6 +169,9 @@ class LoopUnswitch {
         if_merge_block
             ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
             : nullptr;
+    if (if_merge_block && !loop_merge_block) {
+      return false;
+    }
     if (loop_merge_block) {
       // Add the instruction and update managers.
       InstructionBuilder builder(
@@ -174,17 +181,24 @@ class LoopUnswitch {
       builder.SetInsertPoint(&*loop_merge_block->begin());
       cfg.RegisterBlock(loop_merge_block);
       def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
-      // Update CFG.
+      bool ok = true;
       if_merge_block->ForEachPhiInst(
-          [loop_merge_block, &builder, this](Instruction* phi) {
+          [loop_merge_block, &ok, &builder, this](Instruction* phi) -> bool {
             Instruction* cloned = phi->Clone(context_);
-            cloned->SetResultId(TakeNextId());
+            uint32_t new_id = TakeNextId();
+            if (new_id == 0) {
+              ok = false;
+              return false;
+            }
+            cloned->SetResultId(new_id);
             builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
             phi->SetInOperand(0, {cloned->result_id()});
             phi->SetInOperand(1, {loop_merge_block->id()});
             for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
               phi->RemoveInOperand(j);
+            return true;
           });
+      if (!ok) return false;
       // Copy the predecessor list (will get invalidated otherwise).
       std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
       for (uint32_t pid : preds) {
@@ -227,6 +241,9 @@ class LoopUnswitch {
     // we need to create a dedicated block for the if.
     BasicBlock* loop_pre_header =
         CreateBasicBlock(++FindBasicBlockPosition(if_block));
+    if (!loop_pre_header) {
+      return false;
+    }
     InstructionBuilder(
         context_, loop_pre_header,
         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
@@ -308,6 +325,12 @@ class LoopUnswitch {
       // specific value.
       original_loop_constant_value =
           GetValueForDefaultPathForSwitch(iv_condition);
+      if (!original_loop_constant_value) {
+        return false;
+      }
+      if (!original_loop_constant_value) {
+        return false;
+      }
 
       for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
         constant_branch.emplace_back(
@@ -341,6 +364,9 @@ class LoopUnswitch {
 
       Loop* cloned_loop =
           loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
+      if (!cloned_loop) {
+        return false;
+      }
       specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
 
       ////////////////////////////////////
@@ -416,6 +442,7 @@ class LoopUnswitch {
 
     context_->InvalidateAnalysesExceptFor(
         IRContext::Analysis::kAnalysisLoopAnalysis);
+    return true;
   }
 
  private:
@@ -434,10 +461,7 @@ class LoopUnswitch {
   std::vector<BasicBlock*> ordered_loop_blocks_;
 
   // Returns the next usable id for the context.
-  uint32_t TakeNextId() {
-    // TODO(1841): Handle id overflow.
-    return context_->TakeNextId();
-  }
+  uint32_t TakeNextId() { return context_->TakeNextId(); }
 
   // Simplifies |loop| assuming the instruction |to_version_insn| takes the
   // value |cst_value|. |block_range| is an iterator range returning the loop
@@ -573,13 +597,15 @@ Pass::Status LoopUnswitchPass::Process() {
 
   // Process each function in the module
   for (Function& f : *module) {
-    modified |= ProcessFunction(&f);
+    Pass::Status status = ProcessFunction(&f);
+    if (status == Status::Failure) return Status::Failure;
+    if (status == Status::SuccessWithChange) modified = true;
   }
 
   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
-bool LoopUnswitchPass::ProcessFunction(Function* f) {
+Pass::Status LoopUnswitchPass::ProcessFunction(Function* f) {
   bool modified = false;
   std::unordered_set<Loop*> processed_loop;
 
@@ -599,15 +625,17 @@ bool LoopUnswitchPass::ProcessFunction(Function* f) {
         if (!loop.IsLCSSA()) {
           LoopUtils(context(), &loop).MakeLoopClosedSSA();
         }
+        if (!unswitcher.PerformUnswitch()) {
+          return Status::Failure;
+        }
         modified = true;
         loop_changed = true;
-        unswitcher.PerformUnswitch();
       }
       if (loop_changed) break;
     }
   }
 
-  return modified;
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
 }  // namespace opt

+ 2 - 1
3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h

@@ -34,7 +34,8 @@ class LoopUnswitchPass : public Pass {
   Pass::Status Process() override;
 
  private:
-  bool ProcessFunction(Function* f);
+  // Process the given function.
+  Pass::Status ProcessFunction(Function* f);
 };
 
 }  // namespace opt

+ 29 - 15
3rdparty/spirv-tools/source/opt/loop_utils.cpp

@@ -488,12 +488,18 @@ Loop* LoopUtils::CloneLoop(LoopCloningResult* cloning_result) const {
 
 Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
   // Clone the loop.
-  Loop* new_loop = CloneLoop(cloning_result);
+  Loop* cloned_loop = CloneLoop(cloning_result);
+  if (!cloned_loop) {
+    return nullptr;
+  }
 
   // Create a new exit block/label for the new loop.
-  // TODO(1841): Handle id overflow.
-  std::unique_ptr<Instruction> new_label{new Instruction(
-      context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})};
+  uint32_t new_label_id = context_->TakeNextId();
+  if (new_label_id == 0) {
+    return nullptr;
+  }
+  std::unique_ptr<Instruction> new_label{
+      new Instruction(context_, spv::Op::OpLabel, 0, new_label_id, {})};
   std::unique_ptr<BasicBlock> new_exit_bb{new BasicBlock(std::move(new_label))};
   new_exit_bb->SetParent(loop_->GetMergeBlock()->GetParent());
 
@@ -520,7 +526,7 @@ Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
   }
 
   const uint32_t old_header = loop_->GetHeaderBlock()->id();
-  const uint32_t new_header = new_loop->GetHeaderBlock()->id();
+  const uint32_t new_header = cloned_loop->GetHeaderBlock()->id();
   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
 
   def_use->ForEachUse(old_header,
@@ -529,22 +535,24 @@ Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
                           inst->SetOperand(operand, {new_header});
                       });
 
-  // TODO(1841): Handle failure to create pre-header.
+  BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
+  if (!pre_header) {
+    return nullptr;
+  }
   def_use->ForEachUse(
-      loop_->GetOrCreatePreHeaderBlock()->id(),
+      pre_header->id(),
       [new_merge_block, this](Instruction* inst, uint32_t operand) {
         if (this->loop_->IsInsideLoop(inst))
           inst->SetOperand(operand, {new_merge_block});
-
       });
-  new_loop->SetMergeBlock(new_exit_bb.get());
+  cloned_loop->SetMergeBlock(new_exit_bb.get());
 
-  new_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock());
+  cloned_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock());
 
   // Add the new block into the cloned instructions.
   cloning_result->cloned_bb_.push_back(std::move(new_exit_bb));
 
-  return new_loop;
+  return cloned_loop;
 }
 
 Loop* LoopUtils::CloneLoop(
@@ -562,8 +570,11 @@ Loop* LoopUtils::CloneLoop(
     // between old and new ids.
     BasicBlock* new_bb = old_bb->Clone(context_);
     new_bb->SetParent(&function_);
-    // TODO(1841): Handle id overflow.
-    new_bb->GetLabelInst()->SetResultId(context_->TakeNextId());
+    uint32_t new_label_id = context_->TakeNextId();
+    if (new_label_id == 0) {
+      return nullptr;
+    }
+    new_bb->GetLabelInst()->SetResultId(new_label_id);
     def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst());
     context_->set_instr_block(new_bb->GetLabelInst(), new_bb);
     cloning_result->cloned_bb_.emplace_back(new_bb);
@@ -578,8 +589,11 @@ Loop* LoopUtils::CloneLoop(
          new_inst != new_bb->end(); ++new_inst, ++old_inst) {
       cloning_result->ptr_map_[&*new_inst] = &*old_inst;
       if (new_inst->HasResultId()) {
-        // TODO(1841): Handle id overflow.
-        new_inst->SetResultId(context_->TakeNextId());
+        uint32_t new_result_id = context_->TakeNextId();
+        if (new_result_id == 0) {
+          return nullptr;
+        }
+        new_inst->SetResultId(new_result_id);
         cloning_result->value_map_[old_inst->result_id()] =
             new_inst->result_id();
 

+ 2 - 0
3rdparty/spirv-tools/source/opt/loop_utils.h

@@ -114,6 +114,7 @@ class LoopUtils {
   // The function preserves the def/use, cfg and instr to block analyses.
   // The cloned loop nest will be added to the loop descriptor and will have
   // ownership.
+  // Returns the cloned loop, or nullptr if the loop could not be cloned.
   Loop* CloneLoop(LoopCloningResult* cloning_result,
                   const std::vector<BasicBlock*>& ordered_loop_blocks) const;
   // Clone |loop_| and remap its instructions, as above. Overload to compute
@@ -121,6 +122,7 @@ class LoopUtils {
   Loop* CloneLoop(LoopCloningResult* cloning_result) const;
 
   // Clone the |loop_| and make the new loop branch to the second loop on exit.
+  // Returns the cloned loop, or nullptr if the loop could not be cloned.
   Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result);
 
   // Perform a partial unroll of |loop| by given |factor|. This will copy the

+ 44 - 10
3rdparty/spirv-tools/source/opt/merge_return_pass.cpp

@@ -58,7 +58,9 @@ Pass::Status MergeReturnPass::Process() {
         failed = true;
       }
     } else {
-      MergeReturnBlocks(function, return_blocks);
+      if (!MergeReturnBlocks(function, return_blocks)) {
+        failed = true;
+      }
     }
     return true;
   };
@@ -171,10 +173,14 @@ bool MergeReturnPass::ProcessStructured(
   return true;
 }
 
-void MergeReturnPass::CreateReturnBlock() {
+bool MergeReturnPass::CreateReturnBlock() {
   // Create a label for the new return block
+  uint32_t label_id = TakeNextId();
+  if (label_id == 0) {
+    return false;
+  }
   std::unique_ptr<Instruction> return_label(
-      new Instruction(context(), spv::Op::OpLabel, 0u, TakeNextId(), {}));
+      new Instruction(context(), spv::Op::OpLabel, 0u, label_id, {}));
 
   // Create the new basic block
   std::unique_ptr<BasicBlock> return_block(
@@ -186,14 +192,18 @@ void MergeReturnPass::CreateReturnBlock() {
                              final_return_block_);
   assert(final_return_block_->GetParent() == function_ &&
          "The function should have been set when the block was created.");
+  return true;
 }
 
-void MergeReturnPass::CreateReturn(BasicBlock* block) {
+bool MergeReturnPass::CreateReturn(BasicBlock* block) {
   AddReturnValue();
 
   if (return_value_) {
     // Load and return the final return value
     uint32_t loadId = TakeNextId();
+    if (loadId == 0) {
+      return false;
+    }
     block->AddInstruction(MakeUnique<Instruction>(
         context(), spv::Op::OpLoad, function_->type_id(), loadId,
         std::initializer_list<Operand>{
@@ -216,6 +226,7 @@ void MergeReturnPass::CreateReturn(BasicBlock* block) {
     context()->AnalyzeDefUse(block->terminator());
     context()->set_instr_block(block->terminator(), block);
   }
+  return true;
 }
 
 void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
@@ -663,14 +674,16 @@ std::vector<BasicBlock*> MergeReturnPass::CollectReturnBlocks(
   return return_blocks;
 }
 
-void MergeReturnPass::MergeReturnBlocks(
+bool MergeReturnPass::MergeReturnBlocks(
     Function* function, const std::vector<BasicBlock*>& return_blocks) {
   if (return_blocks.size() <= 1) {
     // No work to do.
-    return;
+    return true;
   }
 
-  CreateReturnBlock();
+  if (!CreateReturnBlock()) {
+    return false;
+  }
   uint32_t return_id = final_return_block_->id();
   auto ret_block_iter = --function->end();
   // Create the PHI for the merged block (if necessary).
@@ -687,6 +700,9 @@ void MergeReturnPass::MergeReturnBlocks(
   if (!phi_ops.empty()) {
     // Need a PHI node to select the correct return value.
     uint32_t phi_result_id = TakeNextId();
+    if (phi_result_id == 0) {
+      return false;
+    }
     uint32_t phi_type_id = function->type_id();
     std::unique_ptr<Instruction> phi_inst(new Instruction(
         context(), spv::Op::OpPhi, phi_type_id, phi_result_id, phi_ops));
@@ -718,6 +734,7 @@ void MergeReturnPass::MergeReturnBlocks(
   }
 
   get_def_use_mgr()->AnalyzeInstDefUse(ret_block_iter->GetLabelInst());
+  return true;
 }
 
 void MergeReturnPass::AddNewPhiNodes() {
@@ -781,8 +798,12 @@ void MergeReturnPass::InsertAfterElement(BasicBlock* element,
 }
 
 bool MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
-  CreateReturnBlock();
-  CreateReturn(final_return_block_);
+  if (!CreateReturnBlock()) {
+    return false;
+  }
+  if (!CreateReturn(final_return_block_)) {
+    return false;
+  }
 
   if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) {
     cfg()->RegisterBlock(final_return_block_);
@@ -828,7 +849,8 @@ BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
 
 bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
   // Insert the switch before any code is run.  We have to split the entry
-  // block to make sure the OpVariable instructions remain in the entry block.
+  // block to make sure the OpVariable instructions and DebugFunctionDefinition
+  // instructions remain in the entry block.
   BasicBlock* start_block = &*function_->begin();
   auto split_pos = start_block->begin();
   while (split_pos->opcode() == spv::Op::OpVariable) {
@@ -838,6 +860,18 @@ bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
   BasicBlock* old_block =
       start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
 
+  // Find DebugFunctionDefinition inst in the old block, and if we can find it,
+  // move it to the entry block. Since DebugFunctionDefinition is not necessary
+  // after OpVariable inst, we have to traverse the whole block to find it.
+  for (auto pos = old_block->begin(); pos != old_block->end(); ++pos) {
+    if (pos->GetShader100DebugOpcode() ==
+        NonSemanticShaderDebugInfo100DebugFunctionDefinition) {
+      start_block->AddInstruction(MakeUnique<Instruction>(*pos));
+      pos.Erase();
+      break;
+    }
+  }
+
   // Add the switch to the end of the entry block.
   InstructionBuilder builder(
       context(), start_block,

+ 7 - 5
3rdparty/spirv-tools/source/opt/merge_return_pass.h

@@ -149,8 +149,9 @@ class MergeReturnPass : public MemPass {
 
   // Creates a new basic block with a single return. If |function| returns a
   // value, a phi node is created to select the correct value to return.
-  // Replaces old returns with an unconditional branch to the new block.
-  void MergeReturnBlocks(Function* function,
+  // Replaces old returns with an unconditional branch to the new block. Returns
+  // true if successful.
+  bool MergeReturnBlocks(Function* function,
                          const std::vector<BasicBlock*>& returnBlocks);
 
   // Generate and push new control flow state if |block| contains a merge.
@@ -231,11 +232,12 @@ class MergeReturnPass : public MemPass {
 
   // Add an |OpReturn| or |OpReturnValue| to the end of |block|.  If an
   // |OpReturnValue| is needed, the return value is loaded from |return_value_|.
-  void CreateReturn(BasicBlock* block);
+  // Returns true if successful.
+  bool CreateReturn(BasicBlock* block);
 
   // Creates a block at the end of the function that will become the single
   // return block at the end of the pass.
-  void CreateReturnBlock();
+  bool CreateReturnBlock();
 
   // Creates a Phi node in |merge_block| for the result of |inst|.
   // Any uses of the result of |inst| that are no longer
@@ -332,4 +334,4 @@ class MergeReturnPass : public MemPass {
 }  // namespace opt
 }  // namespace spvtools
 
-#endif  // SOURCE_OPT_MERGE_RETURN_PASS_H_
+#endif  // SOURCE_OPT_MERGE_RETURN_PASS_H_

+ 7 - 0
3rdparty/spirv-tools/source/opt/optimizer.cpp

@@ -641,6 +641,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
     RegisterPass(CreateSplitCombinedImageSamplerPass());
   } else if (pass_name == "resolve-binding-conflicts") {
     RegisterPass(CreateResolveBindingConflictsPass());
+  } else if (pass_name == "canonicalize-ids") {
+    RegisterPass(CreateCanonicalizeIdsPass());
   } else {
     Errorf(consumer(), nullptr, {},
            "Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -1202,6 +1204,11 @@ Optimizer::PassToken CreateResolveBindingConflictsPass() {
       MakeUnique<opt::ResolveBindingConflictsPass>());
 }
 
+Optimizer::PassToken CreateCanonicalizeIdsPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::CanonicalizeIdsPass>());
+}
+
 }  // namespace spvtools
 
 extern "C" {

+ 1 - 0
3rdparty/spirv-tools/source/opt/passes.h

@@ -21,6 +21,7 @@
 #include "source/opt/amd_ext_to_khr.h"
 #include "source/opt/analyze_live_input_pass.h"
 #include "source/opt/block_merge_pass.h"
+#include "source/opt/canonicalize_ids_pass.h"
 #include "source/opt/ccp_pass.h"
 #include "source/opt/cfg_cleanup_pass.h"
 #include "source/opt/code_sink.h"

+ 52 - 5
3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp

@@ -29,6 +29,7 @@ namespace opt {
 
 Pass::Status RemoveDuplicatesPass::Process() {
   bool modified = RemoveDuplicateCapabilities();
+  modified |= RemoveDuplicateExtensions();
   modified |= RemoveDuplicatesExtInstImports();
   modified |= RemoveDuplicateTypes();
   modified |= RemoveDuplicateDecorations();
@@ -36,6 +37,41 @@ Pass::Status RemoveDuplicatesPass::Process() {
   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
+bool RemoveDuplicatesPass::RemoveDuplicateExtensions() const {
+  bool modified = false;
+
+  if (context()->extensions().empty()) {
+    return modified;
+  }
+
+  // set of {condition ID, extension name}
+  // ID 0 means unconditional extension, ie., OpExtension, otherwise the ID is
+  // the condition operand of OpConditionalExtensionINTEL.
+  std::set<std::pair<uint32_t, std::string>> extensions;
+  for (auto* inst = &*context()->extension_begin(); inst;) {
+    uint32_t cond_id = 0;
+    uint32_t i_name = 0;
+    if (inst->opcode() == spv::Op::OpConditionalExtensionINTEL) {
+      cond_id = inst->GetOperand(0).AsId();
+      i_name = 1;
+    }
+
+    auto res =
+        extensions.insert({cond_id, inst->GetOperand(i_name).AsString()});
+
+    if (res.second) {
+      // Never seen before, keep it.
+      inst = inst->NextNode();
+    } else {
+      // It's a duplicate, remove it.
+      inst = context()->KillInst(inst);
+      modified = true;
+    }
+  }
+
+  return modified;
+}
+
 bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const {
   bool modified = false;
 
@@ -43,16 +79,27 @@ bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const {
     return modified;
   }
 
-  std::unordered_set<uint32_t> capabilities;
-  for (auto* i = &*context()->capability_begin(); i;) {
-    auto res = capabilities.insert(i->GetSingleWordOperand(0u));
+  // set of {condition ID, capability}
+  // ID 0 means unconditional capability, ie., OpCapability, otherwise the ID is
+  // the condition operand of OpConditionalCapabilityINTEL.
+  std::set<std::pair<uint32_t, uint32_t>> capabilities;
+  for (auto* inst = &*context()->capability_begin(); inst;) {
+    uint32_t cond_id = 0;
+    uint32_t i_cap = 0;
+    if (inst->opcode() == spv::Op::OpConditionalCapabilityINTEL) {
+      cond_id = inst->GetOperand(0).AsId();
+      i_cap = 1;
+    }
+
+    auto res =
+        capabilities.insert({cond_id, inst->GetSingleWordOperand(i_cap)});
 
     if (res.second) {
       // Never seen before, keep it.
-      i = i->NextNode();
+      inst = inst->NextNode();
     } else {
       // It's a duplicate, remove it.
-      i = context()->KillInst(i);
+      inst = context()->KillInst(inst);
       modified = true;
     }
   }

+ 4 - 0
3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h

@@ -37,6 +37,10 @@ class RemoveDuplicatesPass : public Pass {
   Status Process() override;
 
  private:
+  // Remove duplicate extensions from the module
+  //
+  // Returns true if the module was modified, false otherwise.
+  bool RemoveDuplicateExtensions() const;
   // Remove duplicate capabilities from the module
   //
   // Returns true if the module was modified, false otherwise.

+ 7 - 2
3rdparty/spirv-tools/source/opt/remove_unused_interface_variables_pass.cpp

@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "remove_unused_interface_variables_pass.h"
+
 #include "source/spirv_constant.h"
 namespace spvtools {
 namespace opt {
@@ -55,7 +56,9 @@ class RemoveUnusedInterfaceVariablesContext {
 
   void CollectUsedVariables() {
     std::queue<uint32_t> roots;
-    roots.push(entry_.GetSingleWordInOperand(1));
+    const int op_i =
+        entry_.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1;
+    roots.push(entry_.GetSingleWordInOperand(op_i));
     parent_.context()->ProcessCallTreeFromRoots(pfn_, &roots);
   }
 
@@ -73,7 +76,9 @@ class RemoveUnusedInterfaceVariablesContext {
   }
 
   void Modify() {
-    for (int i = entry_.NumInOperands() - 1; i >= 3; --i)
+    const int min_num_operands =
+        entry_.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 4 : 3;
+    for (int i = entry_.NumInOperands() - 1; i >= min_num_operands; --i)
       entry_.RemoveInOperand(i);
     for (auto id : operands_to_add_) {
       entry_.AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));

+ 11 - 4
3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp

@@ -186,7 +186,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
     Instruction* added_dbg_value =
         context()->get_debug_info_mgr()->AddDebugValueForDecl(
             dbg_decl, /*value_id=*/var->result_id(),
-            /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
+            /*insert_before=*/insert_before, /*line=*/dbg_decl);
 
     if (added_dbg_value == nullptr) return false;
     added_dbg_value->AddOperand(
@@ -475,6 +475,7 @@ void ScalarReplacementPass::CreateVariable(
 
   if (id == 0) {
     replacements->push_back(nullptr);
+    return;
   }
 
   std::unique_ptr<Instruction> variable(
@@ -488,7 +489,10 @@ void ScalarReplacementPass::CreateVariable(
   Instruction* inst = &*block->begin();
 
   // If varInst was initialized, make sure to initialize its replacement.
-  GetOrCreateInitialValue(var_inst, index, inst);
+  if (!GetOrCreateInitialValue(var_inst, index, inst)) {
+    replacements->push_back(nullptr);
+    return;
+  }
   get_def_use_mgr()->AnalyzeInstDefUse(inst);
   context()->set_instr_block(inst, block);
 
@@ -509,11 +513,11 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
   return ptr_type_id;
 }
 
-void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
+bool ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
                                                     uint32_t index,
                                                     Instruction* newVar) {
   assert(source->opcode() == spv::Op::OpVariable);
-  if (source->NumInOperands() < 2) return;
+  if (source->NumInOperands() < 2) return true;
 
   uint32_t initId = source->GetSingleWordInOperand(1u);
   uint32_t storageId = GetStorageType(newVar)->result_id();
@@ -525,6 +529,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
     auto iter = type_to_null_.find(storageId);
     if (iter == type_to_null_.end()) {
       newInitId = TakeNextId();
+      if (newInitId == 0) return false;
       type_to_null_[storageId] = newInitId;
       context()->AddGlobalValue(
           MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
@@ -537,6 +542,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
   } else if (IsSpecConstantInst(init->opcode())) {
     // Create a new constant extract.
     newInitId = TakeNextId();
+    if (newInitId == 0) return false;
     context()->AddGlobalValue(MakeUnique<Instruction>(
         context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
         std::initializer_list<Operand>{
@@ -561,6 +567,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
   if (newInitId != 0) {
     newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
   }
+  return true;
 }
 
 uint64_t ScalarReplacementPass::GetArrayLength(

+ 3 - 1
3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h

@@ -199,7 +199,9 @@ class ScalarReplacementPass : public MemPass {
   // If there is an initial value for |source| for element |index|, it is
   // appended as an operand on |newVar|. If the initial value is OpUndef, no
   // initial value is added to |newVar|.
-  void GetOrCreateInitialValue(Instruction* source, uint32_t index,
+  //
+  // Returns true if the value was successfully created.
+  bool GetOrCreateInitialValue(Instruction* source, uint32_t index,
                                Instruction* newVar);
 
   // Replaces the load to the entire composite.

+ 20 - 2
3rdparty/spirv-tools/source/opt/split_combined_image_sampler_pass.cpp

@@ -556,10 +556,14 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
       Instruction* sampler;
     };
     std::vector<Replacement> replacements;
+    bool error = false;
 
     Function::RewriteParamFn rewriter =
         [&](std::unique_ptr<Instruction>&& param,
             std::back_insert_iterator<Function::ParamList>& appender) {
+          if (error) {
+            return;
+          }
           if (combined_types_.count(param->type_id()) == 0) {
             appender = std::move(param);
             return;
@@ -569,12 +573,22 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
           auto* combined_inst = param.release();
           auto* combined_type = def_use_mgr_->GetDef(combined_inst->type_id());
           auto [image_type, sampler_type] = SplitType(*combined_type);
+          uint32_t image_param_id = context()->TakeNextId();
+          if (image_param_id == 0) {
+            error = true;
+            return;
+          }
           auto image_param = MakeUnique<Instruction>(
               context(), spv::Op::OpFunctionParameter, image_type->result_id(),
-              context()->TakeNextId(), Instruction::OperandList{});
+              image_param_id, Instruction::OperandList{});
+          uint32_t sampler_param_id = context()->TakeNextId();
+          if (sampler_param_id == 0) {
+            error = true;
+            return;
+          }
           auto sampler_param = MakeUnique<Instruction>(
               context(), spv::Op::OpFunctionParameter,
-              sampler_type->result_id(), context()->TakeNextId(),
+              sampler_type->result_id(), sampler_param_id,
               Instruction::OperandList{});
           replacements.push_back(
               {combined_inst, image_param.get(), sampler_param.get()});
@@ -583,6 +597,10 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
         };
     fn.RewriteParams(rewriter);
 
+    if (error) {
+      return SPV_ERROR_INTERNAL;
+    }
+
     for (auto& r : replacements) {
       modified_ = true;
       def_use_mgr_->AnalyzeInstDefUse(r.image);

+ 9 - 6
3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp

@@ -87,13 +87,15 @@ std::string SSARewriter::PhiCandidate::PrettyPrint(const CFG* cfg) const {
   return str.str();
 }
 
-SSARewriter::PhiCandidate& SSARewriter::CreatePhiCandidate(uint32_t var_id,
+SSARewriter::PhiCandidate* SSARewriter::CreatePhiCandidate(uint32_t var_id,
                                                            BasicBlock* bb) {
-  // TODO(1841): Handle id overflow.
   uint32_t phi_result_id = pass_->context()->TakeNextId();
+  if (phi_result_id == 0) {
+    return nullptr;
+  }
   auto result = phi_candidates_.emplace(
       phi_result_id, PhiCandidate(var_id, phi_result_id, bb));
-  PhiCandidate& phi_candidate = result.first->second;
+  PhiCandidate* phi_candidate = &result.first->second;
   return phi_candidate;
 }
 
@@ -268,11 +270,12 @@ uint32_t SSARewriter::GetReachingDef(uint32_t var_id, BasicBlock* bb) {
     // If there is more than one predecessor, this is a join block which may
     // require a Phi instruction.  This will act as |var_id|'s current
     // definition to break potential cycles.
-    PhiCandidate& phi_candidate = CreatePhiCandidate(var_id, bb);
+    PhiCandidate* phi_candidate = CreatePhiCandidate(var_id, bb);
+    if (phi_candidate == nullptr) return 0;
 
     // Set the value for |bb| to avoid an infinite recursion.
-    WriteVariable(var_id, bb, phi_candidate.result_id());
-    val_id = AddPhiOperands(&phi_candidate);
+    WriteVariable(var_id, bb, phi_candidate->result_id());
+    val_id = AddPhiOperands(phi_candidate);
   }
 
   // If we could not find a store for this variable in the path from the root

+ 1 - 1
3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h

@@ -232,7 +232,7 @@ class SSARewriter {
   // during rewriting.
   //
   // Once the candidate Phi is created, it returns its ID.
-  PhiCandidate& CreatePhiCandidate(uint32_t var_id, BasicBlock* bb);
+  PhiCandidate* CreatePhiCandidate(uint32_t var_id, BasicBlock* bb);
 
   // Attempts to remove a trivial Phi candidate |phi_cand|. Trivial Phis are
   // those that only reference themselves and one other value |val| any number

+ 14 - 10
3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp

@@ -53,17 +53,15 @@ bool IsPowerOf2(uint32_t val) {
 
 Pass::Status StrengthReductionPass::Process() {
   // Initialize the member variables on a per module basis.
-  bool modified = false;
   int32_type_id_ = 0;
   uint32_type_id_ = 0;
   std::memset(constant_ids_, 0, sizeof(constant_ids_));
 
   FindIntTypesAndConstants();
-  modified = ScanFunctions();
-  return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+  return ScanFunctions();
 }
 
-bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
+Pass::Status StrengthReductionPass::ReplaceMultiplyByPowerOf2(
     BasicBlock::iterator* inst) {
   assert((*inst)->opcode() == spv::Op::OpIMul &&
          "Only works for multiplication of integers.");
@@ -72,7 +70,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
   // Currently only works on 32-bit integers.
   if ((*inst)->type_id() != int32_type_id_ &&
       (*inst)->type_id() != uint32_type_id_) {
-    return modified;
+    return Status::SuccessWithoutChange;
   }
 
   // Check the operands for a constant that is a power of 2.
@@ -87,9 +85,11 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
         modified = true;
         uint32_t shiftAmount = CountTrailingZeros(constVal);
         uint32_t shiftConstResultId = GetConstantId(shiftAmount);
+        if (shiftConstResultId == 0) return Status::Failure;
 
         // Create the new instruction.
         uint32_t newResultId = TakeNextId();
+        if (newResultId == 0) return Status::Failure;
         std::vector<Operand> newOperands;
         newOperands.push_back((*inst)->GetInOperand(1 - i));
         Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
@@ -117,7 +117,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
     }
   }
 
-  return modified;
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
 void StrengthReductionPass::FindIntTypesAndConstants() {
@@ -152,6 +152,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
 
     // Construct the constant.
     uint32_t resultId = TakeNextId();
+    if (resultId == 0) return 0;
     Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
                      {val});
     std::unique_ptr<Instruction> newConstant(new Instruction(
@@ -169,7 +170,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
   return constant_ids_[val];
 }
 
-bool StrengthReductionPass::ScanFunctions() {
+Pass::Status StrengthReductionPass::ScanFunctions() {
   // I did not use |ForEachInst| in the module because the function that acts on
   // the instruction gets a pointer to the instruction.  We cannot use that to
   // insert a new instruction.  I want an iterator.
@@ -178,16 +179,19 @@ bool StrengthReductionPass::ScanFunctions() {
     for (auto& bb : func) {
       for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
         switch (inst->opcode()) {
-          case spv::Op::OpIMul:
-            if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
+          case spv::Op::OpIMul: {
+            Status s = ReplaceMultiplyByPowerOf2(&inst);
+            if (s == Status::Failure) return Status::Failure;
+            if (s == Status::SuccessWithChange) modified = true;
             break;
+          }
           default:
             break;
         }
       }
     }
   }
-  return modified;
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
 }  // namespace opt

+ 2 - 2
3rdparty/spirv-tools/source/opt/strength_reduction_pass.h

@@ -32,7 +32,7 @@ class StrengthReductionPass : public Pass {
  private:
   // Replaces multiple by power of 2 with an equivalent bit shift.
   // Returns true if something changed.
-  bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*);
+  Status ReplaceMultiplyByPowerOf2(BasicBlock::iterator*);
 
   // Scan the types and constants in the module looking for the integer
   // types that we are
@@ -47,7 +47,7 @@ class StrengthReductionPass : public Pass {
 
   // Replaces certain instructions in function bodies with presumably cheaper
   // ones. Returns true if something changed.
-  bool ScanFunctions();
+  Status ScanFunctions();
 
   // Type ids for the types of interest, or 0 if they do not exist.
   uint32_t int32_type_id_;

+ 19 - 17
3rdparty/spirv-tools/source/opt/trim_capabilities_pass.cpp

@@ -427,20 +427,20 @@ Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
 // Opcode of interest to determine capabilities requirements.
 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 14> kOpcodeHandlers{{
     // clang-format off
-    {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
-    {spv::Op::OpImageWrite,        Handler_OpImageWrite_StorageImageWriteWithoutFormat},
-    {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
-    {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float16 },
-    {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
-    {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
-    {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int16 },
-    {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniformBufferBlock16},
-    {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageBuffer16BitAccess},
+    {spv::Op::OpImageRead,                   Handler_OpImageRead_StorageImageReadWithoutFormat},
+    {spv::Op::OpImageWrite,                  Handler_OpImageWrite_StorageImageWriteWithoutFormat},
+    {spv::Op::OpImageSparseRead,             Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
+    {spv::Op::OpTypeFloat,                   Handler_OpTypeFloat_Float16 },
+    {spv::Op::OpTypeFloat,                   Handler_OpTypeFloat_Float64 },
+    {spv::Op::OpTypeImage,                   Handler_OpTypeImage_ImageMSArray},
+    {spv::Op::OpTypeInt,                     Handler_OpTypeInt_Int16 },
+    {spv::Op::OpTypeInt,                     Handler_OpTypeInt_Int64 },
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageInputOutput16},
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StoragePushConstant16},
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniform16},
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniform16},
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniformBufferBlock16},
+    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageBuffer16BitAccess},
     // clang-format on
 }};
 
@@ -612,7 +612,9 @@ void TrimCapabilitiesPass::addInstructionRequirements(
     ExtensionSet* extensions) const {
   // Ignoring OpCapability and OpExtension instructions.
   if (instruction->opcode() == spv::Op::OpCapability ||
-      instruction->opcode() == spv::Op::OpExtension) {
+      instruction->opcode() == spv::Op::OpConditionalCapabilityINTEL ||
+      instruction->opcode() == spv::Op::OpExtension ||
+      instruction->opcode() == spv::Op::OpConditionalExtensionINTEL) {
     return;
   }
 
@@ -631,7 +633,7 @@ void TrimCapabilitiesPass::addInstructionRequirements(
   }
 
   // Last case: some complex logic needs to be run to determine capabilities.
-  auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
+  auto [begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
   for (auto it = begin; it != end; it++) {
     const OpcodeHandler handler = it->second;
     auto result = handler(instruction);
@@ -754,7 +756,7 @@ Pass::Status TrimCapabilitiesPass::Process() {
     return Status::SuccessWithoutChange;
   }
 
-  auto[required_capabilities, required_extensions] =
+  auto [required_capabilities, required_extensions] =
       DetermineRequiredCapabilitiesAndExtensions();
 
   Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);

+ 1 - 0
3rdparty/spirv-tools/source/opt/trim_capabilities_pass.h

@@ -82,6 +82,7 @@ class TrimCapabilitiesPass : public Pass {
       spv::Capability::FragmentShaderPixelInterlockEXT,
       spv::Capability::FragmentShaderSampleInterlockEXT,
       spv::Capability::FragmentShaderShadingRateInterlockEXT,
+      spv::Capability::Geometry,
       spv::Capability::GroupNonUniform,
       spv::Capability::GroupNonUniformArithmetic,
       spv::Capability::GroupNonUniformClustered,

+ 90 - 1
3rdparty/spirv-tools/source/opt/type_manager.cpp

@@ -495,6 +495,49 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
               {SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
       break;
     }
+    case Type::kTensorARM: {
+      auto tensor_type = type->AsTensorARM();
+      uint32_t const element_type =
+          GetTypeInstruction(tensor_type->element_type());
+      if (element_type == 0) {
+        return 0;
+      }
+      if (tensor_type->rank_id() != 0) {
+        if (tensor_type->shape_id() != 0) {
+          typeInst = MakeUnique<Instruction>(
+              context(), spv::Op::OpTypeTensorARM, 0, id,
+              std::initializer_list<Operand>{
+                  {SPV_OPERAND_TYPE_ID, {element_type}},
+                  {SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}},
+                  {SPV_OPERAND_TYPE_ID, {tensor_type->shape_id()}}});
+        } else {
+          typeInst = MakeUnique<Instruction>(
+              context(), spv::Op::OpTypeTensorARM, 0, id,
+              std::initializer_list<Operand>{
+                  {SPV_OPERAND_TYPE_ID, {element_type}},
+                  {SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}}});
+        }
+      } else {
+        typeInst =
+            MakeUnique<Instruction>(context(), spv::Op::OpTypeTensorARM, 0, id,
+                                    std::initializer_list<Operand>{
+                                        {SPV_OPERAND_TYPE_ID, {element_type}}});
+      }
+      break;
+    }
+    case Type::kGraphARM: {
+      auto const gty = type->AsGraphARM();
+      std::vector<Operand> ops;
+      ops.push_back(
+          Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {gty->num_inputs()}));
+      for (auto iotype : gty->io_types()) {
+        uint32_t iotype_id = GetTypeInstruction(iotype);
+        ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {iotype_id}));
+      }
+      typeInst = MakeUnique<Instruction>(context(), spv::Op::OpTypeGraphARM, 0,
+                                         id, ops);
+      break;
+    }
     default:
       assert(false && "Unexpected type");
       break;
@@ -754,6 +797,23 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
           cv_type->components());
       break;
     }
+    case Type::kTensorARM: {
+      const TensorARM* tensor_type = type.AsTensorARM();
+      const Type* element_type = tensor_type->element_type();
+      rebuilt_ty = MakeUnique<TensorARM>(
+          RebuildType(GetId(element_type), *element_type),
+          tensor_type->rank_id(), tensor_type->shape_id());
+      break;
+    }
+    case Type::kGraphARM: {
+      const GraphARM* graph_type = type.AsGraphARM();
+      std::vector<const Type*> io_types;
+      for (auto ioty : graph_type->io_types()) {
+        io_types.push_back(RebuildType(GetId(ioty), *ioty));
+      }
+      rebuilt_ty = MakeUnique<GraphARM>(graph_type->num_inputs(), io_types);
+      break;
+    }
     default:
       assert(false && "Unhandled type");
       return nullptr;
@@ -1036,6 +1096,31 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
                               inst.GetSingleWordInOperand(1), perm);
       break;
     }
+    case spv::Op::OpTypeTensorARM: {
+      switch (inst.NumInOperands()) {
+        case 1:
+          type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)));
+          break;
+        case 2:
+          type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
+                               inst.GetSingleWordInOperand(1));
+          break;
+        case 3:
+          type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
+                               inst.GetSingleWordInOperand(1),
+                               inst.GetSingleWordInOperand(2));
+          break;
+      }
+      break;
+    }
+    case spv::Op::OpTypeGraphARM: {
+      std::vector<const Type*> io_types;
+      for (unsigned i = 1; i < inst.NumInOperands(); i++) {
+        io_types.push_back(GetType(inst.GetSingleWordInOperand(i)));
+      }
+      type = new GraphARM(inst.GetSingleWordInOperand(0), io_types);
+      break;
+    }
     default:
       assert(false && "Type not handled by the type manager.");
       break;
@@ -1067,7 +1152,11 @@ void TypeManager::AttachDecoration(const Instruction& inst, Type* type) {
       const auto count = inst.NumOperands();
       std::vector<uint32_t> data;
       for (uint32_t i = 1; i < count; ++i) {
-        data.push_back(inst.GetSingleWordOperand(i));
+        // LinkageAttributes has a literal string as an operand, which is a
+        // varible length word. We cannot assume that all operands are single
+        // word.
+        const Operand::OperandData& words = inst.GetOperand(i).words;
+        data.insert(data.end(), words.begin(), words.end());
       }
       type->AddDecoration(std::move(data));
     } break;

+ 91 - 0
3rdparty/spirv-tools/source/opt/types.cpp

@@ -135,6 +135,8 @@ std::unique_ptr<Type> Type::Clone() const {
     DeclareKindCase(CooperativeVectorNV);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
+    DeclareKindCase(TensorARM);
+    DeclareKindCase(GraphARM);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -187,6 +189,8 @@ bool Type::operator==(const Type& other) const {
     DeclareKindCase(HitObjectNV);
     DeclareKindCase(TensorLayoutNV);
     DeclareKindCase(TensorViewNV);
+    DeclareKindCase(TensorARM);
+    DeclareKindCase(GraphARM);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -247,6 +251,8 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
     DeclareKindCase(HitObjectNV);
     DeclareKindCase(TensorLayoutNV);
     DeclareKindCase(TensorViewNV);
+    DeclareKindCase(TensorARM);
+    DeclareKindCase(GraphARM);
 #undef DeclareKindCase
     default:
       assert(false && "Unhandled type");
@@ -899,6 +905,91 @@ bool CooperativeVectorNV::IsSameImpl(const Type* that,
          components_ == mt->components_ && HasSameDecorations(that);
 }
 
+TensorARM::TensorARM(const Type* elty, const uint32_t rank,
+                     const uint32_t shape)
+    : Type(kTensorARM), element_type_(elty), rank_id_(rank), shape_id_(shape) {
+  assert(elty != nullptr);
+  if (shape != 0) {
+    assert(rank != 0);
+  }
+}
+
+std::string TensorARM::str() const {
+  std::ostringstream oss;
+  oss << "tensor<" << element_type_->str() << ", id(" << rank_id_ << "), id("
+      << shape_id_ << ")>";
+  return oss.str();
+}
+
+size_t TensorARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+  hash = hash_combine(hash, rank_id_);
+  hash = hash_combine(hash, shape_id_);
+  return element_type_->ComputeHashValue(hash, seen);
+}
+
+bool TensorARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
+  const TensorARM* tt = that->AsTensorARM();
+  if (!tt) return false;
+  return element_type_->IsSameImpl(tt->element_type_, seen) &&
+         rank_id_ == tt->rank_id_ && shape_id_ == tt->shape_id_ &&
+         HasSameDecorations(that);
+}
+
+GraphARM::GraphARM(const uint32_t num_inputs,
+                   const std::vector<const Type*>& io_types)
+    : Type(kGraphARM), num_inputs_(num_inputs), io_types_(io_types) {
+  assert(io_types.size() > 0);
+}
+
+std::string GraphARM::str() const {
+  std::ostringstream oss;
+  oss << "graph<" << num_inputs_;
+  for (auto ioty : io_types_) {
+    oss << "," << ioty->str();
+  }
+  oss << ">";
+  return oss.str();
+}
+
+bool GraphARM::is_shaped() const {
+  // A graph is considered to be shaped if all its interface tensors are shaped
+  for (auto ioty : io_types_) {
+    auto tensor_type = ioty->AsTensorARM();
+    assert(tensor_type);
+    if (!tensor_type->is_shaped()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+size_t GraphARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+  hash = hash_combine(hash, num_inputs_);
+  for (auto ioty : io_types_) {
+    hash = ioty->ComputeHashValue(hash, seen);
+  }
+  return hash;
+}
+
+bool GraphARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
+  const GraphARM* og = that->AsGraphARM();
+  if (!og) {
+    return false;
+  }
+  if (num_inputs_ != og->num_inputs_) {
+    return false;
+  }
+  if (io_types_.size() != og->io_types_.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < io_types_.size(); i++) {
+    if (!io_types_[i]->IsSameImpl(og->io_types_[i], seen)) {
+      return false;
+    }
+  }
+  return true;
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools

+ 56 - 0
3rdparty/spirv-tools/source/opt/types.h

@@ -69,6 +69,8 @@ class RayQueryKHR;
 class HitObjectNV;
 class TensorLayoutNV;
 class TensorViewNV;
+class TensorARM;
+class GraphARM;
 
 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
 // which is used as a way to probe the actual <subclass>.
@@ -114,6 +116,8 @@ class Type {
     kHitObjectNV,
     kTensorLayoutNV,
     kTensorViewNV,
+    kTensorARM,
+    kGraphARM,
     kLast
   };
 
@@ -220,6 +224,8 @@ class Type {
   DeclareCastMethod(HitObjectNV)
   DeclareCastMethod(TensorLayoutNV)
   DeclareCastMethod(TensorViewNV)
+  DeclareCastMethod(TensorARM)
+  DeclareCastMethod(GraphARM)
 #undef DeclareCastMethod
 
 protected:
@@ -774,6 +780,56 @@ class CooperativeVectorNV : public Type {
   const uint32_t components_;
 };
 
+class TensorARM : public Type {
+ public:
+  TensorARM(const Type* elty, const uint32_t rank = 0,
+            const uint32_t shape = 0);
+  TensorARM(const TensorARM&) = default;
+
+  std::string str() const override;
+
+  TensorARM* AsTensorARM() override { return this; }
+  const TensorARM* AsTensorARM() const override { return this; }
+
+  size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+  const Type* element_type() const { return element_type_; }
+  uint32_t rank_id() const { return rank_id_; }
+  uint32_t shape_id() const { return shape_id_; }
+  bool is_ranked() const { return rank_id_ != 0; }
+  bool is_shaped() const { return shape_id_ != 0; }
+
+ private:
+  bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+  const Type* element_type_;
+  const uint32_t rank_id_;
+  const uint32_t shape_id_;
+};
+
+class GraphARM : public Type {
+ public:
+  GraphARM(const uint32_t num_inputs, const std::vector<const Type*>& io_types);
+  GraphARM(const GraphARM&) = default;
+
+  std::string str() const override;
+
+  GraphARM* AsGraphARM() override { return this; }
+  const GraphARM* AsGraphARM() const override { return this; }
+
+  uint32_t num_inputs() const { return num_inputs_; }
+  const std::vector<const Type*>& io_types() const { return io_types_; }
+  bool is_shaped() const;
+
+  size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+ private:
+  bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+  const uint32_t num_inputs_;
+  const std::vector<const Type*> io_types_;
+};
+
 #define DefineParameterlessType(type, name)                                \
   class type : public Type {                                               \
    public:                                                                 \

+ 58 - 12
3rdparty/spirv-tools/source/opt/upgrade_memory_model.cpp

@@ -160,14 +160,38 @@ void UpgradeMemoryModel::UpgradeMemoryAndImages() {
       }
 
       switch (inst->opcode()) {
-        case spv::Op::OpLoad:
+        case spv::Op::OpLoad: {
+          Instruction* src_pointer = context()->get_def_use_mgr()->GetDef(
+              inst->GetSingleWordInOperand(0u));
+          analysis::Type* src_type =
+              context()->get_type_mgr()->GetType(src_pointer->type_id());
+          auto storage_class = src_type->AsPointer()->storage_class();
+          if (storage_class == spv::StorageClass::Function ||
+              storage_class == spv::StorageClass::Private) {
+            // If the buffer from function variable or private variable, flag
+            // NonPrivatePointer is unnecessary.
+            is_coherent = false;
+          }
           UpgradeFlags(inst, 1u, is_coherent, is_volatile, kVisibility,
                        kMemory);
           break;
-        case spv::Op::OpStore:
+        }
+        case spv::Op::OpStore: {
+          Instruction* src_pointer = context()->get_def_use_mgr()->GetDef(
+              inst->GetSingleWordInOperand(0u));
+          analysis::Type* src_type =
+              context()->get_type_mgr()->GetType(src_pointer->type_id());
+          auto storage_class = src_type->AsPointer()->storage_class();
+          if (storage_class == spv::StorageClass::Function ||
+              storage_class == spv::StorageClass::Private) {
+            // If the buffer from function variable or private variable, flag
+            // NonPrivatePointer is unnecessary.
+            is_coherent = false;
+          }
           UpgradeFlags(inst, 2u, is_coherent, is_volatile, kAvailability,
                        kMemory);
           break;
+        }
         case spv::Op::OpCopyMemory:
         case spv::Op::OpCopyMemorySized:
           start_operand = inst->opcode() == spv::Op::OpCopyMemory ? 2u : 3u;
@@ -366,6 +390,21 @@ std::pair<bool, bool> UpgradeMemoryModel::TraceInstruction(
         indices.push_back(inst->GetSingleWordInOperand(i));
       }
       break;
+    case spv::Op::OpLoad:
+      if (context()->get_type_mgr()->GetType(inst->type_id())->AsPointer()) {
+        analysis::Integer int_ty(32, false);
+        uint32_t int_id =
+            context()->get_type_mgr()->GetTypeInstruction(&int_ty);
+        const analysis::Constant* constant =
+            context()->get_constant_mgr()->GetConstant(
+                context()->get_type_mgr()->GetType(int_id), {0u});
+        uint32_t constant_id = context()
+                                   ->get_constant_mgr()
+                                   ->GetDefiningInstruction(constant)
+                                   ->result_id();
+
+        indices.push_back(constant_id);
+      }
     default:
       break;
   }
@@ -661,22 +700,29 @@ void UpgradeMemoryModel::UpgradeBarriers() {
       roots.push(e.GetSingleWordInOperand(1u));
       if (context()->ProcessCallTreeFromRoots(CollectBarriers, &roots)) {
         for (auto barrier : barriers) {
-          // Add OutputMemoryKHR to the semantics of the barriers.
+          // Add OutputMemoryKHR to the semantics of the non-relaxed barriers.
           uint32_t semantics_id = barrier->GetSingleWordInOperand(2u);
           Instruction* semantics_inst =
               context()->get_def_use_mgr()->GetDef(semantics_id);
           analysis::Type* semantics_type =
               context()->get_type_mgr()->GetType(semantics_inst->type_id());
           uint64_t semantics_value = GetIndexValue(semantics_inst);
-          const analysis::Constant* constant =
-              context()->get_constant_mgr()->GetConstant(
-                  semantics_type,
-                  {static_cast<uint32_t>(semantics_value) |
-                   uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR)});
-          barrier->SetInOperand(2u, {context()
-                                         ->get_constant_mgr()
-                                         ->GetDefiningInstruction(constant)
-                                         ->result_id()});
+          const uint64_t memory_order_mask =
+              uint64_t(spv::MemorySemanticsMask::Acquire |
+                       spv::MemorySemanticsMask::Release |
+                       spv::MemorySemanticsMask::AcquireRelease |
+                       spv::MemorySemanticsMask::SequentiallyConsistent);
+          if (semantics_value & memory_order_mask) {
+            const analysis::Constant* constant =
+                context()->get_constant_mgr()->GetConstant(
+                    semantics_type,
+                    {static_cast<uint32_t>(semantics_value) |
+                     uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR)});
+            barrier->SetInOperand(2u, {context()
+                                           ->get_constant_mgr()
+                                           ->GetDefiningInstruction(constant)
+                                           ->result_id()});
+          }
         }
       }
       barriers.clear();

+ 4 - 1
3rdparty/spirv-tools/source/parsed_operand.cpp

@@ -59,7 +59,10 @@ void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst,
             *out << spvtools::utils::FloatProxy<spvtools::utils::Float8_E5M2>(
                 uint8_t(word & 0xFF));
             break;
-          // TODO Bfloat16
+          case SPV_FP_ENCODING_BFLOAT16:
+            *out << spvtools::utils::FloatProxy<spvtools::utils::BFloat16>(
+                uint16_t(word & 0xFFFF));
+            break;
           case SPV_FP_ENCODING_UNKNOWN:
             switch (operand.number_bit_width) {
               case 16:

+ 1 - 1
3rdparty/spirv-tools/source/text_handler.cpp

@@ -336,7 +336,7 @@ spv_result_t AssemblyContext::recordTypeDefinition(
       return diagnostic() << "Invalid OpTypeFloat instruction";
     spv_fp_encoding_t enc = SPV_FP_ENCODING_UNKNOWN;
     if (pInst->words.size() >= 4) {
-      const spvtools::OperandDesc* desc;
+      const spvtools::OperandDesc* desc = nullptr;
       spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
                                                     pInst->words[3], &desc);
       if (status == SPV_SUCCESS) {

+ 90 - 0
3rdparty/spirv-tools/source/util/hex_float.h

@@ -103,6 +103,34 @@ class Float16 {
   uint16_t val;
 };
 
+class BFloat16 {
+ public:
+  BFloat16(uint16_t v) : val(v) {}
+  BFloat16() = default;
+  BFloat16(const BFloat16& other) { val = other.val; }
+
+  // Exponent mask: 0x7F80, Mantissa mask: 0x007F
+  static bool isNan(const BFloat16& val) {
+    return ((val.val & 0x7F80) == 0x7F80) && ((val.val & 0x007F) != 0);
+  }
+  static bool isInfinity(const BFloat16& val) {
+    return ((val.val & 0x7F80) == 0x7F80) && ((val.val & 0x007F) == 0);
+  }
+
+  uint16_t get_value() const { return val; }
+
+  // a sign bit of 0, and an all 1 mantissa.
+  static BFloat16 max() { return BFloat16(0x7F7F); }
+  // a sign bit of 1, and an all 1 mantissa.
+  static BFloat16 lowest() { return BFloat16(0xFF7F); }
+
+ private:
+  // 15: Sign
+  // 14-7: Exponent
+  // 6-0: Mantissa
+  uint16_t val;
+};
+
 // To specialize this type, you must override uint_type to define
 // an unsigned integer that can fit your floating point type.
 // You must also add a isNan function that returns true if
@@ -212,6 +240,24 @@ struct FloatProxyTraits<Float16> {
   static uint32_t width() { return 16u; }
 };
 
+template <>
+struct FloatProxyTraits<BFloat16> {
+  using uint_type = uint16_t;
+  static bool isNan(BFloat16 f) { return BFloat16::isNan(f); }
+  // Returns true if the given value is any kind of infinity.
+  static bool isInfinity(BFloat16 f) { return BFloat16::isInfinity(f); }
+  // Returns the maximum normal value.
+  static BFloat16 max() { return BFloat16::max(); }
+  // Returns the lowest normal value.
+  static BFloat16 lowest() { return BFloat16::lowest(); }
+  // Returns the value as the native floating point format.
+  static BFloat16 getAsFloat(const uint_type& t) { return BFloat16(t); }
+  // Returns the bits from the given floating pointer number.
+  static uint_type getBitsFromFloat(const BFloat16& t) { return t.get_value(); }
+  // Returns the bitwidth.
+  static uint32_t width() { return 16u; }
+};
+
 // Since copying a floating point number (especially if it is NaN)
 // does not guarantee that bits are preserved, this class lets us
 // store the type and use it as a float when necessary.
@@ -403,6 +449,23 @@ struct HexFloatTraits<FloatProxy<Float16>> {
   static const uint_type NaN_pattern = 0x7c00;
 };
 
+// Traits for BFloat16.
+// 1 sign bit, 7 exponent bits, 8 fractional bits.
+template <>
+struct HexFloatTraits<FloatProxy<BFloat16>> {
+  using uint_type = uint16_t;
+  using int_type = int16_t;
+  using underlying_type = FloatProxy<BFloat16>;
+  using underlying_typetraits = FloatProxyTraits<BFloat16>;
+  using native_type = uint16_t;
+  static const uint_type num_used_bits = 16;
+  static const uint_type num_exponent_bits = 8;
+  static const uint_type num_fraction_bits = 7;
+  static const uint_type exponent_bias = 127;
+  static const bool has_infinity = true;
+  static const uint_type NaN_pattern = 0x7F80;
+};
+
 enum class round_direction {
   kToZero,
   kToNearestEven,
@@ -1038,6 +1101,26 @@ ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
   }
   return is;
 }
+
+// Same flow as Float16
+template <>
+inline std::istream&
+ParseNormalFloat<FloatProxy<BFloat16>, HexFloatTraits<FloatProxy<BFloat16>>>(
+    std::istream& is, bool negate_value,
+    HexFloat<FloatProxy<BFloat16>, HexFloatTraits<FloatProxy<BFloat16>>>&
+        value) {
+  HexFloat<FloatProxy<float>> float_val(0.0f);
+  ParseNormalFloat(is, negate_value, float_val);
+
+  float_val.castTo(value, round_direction::kToZero);
+
+  if (BFloat16::isInfinity(value.value().getAsFloat())) {
+    value.set_value(value.isNegative() ? BFloat16::lowest() : BFloat16::max());
+    is.setstate(std::ios_base::failbit);
+  }
+  return is;
+}
+
 // Specialization of ParseNormalFloat for FloatProxy<Float8_E4M3> values.
 // This will parse the float as it were a 32-bit floating point number,
 // and then round it down to fit into a Float8_E4M3 value.
@@ -1468,6 +1551,13 @@ inline std::ostream& operator<<<Float16>(std::ostream& os,
   return os;
 }
 
+template <>
+inline std::ostream& operator<< <BFloat16>(std::ostream& os,
+                                           const FloatProxy<BFloat16>& value) {
+  os << HexFloat<FloatProxy<BFloat16>>(value);
+  return os;
+}
+
 template <>
 inline std::ostream& operator<< <Float8_E4M3>(
     std::ostream& os, const FloatProxy<Float8_E4M3>& value) {

+ 9 - 1
3rdparty/spirv-tools/source/util/parse_number.cpp

@@ -185,7 +185,15 @@ EncodeNumberStatus ParseAndEncodeFloatingPointNumber(
       emit(static_cast<uint32_t>(hVal.value().getAsFloat().get_value()));
       return EncodeNumberStatus::kSuccess;
     } break;
-    case SPV_FP_ENCODING_BFLOAT16:  // FIXME this likely needs separate handling
+    case SPV_FP_ENCODING_BFLOAT16: {
+      HexFloat<FloatProxy<BFloat16>> hVal(0);
+      if (!ParseNumber(text, &hVal)) {
+        ErrorMsgStream(error_msg) << "Invalid bfloat16 literal: " << text;
+        return EncodeNumberStatus::kInvalidText;
+      }
+      emit(static_cast<uint32_t>(hVal.value().getAsFloat().get_value()));
+      return EncodeNumberStatus::kSuccess;
+    } break;
     case SPV_FP_ENCODING_IEEE754_BINARY16: {
       HexFloat<FloatProxy<Float16>> hVal(0);
       if (!ParseNumber(text, &hVal)) {

+ 59 - 23
3rdparty/spirv-tools/source/val/validate.cpp

@@ -64,9 +64,12 @@ void RegisterExtension(ValidationState_t& _,
 spv_result_t ProcessExtensions(void* user_data,
                                const spv_parsed_instruction_t* inst) {
   const spv::Op opcode = static_cast<spv::Op>(inst->opcode);
-  if (opcode == spv::Op::OpCapability) return SPV_SUCCESS;
+  if (opcode == spv::Op::OpCapability ||
+      opcode == spv::Op::OpConditionalCapabilityINTEL)
+    return SPV_SUCCESS;
 
-  if (opcode == spv::Op::OpExtension) {
+  if (opcode == spv::Op::OpExtension ||
+      opcode == spv::Op::OpConditionalExtensionINTEL) {
     ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
     RegisterExtension(_, inst);
     return SPV_SUCCESS;
@@ -115,10 +118,11 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
   _.ComputeFunctionToEntryPointMapping();
   _.ComputeRecursiveEntryPoints();
 
-  if (_.entry_points().empty() && !_.HasCapability(spv::Capability::Linkage)) {
+  if (_.entry_points().empty() && !_.HasCapability(spv::Capability::Linkage) &&
+      !_.HasCapability(spv::Capability::GraphARM)) {
     return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
            << "No OpEntryPoint instruction was found. This is only allowed if "
-              "the Linkage capability is being used.";
+              "the Linkage or GraphARM capability is being used.";
   }
 
   for (const auto& entry_point : _.entry_points()) {
@@ -151,6 +155,16 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateGraphEntryPoints(ValidationState_t& _) {
+  if (_.graph_entry_points().empty() &&
+      _.HasCapability(spv::Capability::GraphARM)) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
+           << "No OpGraphEntryPointARM instruction was found but the GraphARM "
+              "capability is declared.";
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t ValidateBinaryUsingContextAndValidationState(
     const spv_context_t& context, const uint32_t* words, const size_t num_words,
     spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
@@ -217,43 +231,59 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
       // able to, briefly, de-const the instruction.
       Instruction* inst = const_cast<Instruction*>(&instruction);
 
-      if (inst->opcode() == spv::Op::OpEntryPoint) {
-        const auto entry_point = inst->GetOperandAs<uint32_t>(1);
-        const auto execution_model = inst->GetOperandAs<spv::ExecutionModel>(0);
-        const std::string desc_name = inst->GetOperandAs<std::string>(2);
+      if ((inst->opcode() == spv::Op::OpEntryPoint) ||
+          (inst->opcode() == spv::Op::OpConditionalEntryPointINTEL)) {
+        const int i_model = inst->opcode() == spv::Op::OpEntryPoint ? 0 : 1;
+        const int i_point = inst->opcode() == spv::Op::OpEntryPoint ? 1 : 2;
+        const int i_name = inst->opcode() == spv::Op::OpEntryPoint ? 2 : 3;
+        const int min_num_operands =
+            inst->opcode() == spv::Op::OpEntryPoint ? 3 : 4;
+
+        const auto entry_point = inst->GetOperandAs<uint32_t>(i_point);
+        const auto execution_model =
+            inst->GetOperandAs<spv::ExecutionModel>(i_model);
+        const std::string desc_name = inst->GetOperandAs<std::string>(i_name);
 
         ValidationState_t::EntryPointDescription desc;
         desc.name = desc_name;
 
         std::vector<uint32_t> interfaces;
-        for (size_t j = 3; j < inst->operands().size(); ++j)
+        for (size_t j = min_num_operands; j < inst->operands().size(); ++j)
           desc.interfaces.push_back(inst->word(inst->operand(j).offset));
 
         vstate->RegisterEntryPoint(entry_point, execution_model,
                                    std::move(desc));
 
-        if (visited_entry_points.size() > 0) {
-          for (const Instruction* check_inst : visited_entry_points) {
-            const auto check_execution_model =
-                check_inst->GetOperandAs<spv::ExecutionModel>(0);
-            const std::string check_name =
-                check_inst->GetOperandAs<std::string>(2);
-
-            if (desc_name == check_name &&
-                execution_model == check_execution_model) {
-              return vstate->diag(SPV_ERROR_INVALID_DATA, inst)
-                     << "2 Entry points cannot share the same name and "
-                        "ExecutionMode.";
+        if (inst->opcode() == spv::Op::OpEntryPoint) {
+          // conditional entry points are allowed to share the same name and
+          // exec mode
+          if (visited_entry_points.size() > 0) {
+            for (const Instruction* check_inst : visited_entry_points) {
+              const auto check_execution_model =
+                  check_inst->GetOperandAs<spv::ExecutionModel>(i_model);
+              const std::string check_name =
+                  check_inst->GetOperandAs<std::string>(i_name);
+
+              if (desc_name == check_name &&
+                  execution_model == check_execution_model) {
+                return vstate->diag(SPV_ERROR_INVALID_DATA, inst)
+                       << "2 Entry points cannot share the same name and "
+                          "ExecutionMode.";
+              }
             }
           }
+          visited_entry_points.push_back(inst);
         }
-        visited_entry_points.push_back(inst);
 
         has_mask_task_nv |= (execution_model == spv::ExecutionModel::TaskNV ||
                              execution_model == spv::ExecutionModel::MeshNV);
         has_mask_task_ext |= (execution_model == spv::ExecutionModel::TaskEXT ||
                               execution_model == spv::ExecutionModel::MeshEXT);
       }
+      if (inst->opcode() == spv::Op::OpGraphEntryPointARM) {
+        const auto graph = inst->GetOperandAs<uint32_t>(1);
+        vstate->RegisterGraphEntryPoint(graph);
+      }
       if (inst->opcode() == spv::Op::OpFunctionCall) {
         if (!vstate->in_function_body()) {
           return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction)
@@ -299,6 +329,10 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
     return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
            << "Missing OpFunctionEnd at end of module.";
 
+  if (vstate->graph_definition_region() != kGraphDefinitionOutside)
+    return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
+           << "Missing OpGraphEndARM at end of module.";
+
   if (vstate->HasCapability(spv::Capability::BindlessTextureNV) &&
       !vstate->has_samplerimage_variable_address_mode_specified())
     return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
@@ -314,7 +348,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
   if (auto error = ValidateForwardDecls(*vstate)) return error;
 
   // Calculate reachability after all the blocks are parsed, but early that it
-  // can be relied on in subsequent pases.
+  // can be relied on in subsequent passes.
   ReachabilityPass(*vstate);
 
   // ID usage needs be handled in its own iteration of the instructions,
@@ -368,6 +402,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
     if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
     if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
     if (auto error = TensorPass(*vstate, &instruction)) return error;
+    if (auto error = GraphPass(*vstate, &instruction)) return error;
     if (auto error = InvalidTypePass(*vstate, &instruction)) return error;
   }
 
@@ -377,6 +412,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
   if (auto error = ValidateAdjacency(*vstate)) return error;
 
   if (auto error = ValidateEntryPoints(*vstate)) return error;
+  if (auto error = ValidateGraphEntryPoints(*vstate)) return error;
   // CFG checks are performed after the binary has been parsed
   // and the CFGPass has collected information about the control flow
   if (auto error = PerformCfgChecks(*vstate)) return error;

+ 5 - 2
3rdparty/spirv-tools/source/val/validate.h

@@ -195,8 +195,8 @@ spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst);
 /// Validates correctness of debug instructions.
 spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst);
 
-// Validates that capability declarations use operands allowed in the current
-// context.
+/// Validates that capability declarations use operands allowed in the current
+/// context.
 spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst);
 
 /// Validates correctness of primitive instructions.
@@ -226,6 +226,9 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
 /// Validates correctness of tensor instructions.
 spv_result_t TensorPass(ValidationState_t& _, const Instruction* inst);
 
+/// Validates correctness of graph instructions.
+spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst);
+
 /// Validates correctness of certain special type instructions.
 spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst);
 

+ 23 - 2
3rdparty/spirv-tools/source/val/validate_annotation.cpp

@@ -333,6 +333,14 @@ spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) {
 }
 
 spv_result_t ValidateDecorateId(ValidationState_t& _, const Instruction* inst) {
+  const auto target_id = inst->GetOperandAs<uint32_t>(0);
+  const auto target = _.FindDef(target_id);
+  if (target && spv::Op::OpDecorationGroup == target->opcode()) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpMemberDecorate Target <id> " << _.getIdName(target_id)
+           << " must not be an OpDecorationGroup instruction.";
+  }
+
   const auto decoration = inst->GetOperandAs<spv::Decoration>(1);
   if (!DecorationTakesIdParameters(decoration)) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -340,6 +348,20 @@ spv_result_t ValidateDecorateId(ValidationState_t& _, const Instruction* inst) {
               "OpDecorateId";
   }
 
+  for (uint32_t i = 2; i < inst->operands().size(); ++i) {
+    const auto param_id = inst->GetOperandAs<uint32_t>(i);
+    const auto param = _.FindDef(param_id);
+
+    // Both target and param are elements of ordered_instructions we can
+    // determine their relative positions in the SPIR-V module by comparing
+    // pointers.
+    if (target <= param) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "Parameter <ID> " << _.getIdName(param_id)
+             << " must appear earlier in the binary than the target";
+    }
+  }
+
   // No member decorations take id parameters, so we don't bother checking if
   // we are using a member only decoration here.
 
@@ -388,8 +410,7 @@ spv_result_t ValidateDecorationGroup(ValidationState_t& _,
     if (use->opcode() != spv::Op::OpDecorate &&
         use->opcode() != spv::Op::OpGroupDecorate &&
         use->opcode() != spv::Op::OpGroupMemberDecorate &&
-        use->opcode() != spv::Op::OpName &&
-        use->opcode() != spv::Op::OpDecorateId && !use->IsNonSemantic()) {
+        use->opcode() != spv::Op::OpName && !use->IsNonSemantic()) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Result id of OpDecorationGroup can only "
              << "be targeted by OpName, OpGroupDecorate, "

+ 0 - 21
3rdparty/spirv-tools/source/val/validate_atomics.cpp

@@ -388,27 +388,6 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
         if (auto error = ValidateMemorySemantics(
                 _, inst, unequal_semantics_index, memory_scope))
           return error;
-
-        // Volatile bits must match for equal and unequal semantics. Previous
-        // checks guarantee they are 32-bit constants, but we need to recheck
-        // whether they are evaluatable constants.
-        bool is_int32 = false;
-        bool is_equal_const = false;
-        bool is_unequal_const = false;
-        uint32_t equal_value = 0;
-        uint32_t unequal_value = 0;
-        std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
-            inst->GetOperandAs<uint32_t>(equal_semantics_index));
-        std::tie(is_int32, is_unequal_const, unequal_value) =
-            _.EvalInt32IfConst(
-                inst->GetOperandAs<uint32_t>(unequal_semantics_index));
-        if (is_equal_const && is_unequal_const &&
-            ((equal_value & uint32_t(spv::MemorySemanticsMask::Volatile)) ^
-             (unequal_value & uint32_t(spv::MemorySemanticsMask::Volatile)))) {
-          return _.diag(SPV_ERROR_INVALID_ID, inst)
-                 << "Volatile mask setting must match for Equal and Unequal "
-                    "memory semantics";
-        }
       }
 
       if (opcode == spv::Op::OpAtomicStore) {

+ 4 - 4
3rdparty/spirv-tools/source/val/validate_barriers.cpp

@@ -45,10 +45,10 @@ spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst) {
                       model != spv::ExecutionModel::MeshNV) {
                     if (message) {
                       *message =
-                          "OpControlBarrier requires one of the following "
-                          "Execution "
-                          "Models: TessellationControl, GLCompute, Kernel, "
-                          "MeshNV or TaskNV";
+                          "In SPIR-V 1.2 or earlier, OpControlBarrier requires "
+                          "one of the following "
+                          "Execution Models: TessellationControl, GLCompute, "
+                          "Kernel, MeshNV or TaskNV";
                     }
                     return false;
                   }

+ 4 - 2
3rdparty/spirv-tools/source/val/validate_bitwise.cpp

@@ -39,9 +39,11 @@ spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
     if (_.GetBitWidth(base_type) != 32 &&
         !_.options()->allow_vulkan_32_bit_bitwise) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4781)
+             << _.VkErrorID(10824)
              << "Expected 32-bit int type for Base operand: "
-             << spvOpcodeString(opcode);
+             << spvOpcodeString(opcode)
+             << _.MissingFeature("maintenance9 feature",
+                                 "--allow-vulkan-32-bit-bitwise", false);
     }
   }
 

+ 411 - 166
3rdparty/spirv-tools/source/val/validate_builtins.cpp

@@ -17,21 +17,22 @@
 // Validates correctness of built-in variables.
 
 #include <array>
+#include <cstdint>
 #include <functional>
 #include <list>
 #include <map>
 #include <set>
 #include <sstream>
-#include <stack>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #include "source/opcode.h"
 #include "source/spirv_target_env.h"
-#include "source/util/bitutils.h"
 #include "source/val/instruction.h"
 #include "source/val/validate.h"
 #include "source/val/validation_state.h"
+#include "spirv/unified1/spirv.hpp11"
 
 namespace spvtools {
 namespace val {
@@ -373,6 +374,18 @@ class BuiltInsValidator {
   spv_result_t ValidateMeshShadingEXTBuiltinsAtDefinition(
       const Decoration& decoration, const Instruction& inst);
 
+  // Used as a common method for validating MeshEXT builtins
+  spv_result_t ValidateMeshBuiltinInterfaceRules(
+      const Decoration& decoration, const Instruction& inst,
+      spv::Op scalar_type, const Instruction& referenced_from_inst);
+  spv_result_t ValidatePrimitiveShadingRateInterfaceRules(
+      const Decoration& decoration, const Instruction& inst,
+      const Instruction& referenced_from_inst);
+  // Builtin that needs check incase **not** used with MeshEXT
+  spv_result_t ValidateNonMeshInterfaceRules(
+      const Decoration& decoration, const Instruction& inst,
+      const Instruction& referenced_from_inst);
+
   // The following section contains functions which are called when id defined
   // by |referenced_inst| is
   // 1. referenced by |referenced_from_inst|
@@ -590,8 +603,9 @@ class BuiltInsValidator {
   spv_result_t ValidateBool(
       const Decoration& decoration, const Instruction& inst,
       const std::function<spv_result_t(const std::string& message)>& diag);
-  spv_result_t ValidateBlockBoolOrArrayedBool(
+  spv_result_t ValidateBlockTypeOrArrayedType(
       const Decoration& decoration, const Instruction& inst,
+      bool& present_in_block, spv::Op expected_scalar_type,
       const std::function<spv_result_t(const std::string& message)>& diag);
   spv_result_t ValidateI(
       const Decoration& decoration, const Instruction& inst,
@@ -675,20 +689,50 @@ class BuiltInsValidator {
   // UniformConstant".
   std::string GetStorageClassDesc(const Instruction& inst) const;
 
+  uint64_t GetArrayLength(uint32_t interface_var_id);
+
   // Updates inner working of the class. Is called sequentially for every
   // instruction.
   void Update(const Instruction& inst);
 
-  // Check if "inst" is an interface variable
-  // or type of a interface varibale of any mesh entry point
-  bool isMeshInterfaceVar(const Instruction& inst) {
-    auto getUnderlyingTypeId = [&](const Instruction* ifxVar) {
-      auto pointerTypeInst = _.FindDef(ifxVar->type_id());
-      auto typeInst = _.FindDef(pointerTypeInst->GetOperandAs<uint32_t>(2));
-      while (typeInst->opcode() == spv::Op::OpTypeArray) {
-        typeInst = _.FindDef(typeInst->GetOperandAs<uint32_t>(1));
+  bool IsBulitinInEntryPoint(const Instruction& inst, uint32_t entry_point) {
+    auto get_underlying_type_id = [&](const Instruction* ifx_var) {
+      auto pointer_type_inst = _.FindDef(ifx_var->type_id());
+      auto type_inst = _.FindDef(pointer_type_inst->GetOperandAs<uint32_t>(2));
+      while (type_inst->opcode() == spv::Op::OpTypeArray) {
+        type_inst = _.FindDef(type_inst->GetOperandAs<uint32_t>(1));
       };
-      return typeInst->id();
+      return type_inst->id();
+    };
+
+    for (const auto& desc : _.entry_point_descriptions(entry_point)) {
+      for (auto interface : desc.interfaces) {
+        if (inst.opcode() == spv::Op::OpTypeStruct) {
+          auto varInst = _.FindDef(interface);
+          if (inst.id() == get_underlying_type_id(varInst)) {
+            return true;
+          }
+        } else if (inst.id() == interface) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  // Check if "inst" is an interface variable or type of a interface varibale
+  // of any mesh entry point. Populate entry_point_interface_id with all
+  // entry points and interface variables that refer to the "inst"
+  bool IsMeshInterfaceVar(
+      const Instruction& inst,
+      std::map<uint32_t, uint32_t>& entry_point_interface_id) {
+    auto get_underlying_type_id = [&](const Instruction* ifx_var) {
+      auto pointer_type_inst = _.FindDef(ifx_var->type_id());
+      auto type_inst = _.FindDef(pointer_type_inst->GetOperandAs<uint32_t>(2));
+      while (type_inst->opcode() == spv::Op::OpTypeArray) {
+        type_inst = _.FindDef(type_inst->GetOperandAs<uint32_t>(1));
+      };
+      return type_inst->id();
     };
 
     for (const uint32_t entry_point : _.entry_points()) {
@@ -699,15 +743,19 @@ class BuiltInsValidator {
           for (auto interface : desc.interfaces) {
             if (inst.opcode() == spv::Op::OpTypeStruct) {
               auto varInst = _.FindDef(interface);
-              if (inst.id() == getUnderlyingTypeId(varInst)) return true;
+              if (inst.id() == get_underlying_type_id(varInst)) {
+                entry_point_interface_id[entry_point] = interface;
+                break;
+              }
             } else if (inst.id() == interface) {
-              return true;
+              entry_point_interface_id[entry_point] = interface;
+              break;
             }
           }
         }
       }
     }
-    return false;
+    return !entry_point_interface_id.empty();
   }
 
   ValidationState_t& _;
@@ -730,6 +778,10 @@ class BuiltInsValidator {
 
   // Execution models with which the current function can be called.
   std::set<spv::ExecutionModel> execution_models_;
+
+  // For Builtin that can only be declared once in an entry point, keep track if
+  // the entry point has it already
+  std::set<uint32_t> cull_primitive_entry_points_;
 };
 
 void BuiltInsValidator::Update(const Instruction& inst) {
@@ -807,6 +859,29 @@ std::string BuiltInsValidator::GetStorageClassDesc(
   return ss.str();
 }
 
+uint64_t BuiltInsValidator::GetArrayLength(uint32_t interface_var_id) {
+  uint32_t underlying_type;
+  spv::StorageClass storage_class;
+  uint64_t array_len = -1;
+  const Instruction* inst = _.FindDef(interface_var_id);
+  if (inst->opcode() != spv::Op::OpVariable) {
+    return -1;
+  }
+
+  if (!_.GetPointerTypeInfo(inst->type_id(), &underlying_type,
+                            &storage_class)) {
+    return 0;
+  }
+  if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
+    // Get the array length
+    const auto length_id = _.FindDef(underlying_type)->word(3u);
+    if (!_.EvalConstantValUint64(length_id, &array_len)) {
+      return 0;
+    }
+  }
+  return array_len;
+}
+
 spv_result_t BuiltInsValidator::ValidateBool(
     const Decoration& decoration, const Instruction& inst,
     const std::function<spv_result_t(const std::string& message)>& diag) {
@@ -823,25 +898,50 @@ spv_result_t BuiltInsValidator::ValidateBool(
   return SPV_SUCCESS;
 }
 
-spv_result_t BuiltInsValidator::ValidateBlockBoolOrArrayedBool(
-    const Decoration& decoration, const Instruction& inst,
+spv_result_t BuiltInsValidator::ValidateBlockTypeOrArrayedType(
+    const Decoration& decoration, const Instruction& inst, bool& isBlock,
+    spv::Op expected_scalar_type,
     const std::function<spv_result_t(const std::string& message)>& diag) {
   uint32_t underlying_type = 0;
+  int64_t array_len = -1;
+  isBlock = true;
   if (spv_result_t error =
           GetUnderlyingType(_, decoration, inst, &underlying_type)) {
     return error;
   }
   // Strip the array, if present.
   if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
+    // Get the array length
+    const auto length_id = _.FindDef(underlying_type)->word(3u);
+    if (!_.EvalConstantValInt64(length_id, &array_len)) {
+      return diag(GetDefinitionDesc(decoration, inst) +
+                  " Failed to find the array length.");
+    }
     underlying_type = _.FindDef(underlying_type)->word(2u);
+    isBlock = false;
   } else if (!_.HasDecoration(inst.id(), spv::Decoration::Block)) {
     // If not in array, and bool is in a struct, must be in a Block struct
     return diag(GetDefinitionDesc(decoration, inst) +
                 " Scalar boolean must be in a Block.");
   }
 
-  if (!_.IsBoolScalarType(underlying_type)) {
-    return diag(GetDefinitionDesc(decoration, inst) + " is not a bool scalar.");
+  switch (expected_scalar_type) {
+    case spv::Op::OpTypeBool:
+      if (!_.IsBoolScalarType(underlying_type)) {
+        return diag(GetDefinitionDesc(decoration, inst) +
+                    " is not a bool scalar.");
+      }
+      break;
+    case spv::Op::OpTypeInt:
+      if (!_.IsIntScalarType(underlying_type)) {
+        return diag(GetDefinitionDesc(decoration, inst) +
+                    " is not an integer scalar.");
+      }
+      break;
+    default:
+      assert(0 && "Unhandled scalar type");
+      return diag(GetDefinitionDesc(decoration, inst) +
+                  " is not a recognized scalar type.");
   }
 
   return SPV_SUCCESS;
@@ -2188,49 +2288,6 @@ spv_result_t BuiltInsValidator::ValidatePositionAtReference(
 
 spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition(
     const Decoration& decoration, const Instruction& inst) {
-  if (spvIsVulkanEnv(_.context()->target_env)) {
-    // PrimitiveId can be a per-primitive variable for mesh shader stage.
-    // In such cases variable will have an array of 32-bit integers.
-    if (decoration.struct_member_index() != Decoration::kInvalidMember) {
-      // This must be a 32-bit int scalar.
-      if (spv_result_t error = ValidateI32(
-              decoration, inst,
-              [this, &inst](const std::string& message) -> spv_result_t {
-                return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                       << _.VkErrorID(4337)
-                       << "According to the Vulkan spec BuiltIn PrimitiveId "
-                          "variable needs to be a 32-bit int scalar. "
-                       << message;
-              })) {
-        return error;
-      }
-    } else {
-      if (spv_result_t error = ValidateOptionalArrayedI32(
-              decoration, inst,
-              [this, &inst](const std::string& message) -> spv_result_t {
-                return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                       << _.VkErrorID(4337)
-                       << "According to the Vulkan spec BuiltIn PrimitiveId "
-                          "variable needs to be a 32-bit int scalar. "
-                       << message;
-              })) {
-        return error;
-      }
-    }
-
-    if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
-      if (isMeshInterfaceVar(inst) &&
-          !_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
-        return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-               << _.VkErrorID(7040)
-               << "According to the Vulkan spec the variable decorated with "
-                  "Builtin PrimitiveId within the MeshEXT Execution Model must "
-                  "also be decorated with the PerPrimitiveEXT decoration. ";
-      }
-    }
-  }
-
-  // Seed at reference checks with this built-in.
   return ValidatePrimitiveIdAtReference(decoration, inst, inst, inst);
 }
 
@@ -2297,6 +2354,27 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference(
           referenced_from_inst, std::placeholders::_1));
     }
 
+    if (!_.HasCapability(spv::Capability::MeshShadingEXT) &&
+        !_.HasCapability(spv::Capability::MeshShadingNV) &&
+        !_.HasCapability(spv::Capability::Geometry) &&
+        !_.HasCapability(spv::Capability::Tessellation)) {
+      id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+          &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, 4333,
+          "Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for "
+          "variables in the Fragment execution model unless it declares "
+          "Geometry, Tessellation, or MeshShader capabilities.",
+          spv::ExecutionModel::Fragment, decoration, built_in_inst,
+          referenced_from_inst, std::placeholders::_1));
+    }
+
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+        &BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
+        built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
+
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+        std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
+                  decoration, built_in_inst, std::placeholders::_1));
+
     for (const spv::ExecutionModel execution_model : execution_models_) {
       switch (execution_model) {
         case spv::ExecutionModel::Fragment:
@@ -2593,6 +2671,13 @@ spv_result_t BuiltInsValidator::ValidateTessLevelOuterAtDefinition(
             })) {
       return error;
     }
+
+    if (!_.HasDecoration(inst.id(), spv::Decoration::Patch)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << _.VkErrorID(10880)
+             << "BuiltIn TessLevelOuter variable needs to also have a Patch "
+                "decoration.";
+    }
   }
 
   // Seed at reference checks with this built-in.
@@ -2607,13 +2692,20 @@ spv_result_t BuiltInsValidator::ValidateTessLevelInnerAtDefinition(
             [this, &inst](const std::string& message) -> spv_result_t {
               return _.diag(SPV_ERROR_INVALID_DATA, &inst)
                      << _.VkErrorID(4397)
-                     << "According to the Vulkan spec BuiltIn TessLevelOuter "
+                     << "According to the Vulkan spec BuiltIn TessLevelInner "
                         "variable needs to be a 2-component 32-bit float "
                         "array. "
                      << message;
             })) {
       return error;
     }
+
+    if (!_.HasDecoration(inst.id(), spv::Decoration::Patch)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+             << _.VkErrorID(10880)
+             << "BuiltIn TessLevelInner variable needs to also have a Patch "
+                "decoration.";
+    }
   }
 
   // Seed at reference checks with this built-in.
@@ -2796,67 +2888,180 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtReference(
   return SPV_SUCCESS;
 }
 
-spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
-    const Decoration& decoration, const Instruction& inst) {
-  if (spvIsVulkanEnv(_.context()->target_env)) {
-    // This can be a per-primitive variable for mesh shader stage.
-    // In such cases variable will have an array of 32-bit integers.
-    if (decoration.struct_member_index() != Decoration::kInvalidMember) {
-      // This must be a 32-bit int scalar.
-      if (spv_result_t error = ValidateI32(
-              decoration, inst,
-              [this, &decoration,
-               &inst](const std::string& message) -> spv_result_t {
-                uint32_t vuid =
-                    (decoration.builtin() == spv::BuiltIn::Layer) ? 4276 : 4408;
+typedef struct {
+  uint32_t array_type;
+  uint32_t array_size;
+  uint32_t block_array_size;
+  uint32_t perprim_deco;
+} MeshBuiltinVUIDs;
+
+spv_result_t BuiltInsValidator::ValidateMeshBuiltinInterfaceRules(
+    const Decoration& decoration, const Instruction& inst, spv::Op scalar_type,
+    const Instruction& referenced_from_inst) {
+  if (function_id_) {
+    if (execution_models_.count(spv::ExecutionModel::MeshEXT)) {
+      bool is_block = false;
+      const spv::BuiltIn builtin = decoration.builtin();
+
+      static const std::unordered_map<spv::BuiltIn, MeshBuiltinVUIDs>
+          mesh_vuid_map = {{
+              {spv::BuiltIn::CullPrimitiveEXT, {7036, 10589, 10590, 7038}},
+              {spv::BuiltIn::PrimitiveId, {10595, 10596, 10597, 7040}},
+              {spv::BuiltIn::Layer, {10592, 10593, 10594, 7039}},
+              {spv::BuiltIn::ViewportIndex, {10601, 10602, 10603, 7060}},
+              {spv::BuiltIn::PrimitiveShadingRateKHR,
+               {10598, 10599, 10600, 7059}},
+          }};
+      const MeshBuiltinVUIDs& vuids = mesh_vuid_map.at(builtin);
+      if (spv_result_t error = ValidateBlockTypeOrArrayedType(
+              decoration, inst, is_block, scalar_type,
+              [this, &inst, &builtin, &scalar_type,
+               &vuids](const std::string& message) -> spv_result_t {
                 return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                       << _.VkErrorID(vuid)
-                       << "According to the Vulkan spec BuiltIn "
+                       << _.VkErrorID(vuids.array_type)
+                       << "According to the Vulkan specspec BuiltIn "
                        << _.grammar().lookupOperandName(
-                              SPV_OPERAND_TYPE_BUILT_IN,
-                              (uint32_t)decoration.builtin())
-                       << "variable needs to be a 32-bit int scalar. "
-                       << message;
+                              SPV_OPERAND_TYPE_BUILT_IN, (uint32_t)builtin)
+                       << " variable needs to be a either a "
+                       << spvOpcodeString(scalar_type)
+                       << " or an "
+                          "array of "
+                       << spvOpcodeString(scalar_type) << ". " << message;
               })) {
         return error;
       }
-    } else {
-      if (spv_result_t error = ValidateOptionalArrayedI32(
+
+      if (!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+               << _.VkErrorID(vuids.perprim_deco)
+               << "According to the Vulkan spec the variable decorated with "
+                  "Builtin "
+               << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                (uint32_t)builtin)
+               << " within the MeshEXT Execution Model must also be "
+               << "decorated with the PerPrimitiveEXT decoration. ";
+      }
+
+      // These builtin have the ability to be an array with MeshEXT
+      // When an array, we need to make sure the array size lines up
+      std::map<uint32_t, uint32_t> entry_interface_id_map;
+      bool found = IsMeshInterfaceVar(inst, entry_interface_id_map);
+      if (found) {
+        for (const auto& id : entry_interface_id_map) {
+          uint32_t entry_point_id = id.first;
+          uint32_t interface_var_id = id.second;
+
+          const uint64_t interface_size = GetArrayLength(interface_var_id);
+          const uint32_t output_prim_size =
+              _.GetOutputPrimitivesEXT(entry_point_id);
+          if (interface_size != output_prim_size) {
+            return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                   << _.VkErrorID(is_block ? vuids.block_array_size
+                                           : vuids.array_size)
+                   << " The size of the array decorated with "
+                   << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                    (uint32_t)builtin)
+                   << " (" << interface_size
+                   << ") must match the value specified by OutputPrimitivesEXT "
+                      "("
+                   << output_prim_size << "). ";
+          }
+        }
+      }
+    }
+  } else {
+    // Propagate this rule to all dependant ids in the global scope.
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+        std::bind(&BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this,
+                  decoration, inst, scalar_type, std::placeholders::_1));
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules(
+    const Decoration& decoration, const Instruction& inst,
+    const Instruction& referenced_from_inst) {
+  if (function_id_) {
+    if (!execution_models_.count(spv::ExecutionModel::MeshEXT)) {
+      if (spv_result_t error = ValidateI32(
               decoration, inst,
-              [this, &decoration,
-               &inst](const std::string& message) -> spv_result_t {
-                uint32_t vuid =
-                    (decoration.builtin() == spv::BuiltIn::Layer) ? 4276 : 4408;
+              [this, &inst,
+               &decoration](const std::string& message) -> spv_result_t {
                 return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                       << _.VkErrorID(vuid)
+                       << _.VkErrorID(4486)
                        << "According to the Vulkan spec BuiltIn "
                        << _.grammar().lookupOperandName(
                               SPV_OPERAND_TYPE_BUILT_IN,
                               (uint32_t)decoration.builtin())
-                       << "variable needs to be a 32-bit int scalar. "
+                       << " variable needs to be a 32-bit int scalar. "
                        << message;
               })) {
         return error;
       }
     }
+  } else {
+    // Propagate this rule to all dependant ids in the global scope.
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+        &BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules, this,
+        decoration, inst, std::placeholders::_1));
+  }
+  return SPV_SUCCESS;
+}
 
-    if (isMeshInterfaceVar(inst) &&
-        _.HasCapability(spv::Capability::MeshShadingEXT) &&
-        !_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
-      const spv::BuiltIn label = spv::BuiltIn(decoration.params()[0]);
-      uint32_t vkerrid = (label == spv::BuiltIn::Layer) ? 7039 : 7060;
-      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-             << _.VkErrorID(vkerrid)
-             << "According to the Vulkan spec the variable decorated with "
-                "Builtin "
-             << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
-                                              decoration.params()[0])
-             << " within the MeshEXT Execution Model must also be decorated "
-                "with the PerPrimitiveEXT decoration. ";
+// For Layer, ViewportIndex, and PrimitiveId
+spv_result_t BuiltInsValidator::ValidateNonMeshInterfaceRules(
+    const Decoration& decoration, const Instruction& inst,
+    const Instruction& referenced_from_inst) {
+  if (function_id_) {
+    // This can be a per-primitive variable for NV mesh shader stage.
+    // In such cases variable will have an array of 32-bit integers.
+    if (!execution_models_.count(spv::ExecutionModel::MeshEXT)) {
+      const spv::BuiltIn builtin = decoration.builtin();
+      const uint32_t vuid = (builtin == spv::BuiltIn::Layer)           ? 4276
+                            : (builtin == spv::BuiltIn::ViewportIndex) ? 4408
+                                                                       : 4337;
+      if (decoration.struct_member_index() != Decoration::kInvalidMember) {
+        if (spv_result_t error = ValidateI32(
+                decoration, inst,
+                [this, &vuid, builtin,
+                 &inst](const std::string& message) -> spv_result_t {
+                  return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                         << _.VkErrorID(vuid)
+                         << "According to the Vulkan spec BuiltIn "
+                         << _.grammar().lookupOperandName(
+                                SPV_OPERAND_TYPE_BUILT_IN, (uint32_t)builtin)
+                         << "variable needs to be a 32-bit int scalar. "
+                         << message;
+                })) {
+          return error;
+        }
+      } else if (spv_result_t error = ValidateOptionalArrayedI32(
+                     decoration, inst,
+                     [this, &vuid, builtin,
+                      &inst](const std::string& message) -> spv_result_t {
+                       return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                              << _.VkErrorID(vuid)
+                              << "According to the Vulkan spec BuiltIn "
+                              << _.grammar().lookupOperandName(
+                                     SPV_OPERAND_TYPE_BUILT_IN,
+                                     (uint32_t)builtin)
+                              << "variable needs to be a 32-bit int scalar. "
+                              << message;
+                     })) {
+        return error;
+      }
     }
+  } else {
+    // Propagate this rule to all dependant ids in the global scope.
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+        std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
+                  decoration, inst, std::placeholders::_1));
   }
+  return SPV_SUCCESS;
+}
 
-  // Seed at reference checks with this built-in.
+spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
+    const Decoration& decoration, const Instruction& inst) {
   return ValidateLayerOrViewportIndexAtReference(decoration, inst, inst, inst);
 }
 
@@ -2914,6 +3119,14 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference(
                     referenced_from_inst, std::placeholders::_1));
     }
 
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+        &BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
+        built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
+
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
+        std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
+                  decoration, built_in_inst, std::placeholders::_1));
+
     for (const spv::ExecutionModel execution_model : execution_models_) {
       switch (execution_model) {
         case spv::ExecutionModel::Geometry:
@@ -3338,12 +3551,47 @@ spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition(
       bool static_x = _.EvalConstantValUint64(inst.word(3), &x_size);
       bool static_y = _.EvalConstantValUint64(inst.word(4), &y_size);
       bool static_z = _.EvalConstantValUint64(inst.word(5), &z_size);
-      if (static_x && static_y && static_z &&
-          ((x_size * y_size * z_size) == 0)) {
-        return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-               << "WorkgroupSize decorations must not have a static "
-                  "product of zero (X = "
-               << x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
+      if (static_x && static_y && static_z) {
+        const uint64_t product_size = x_size * y_size * z_size;
+        if (product_size == 0) {
+          return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                 << "WorkgroupSize decorations must not have a static "
+                    "product of zero (X = "
+                 << x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
+        }
+
+        // If there is a known static workgroup size, all entrypoints with
+        // explicit derivative execution modes can be validated. These are only
+        // found in execution models that support explicit workgroup sizes
+        for (const uint32_t entry_point : _.entry_points()) {
+          const auto* modes = _.GetExecutionModes(entry_point);
+          if (!modes) continue;
+          if (modes->count(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
+            if (x_size % 2 != 0 || y_size % 2 != 0) {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << _.VkErrorID(10151)
+                     << "WorkgroupSize decorations has a static dimensions of "
+                        "(X = "
+                     << x_size << ", Y = " << y_size << ") but Entry Point id "
+                     << entry_point
+                     << " has an DerivativeGroupQuadsKHR execution mode, so "
+                        "both dimensions must be a multiple of 2";
+            }
+          }
+          if (modes->count(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
+            if (product_size % 4 != 0) {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << _.VkErrorID(10152)
+                     << "WorkgroupSize decorations has a static dimensions of "
+                        "(X = "
+                     << x_size << ", Y = " << y_size << ", Z = " << z_size
+                     << ") but Entry Point id " << entry_point
+                     << " has an DerivativeGroupLinearKHR execution mode, so "
+                        "the product ("
+                     << product_size << ") must be a multiple of 4";
+            }
+          }
+        }
       }
     }
   }
@@ -3986,34 +4234,6 @@ spv_result_t BuiltInsValidator::ValidateNVSMOrARMCoreBuiltinsAtReference(
 
 spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateAtDefinition(
     const Decoration& decoration, const Instruction& inst) {
-  if (spvIsVulkanEnv(_.context()->target_env)) {
-    if (spv_result_t error = ValidateI32(
-            decoration, inst,
-            [this, &inst,
-             &decoration](const std::string& message) -> spv_result_t {
-              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                     << _.VkErrorID(4486)
-                     << "According to the Vulkan spec BuiltIn "
-                     << _.grammar().lookupOperandName(
-                            SPV_OPERAND_TYPE_BUILT_IN,
-                            (uint32_t)decoration.builtin())
-                     << " variable needs to be a 32-bit int scalar. "
-                     << message;
-            })) {
-      return error;
-    }
-    if (isMeshInterfaceVar(inst) &&
-        _.HasCapability(spv::Capability::MeshShadingEXT) &&
-        !_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-             << _.VkErrorID(7059)
-             << "The variable decorated with PrimitiveShadingRateKHR "
-                "within the MeshEXT Execution Model must also be "
-                "decorated with the PerPrimitiveEXT decoration";
-    }
-  }
-
-  // Seed at reference checks with this built-in.
   return ValidatePrimitiveShadingRateAtReference(decoration, inst, inst, inst);
 }
 
@@ -4035,6 +4255,14 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateAtReference(
              << " " << GetStorageClassDesc(referenced_from_inst);
     }
 
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+        &BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
+        built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
+
+    id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+        &BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules, this,
+        decoration, built_in_inst, std::placeholders::_1));
+
     for (const spv::ExecutionModel execution_model : execution_models_) {
       switch (execution_model) {
         case spv::ExecutionModel::Vertex:
@@ -4365,48 +4593,61 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
           return error;
         }
         break;
-      case spv::BuiltIn::CullPrimitiveEXT:
-        if (spv_result_t error = ValidateBlockBoolOrArrayedBool(
-                decoration, inst,
-                [this, &inst, &decoration,
-                 &vuid](const std::string& message) -> spv_result_t {
-                  return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                         << _.VkErrorID(vuid) << "According to the "
-                         << spvLogStringForEnv(_.context()->target_env)
-                         << " spec BuiltIn "
-                         << _.grammar().lookupOperandName(
-                                SPV_OPERAND_TYPE_BUILT_IN,
-                                (uint32_t)decoration.builtin())
-                         << " variable needs to be a either a boolean or an "
-                            "array of booleans."
-                         << message;
-                })) {
+      case spv::BuiltIn::CullPrimitiveEXT: {
+        // We know this only allowed for Mesh Execution Model
+        if (spv_result_t error = ValidateMeshBuiltinInterfaceRules(
+                decoration, inst, spv::Op::OpTypeBool, inst)) {
           return error;
         }
-        if (!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
-          return _.diag(SPV_ERROR_INVALID_DATA, &inst)
-                 << _.VkErrorID(7038)
-                 << "The variable decorated with CullPrimitiveEXT within the "
-                    "MeshEXT Execution Model must also be decorated with the "
-                    "PerPrimitiveEXT decoration ";
+
+        for (const uint32_t entry_point : _.entry_points()) {
+          auto* models = _.GetExecutionModels(entry_point);
+          if (models->find(spv::ExecutionModel::MeshEXT) == models->end() &&
+              models->find(spv::ExecutionModel::MeshNV) == models->end()) {
+            continue;
+          }
+
+          if (IsBulitinInEntryPoint(inst, entry_point)) {
+            if (cull_primitive_entry_points_.find(entry_point) !=
+                cull_primitive_entry_points_.end()) {
+              return _.diag(SPV_ERROR_INVALID_DATA, &inst)
+                     << _.VkErrorID(10591)
+                     << "There must be only one declaration of the "
+                        "CullPrimitiveEXT associated in entry point's "
+                        "interface. "
+                     << GetIdDesc(*_.FindDef(entry_point));
+            } else {
+              cull_primitive_entry_points_.insert(entry_point);
+            }
+          }
         }
+
         break;
+      }
       default:
         assert(0 && "Unexpected mesh EXT builtin");
     }
     for (const uint32_t entry_point : _.entry_points()) {
+      // execution modes and builtin are both global, so only check these
+      // buildit definitions if we know the entrypoint is Mesh
+      auto* models = _.GetExecutionModels(entry_point);
+      if (models->find(spv::ExecutionModel::MeshEXT) == models->end() &&
+          models->find(spv::ExecutionModel::MeshNV) == models->end()) {
+        continue;
+      }
+
       const auto* modes = _.GetExecutionModes(entry_point);
-      uint64_t maxOutputPrimitives = _.GetOutputPrimitivesEXT(entry_point);
+      uint64_t max_output_primitives = _.GetOutputPrimitivesEXT(entry_point);
       uint32_t underlying_type = 0;
       if (spv_result_t error =
               GetUnderlyingType(_, decoration, inst, &underlying_type)) {
         return error;
       }
 
-      uint64_t primitiveArrayDim = 0;
+      uint64_t primitive_array_dim = 0;
       if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
         underlying_type = _.FindDef(underlying_type)->word(3u);
-        if (!_.EvalConstantValUint64(underlying_type, &primitiveArrayDim)) {
+        if (!_.EvalConstantValUint64(underlying_type, &primitive_array_dim)) {
           assert(0 && "Array type definition is corrupt");
         }
       }
@@ -4419,7 +4660,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
                       "with "
                       "the OutputPoints Execution Mode. ";
           }
-          if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
+          if (primitive_array_dim &&
+              primitive_array_dim != max_output_primitives) {
             return _.diag(SPV_ERROR_INVALID_DATA, &inst)
                    << _.VkErrorID(7046)
                    << "The size of the array decorated with "
@@ -4435,7 +4677,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
                       "with "
                       "the OutputLinesEXT Execution Mode. ";
           }
-          if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
+          if (primitive_array_dim &&
+              primitive_array_dim != max_output_primitives) {
             return _.diag(SPV_ERROR_INVALID_DATA, &inst)
                    << _.VkErrorID(7052)
                    << "The size of the array decorated with "
@@ -4451,7 +4694,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
                       "with "
                       "the OutputTrianglesEXT Execution Mode. ";
           }
-          if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
+          if (primitive_array_dim &&
+              primitive_array_dim != max_output_primitives) {
             return _.diag(SPV_ERROR_INVALID_DATA, &inst)
                    << _.VkErrorID(7058)
                    << "The size of the array decorated with "
@@ -4692,6 +4936,7 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinitionVulkan(
     case spv::BuiltIn::CullMaskKHR: {
       return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
     }
+    // These are only for Mesh, not Task execution model
     case spv::BuiltIn::CullPrimitiveEXT:
     case spv::BuiltIn::PrimitivePointIndicesEXT:
     case spv::BuiltIn::PrimitiveLineIndicesEXT:

+ 12 - 5
3rdparty/spirv-tools/source/val/validate_capability.cpp

@@ -345,11 +345,18 @@ bool IsEnabledByCapabilityOpenCL_2_0(ValidationState_t& _,
 // Validates that capability declarations use operands allowed in the current
 // context.
 spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst) {
-  if (inst->opcode() != spv::Op::OpCapability) return SPV_SUCCESS;
-
-  assert(inst->operands().size() == 1);
-
-  const spv_parsed_operand_t& operand = inst->operand(0);
+  if (inst->opcode() != spv::Op::OpCapability &&
+      inst->opcode() != spv::Op::OpConditionalCapabilityINTEL)
+    return SPV_SUCCESS;
+
+  assert(!((inst->opcode() == spv::Op::OpCapability) ^
+           (inst->operands().size() == 1)));
+  assert(!((inst->opcode() == spv::Op::OpConditionalCapabilityINTEL) ^
+           (inst->operands().size() == 2)));
+
+  const uint32_t i_cap =
+      inst->opcode() == spv::Op::OpConditionalCapabilityINTEL ? 1 : 0;
+  const spv_parsed_operand_t& operand = inst->operand(i_cap);
 
   assert(operand.num_words == 1);
   assert(operand.offset < inst->words().size());

+ 465 - 1
3rdparty/spirv-tools/source/val/validate_composites.cpp

@@ -16,6 +16,8 @@
 
 // Validates correctness of composite SPIR-V instructions.
 
+#include <climits>
+
 #include "source/opcode.h"
 #include "source/spirv_target_env.h"
 #include "source/val/instruction.h"
@@ -618,8 +620,464 @@ spv_result_t ValidateCopyLogical(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
-}  // anonymous namespace
+spv_result_t ValidateCompositeConstructCoopMatQCOM(ValidationState_t& _,
+                                                   const Instruction* inst) {
+  // Is the result of coop mat ?
+  const auto result_type_inst = _.FindDef(inst->type_id());
+  if (!result_type_inst ||
+      result_type_inst->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the result type be OpTypeCooperativeMatrixKHR";
+  }
+
+  const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
+  const auto source_type_inst = _.FindDef(source->type_id());
+
+  if (!source_type_inst || source_type_inst->opcode() != spv::Op::OpTypeArray) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the input operand be an OpTypeArray.";
+  }
+
+  // Is the scope Subgrouop ?
+  {
+    unsigned scope = UINT_MAX;
+    unsigned scope_id = result_type_inst->GetOperandAs<unsigned>(2u);
+    bool status = _.GetConstantValueAs<unsigned>(scope_id, scope);
+    bool is_scope_spec_const =
+        spvOpcodeIsSpecConstant(_.FindDef(scope_id)->opcode());
+    if (!is_scope_spec_const &&
+        (!status || scope != static_cast<uint64_t>(spv::Scope::Subgroup))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Opcode " << spvOpcodeString(inst->opcode())
+             << " requires the result type's scope be Subgroup.";
+    }
+  }
+
+  unsigned ar_len = UINT_MAX;
+  unsigned src_arr_len_id = source_type_inst->GetOperandAs<unsigned>(2u);
+  bool ar_len_status = _.GetConstantValueAs<unsigned>(src_arr_len_id, ar_len);
+  bool is_src_arr_len_spec_const =
+      spvOpcodeIsSpecConstant(_.FindDef(src_arr_len_id)->opcode());
+
+  const auto source_elt_type = _.GetComponentType(source_type_inst->id());
+  const auto result_elt_type = result_type_inst->GetOperandAs<uint32_t>(1u);
+
+  if ((source_elt_type != result_elt_type) &&
+      !(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
+                                      32) &&
+        _.IsUnsignedIntScalarType(source_elt_type))) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires ether the input element type is equal to the result "
+              "element type or it is the unsigned 32-bit integer.";
+  }
+
+  unsigned res_row_id = result_type_inst->GetOperandAs<unsigned>(3u);
+  unsigned res_col_id = result_type_inst->GetOperandAs<unsigned>(4u);
+  unsigned res_use_id = result_type_inst->GetOperandAs<unsigned>(5u);
+
+  unsigned cm_use = UINT_MAX;
+  bool cm_use_status = _.GetConstantValueAs<unsigned>(res_use_id, cm_use);
+
+  switch (static_cast<spv::CooperativeMatrixUse>(cm_use)) {
+    case spv::CooperativeMatrixUse::MatrixAKHR: {
+      // result coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<8>(result_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the result element type is one of 8-bit OpTypeInt "
+                  "signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when result coopmat's use is MatrixAKHR";
+      }
+
+      // result coopmat column length check
+      unsigned n_cols = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(res_col_id, n_cols);
+      bool is_res_col_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(res_col_id)->opcode());
+      if (!is_res_col_spec_const &&
+          (!status || (!(_.ContainsSizedIntOrFloatType(result_elt_type,
+                                                       spv::Op::OpTypeInt, 8) &&
+                         n_cols == 32) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             result_elt_type, spv::Op::OpTypeFloat, 16) &&
+                         n_cols == 16) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             result_elt_type, spv::Op::OpTypeFloat, 32) &&
+                         n_cols == 8)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the columns of the result coopmat have the bit "
+                  "length of 256"
+               << " when result coopmat's use is MatrixAKHR";
+      }
+      // source array length check
+      if (!is_src_arr_len_spec_const &&
+          (!ar_len_status ||
+           (!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
+                                            32) &&
+              _.IsUnsignedIntScalarType(source_elt_type) && (ar_len == 8)) &&
+            !(n_cols == ar_len)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source array length be 8 if its elt type is "
+                  "32-bit unsigned OpTypeInt and be the result's number of "
+                  "columns, otherwise"
+               << " when result coopmat's use is MatrixAKHR";
+      }
+      break;
+    }
+    case spv::CooperativeMatrixUse::MatrixBKHR: {
+      // result coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<8>(result_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the result element type is one of 8-bit OpTypeInt "
+                  "signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when result coopmat's use is MatrixBKHR";
+      }
+
+      // result coopmat row length check
+      unsigned n_rows = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(res_row_id, n_rows);
+      bool is_res_row_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(res_row_id)->opcode());
+      if (!is_res_row_spec_const &&
+          (!status || (!(_.ContainsSizedIntOrFloatType(result_elt_type,
+                                                       spv::Op::OpTypeInt, 8) &&
+                         n_rows == 32) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             result_elt_type, spv::Op::OpTypeFloat, 16) &&
+                         n_rows == 16) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             result_elt_type, spv::Op::OpTypeFloat, 32) &&
+                         n_rows == 8)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the rows of the result operand have the bit "
+                  "length of 256"
+               << " when result coopmat's use is MatrixBKHR";
+      }
+      // source array length check
+      if (!is_src_arr_len_spec_const &&
+          (!ar_len_status ||
+           (!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
+                                            32) &&
+              _.IsUnsignedIntScalarType(source_elt_type) && (ar_len == 8)) &&
+            !(n_rows == ar_len)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source array length be 8 if its elt type is "
+                  "32-bit unsigned OpTypeInt and be the result's number of "
+                  "rows, otherwise"
+               << " when result coopmat's use is MatrixBKHR";
+      }
+      break;
+    }
+    case spv::CooperativeMatrixUse::MatrixAccumulatorKHR: {
+      // result coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<32>(result_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the result element type is one of 32-bit "
+                  "OpTypeInt signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when result coopmat's use is MatrixAccumulatorKHR";
+      }
+
+      // source array length check
+      unsigned n_cols = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(res_col_id, n_cols);
+      bool is_res_col_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(res_col_id)->opcode());
+      if (!is_res_col_spec_const && !is_src_arr_len_spec_const &&
+          (!status || !ar_len_status ||
+           (!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
+                                            32) &&
+              _.IsUnsignedIntScalarType(source_elt_type) &&
+              (_.ContainsSizedIntOrFloatType(result_elt_type,
+                                             spv::Op::OpTypeFloat, 16)
+                   ? (n_cols / 2 == ar_len)
+                   : n_cols == ar_len)) &&
+            (n_cols != ar_len)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source array length be a half of the number "
+                  "of columns of the resulting cooerative matrix if the "
+                  "matrix's componet type is 16-bit OpTypeFloat and be equal "
+                  "to the number of columns, otherwise,"
+               << " when result coopmat's use is MatrixAccumulatorKHR";
+      }
+      break;
+    }
+    default: {
+      bool is_cm_use_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(res_use_id)->opcode());
+      if (!is_cm_use_spec_const || !cm_use_status) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the the resulting cooerative matrix's use be "
+               << " one of MatrixAKHR (== 0), MatrixBKHR (== 1), and "
+                  "MatrixAccumulatorKHR (== 2)";
+      }
+      break;
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateCompositeExtractCoopMatQCOM(ValidationState_t& _,
+                                                 const Instruction* inst) {
+  const auto result_type_inst = _.FindDef(inst->type_id());
+  if (!result_type_inst || result_type_inst->opcode() != spv::Op::OpTypeArray) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the input operand be an OpTypeArray.";
+  }
+
+  const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
+  const auto source_type_inst = _.FindDef(source->type_id());
+
+  // Is the source of coop mat ?
+  if (!source_type_inst ||
+      source_type_inst->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the source type be OpTypeCooperativeMatrixKHR";
+  }
 
+  // Is the scope Subgrouop ?
+  {
+    unsigned scope = UINT_MAX;
+    unsigned scope_id = source_type_inst->GetOperandAs<unsigned>(2u);
+    bool status = _.GetConstantValueAs<unsigned>(scope_id, scope);
+    bool is_scope_spec_const =
+        spvOpcodeIsSpecConstant(_.FindDef(scope_id)->opcode());
+    if (!is_scope_spec_const &&
+        (!status || scope != static_cast<uint64_t>(spv::Scope::Subgroup))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Opcode " << spvOpcodeString(inst->opcode())
+             << " requires the source type's scope be Subgroup.";
+    }
+  }
+
+  unsigned ar_len = UINT_MAX;
+  unsigned res_arr_len_id = result_type_inst->GetOperandAs<unsigned>(2u);
+  bool ar_len_status = _.GetConstantValueAs<unsigned>(res_arr_len_id, ar_len);
+  bool is_res_arr_len_spec_const =
+      spvOpcodeIsSpecConstant(_.FindDef(res_arr_len_id)->opcode());
+
+  const auto source_elt_type = _.GetComponentType(source_type_inst->id());
+  const auto result_elt_type = result_type_inst->GetOperandAs<uint32_t>(1u);
+
+  unsigned src_row_id = source_type_inst->GetOperandAs<unsigned>(3u);
+  unsigned src_col_id = source_type_inst->GetOperandAs<unsigned>(4u);
+  unsigned src_use_id = source_type_inst->GetOperandAs<unsigned>(5u);
+
+  unsigned cm_use = UINT_MAX;
+  bool cm_use_status = _.GetConstantValueAs<unsigned>(src_use_id, cm_use);
+
+  switch (static_cast<spv::CooperativeMatrixUse>(cm_use)) {
+    case spv::CooperativeMatrixUse::MatrixAKHR: {
+      // source coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<8>(source_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source element type be one of 8-bit OpTypeInt "
+                  "signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when source coopmat's use is MatrixAKHR";
+      }
+
+      // source coopmat column length check
+      unsigned n_cols = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(src_col_id, n_cols);
+      bool is_src_col_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(src_col_id)->opcode());
+      if (!is_src_col_spec_const &&
+          (!status || (!(_.ContainsSizedIntOrFloatType(source_elt_type,
+                                                       spv::Op::OpTypeInt, 8) &&
+                         n_cols == 32) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             source_elt_type, spv::Op::OpTypeFloat, 16) &&
+                         n_cols == 16) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             source_elt_type, spv::Op::OpTypeFloat, 32) &&
+                         n_cols == 8)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the columns of the source coopmat have the bit "
+                  "length of 256"
+               << " when source coopmat's use is MatrixAKHR";
+      }
+      // result type check
+      if (!is_res_arr_len_spec_const &&
+          !(source_elt_type == result_elt_type && (n_cols == ar_len)) &&
+          !(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
+                                          32) &&
+            _.IsUnsignedIntScalarType(result_elt_type) && (ar_len == 8))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires either the result element type be the same as the "
+                  "source cooperative matrix's component type"
+               << " and its length be the same as the number of columns of the "
+                  "matrix or the result element type be"
+               << " unsigned 32-bit OpTypeInt and the length be 8"
+               << " when source coopmat's use is MatrixAKHR";
+      }
+      break;
+    }
+    case spv::CooperativeMatrixUse::MatrixBKHR: {
+      // source coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<8>(source_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source element type be one of 8-bit OpTypeInt "
+                  "signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when source coopmat's use is MatrixBKHR";
+      }
+
+      // source coopmat row length check
+      unsigned n_rows = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(src_row_id, n_rows);
+      bool is_src_row_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(src_row_id)->opcode());
+      if (!is_src_row_spec_const &&
+          (!status || (!(_.ContainsSizedIntOrFloatType(source_elt_type,
+                                                       spv::Op::OpTypeInt, 8) &&
+                         n_rows == 32) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             source_elt_type, spv::Op::OpTypeFloat, 16) &&
+                         n_rows == 16) &&
+                       !(_.ContainsSizedIntOrFloatType(
+                             source_elt_type, spv::Op::OpTypeFloat, 32) &&
+                         n_rows == 8)))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the rows of the source coopmat have the bit "
+                  "length of 256"
+               << " when source coopmat's use is MatrixBKHR";
+      }
+      // result type check
+      if (!is_res_arr_len_spec_const &&
+          !(source_elt_type == result_elt_type && (n_rows == ar_len)) &&
+          !(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
+                                          32) &&
+            _.IsUnsignedIntScalarType(result_elt_type) && (ar_len == 8))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires either the result element type be the same as the "
+                  "source cooperative matrix's component type"
+               << " and its length be the same as the number of rows of the "
+                  "matrix or the result element type be"
+               << " unsigned 32-bit OpTypeInt and the length be 8"
+               << " when source coopmat's use is MatrixBKHR";
+      }
+      break;
+    }
+    case spv::CooperativeMatrixUse::MatrixAccumulatorKHR: {
+      // source coopmat component type check
+      if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source element type be one of 32-bit "
+                  "OpTypeInt signed/unsigned, 16- or 32-bit OpTypeFloat"
+               << " when source coopmat's use is MatrixAccumulatorKHR";
+      }
+
+      // result type check
+      unsigned n_cols = UINT_MAX;
+      bool status = _.GetConstantValueAs<unsigned>(src_col_id, n_cols);
+      bool is_src_col_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(src_col_id)->opcode());
+      if (!is_src_col_spec_const && !is_res_arr_len_spec_const &&
+          (!status || !ar_len_status ||
+           (!(source_elt_type == result_elt_type && (n_cols == ar_len)) &&
+            !(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
+                                            32) &&
+              _.IsUnsignedIntScalarType(result_elt_type) &&
+              (_.ContainsSizedIntOrFloatType(source_elt_type,
+                                             spv::Op::OpTypeFloat, 16)
+                   ? (n_cols / 2 == ar_len)
+                   : (n_cols == ar_len)))))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires either the result element type be the same as the "
+                  "source cooperative matrix's component type"
+               << " and its length be the same as the number of columns of the "
+                  "matrix or the result element type be"
+               << " unsigned 32-bit OpTypeInt and the length be the number of "
+                  "the columns of the matrix if its component"
+               << " type is 32-bit OpTypeFloat and be a half of the number of "
+                  "the columns of the matrix if its component"
+               << " type is 16-bit OpTypeFloat"
+               << " when source coopmat's use is MatrixAccumulatorKHR";
+      }
+      break;
+    }
+    default: {
+      bool is_cm_use_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(src_use_id)->opcode());
+      if (!is_cm_use_spec_const || !cm_use_status) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the the source cooerative matrix's use be "
+               << " one of MatrixAKHR (== 0), MatrixBKHR (== 1), and "
+                  "MatrixAccumulatorKHR (== 2)";
+      }
+      break;
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateExtractSubArrayQCOM(ValidationState_t& _,
+                                         const Instruction* inst) {
+  const auto result_type_inst = _.FindDef(inst->type_id());
+  const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
+  const auto source_type_inst = _.FindDef(source->type_id());
+
+  // Are the input and the result arrays?
+  if (result_type_inst->opcode() != spv::Op::OpTypeArray ||
+      source_type_inst->opcode() != spv::Op::OpTypeArray) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires OpTypeArray operands for the input and the result.";
+  }
+
+  const auto source_elt_type = _.GetComponentType(source_type_inst->id());
+  const auto result_elt_type = _.GetComponentType(result_type_inst->id());
+
+  // Do the input and result element types match?
+  if (source_elt_type != result_elt_type) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the input and result element types match.";
+  }
+
+  // Elt type must be one of int32_t/uint32_t/float32/float16
+  if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the element type be one of 32-bit OpTypeInt "
+              "(signed/unsigned), 32-bit OpTypeFloat and 16-bit OpTypeFloat";
+  }
+
+  const auto start_index = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
+  if (!start_index || !_.ContainsSizedIntOrFloatType(start_index->type_id(),
+                                                     spv::Op::OpTypeInt, 32)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Opcode " << spvOpcodeString(inst->opcode())
+           << " requires the type of the start index operand be 32-bit "
+              "OpTypeInt";
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // anonymous namespace
 // Validates correctness of composite instructions.
 spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
   switch (inst->opcode()) {
@@ -641,6 +1099,12 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
       return ValidateTranspose(_, inst);
     case spv::Op::OpCopyLogical:
       return ValidateCopyLogical(_, inst);
+    case spv::Op::OpCompositeConstructCoopMatQCOM:
+      return ValidateCompositeConstructCoopMatQCOM(_, inst);
+    case spv::Op::OpCompositeExtractCoopMatQCOM:
+      return ValidateCompositeExtractCoopMatQCOM(_, inst);
+    case spv::Op::OpExtractSubArrayQCOM:
+      return ValidateExtractSubArrayQCOM(_, inst);
     default:
       break;
   }

+ 81 - 4
3rdparty/spirv-tools/source/val/validate_conversion.cpp

@@ -14,6 +14,8 @@
 
 // Validates correctness of conversion instructions.
 
+#include <climits>
+
 #include "source/opcode.h"
 #include "source/spirv_constant.h"
 #include "source/spirv_target_env.h"
@@ -572,26 +574,38 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
         if (result_is_pointer && !input_is_pointer && !input_is_int_scalar &&
             !(input_is_int_vector && input_has_int32))
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                 << "Expected input to be a pointer, int scalar or 32-bit int "
+                 << "In SPIR-V 1.5 or later (or with "
+                    "SPV_KHR_physical_storage_buffer), expected input to be a "
+                    "pointer, "
+                    "int scalar or 32-bit int "
                     "vector if Result Type is pointer: "
                  << spvOpcodeString(opcode);
 
         if (input_is_pointer && !result_is_pointer && !result_is_int_scalar &&
             !(result_is_int_vector && result_has_int32))
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                 << "Pointer can only be converted to another pointer, int "
+                 << "In SPIR-V 1.5 or later (or with "
+                    "SPV_KHR_physical_storage_buffer), pointer can only be "
+                    "converted to "
+                    "another pointer, int "
                     "scalar or 32-bit int vector: "
                  << spvOpcodeString(opcode);
       } else {
         if (result_is_pointer && !input_is_pointer && !input_is_int_scalar)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                 << "Expected input to be a pointer or int scalar if Result "
+                 << "In SPIR-V 1.4 or earlier (and without "
+                    "SPV_KHR_physical_storage_buffer), expected input to be a "
+                    "pointer "
+                    "or int scalar if Result "
                     "Type is pointer: "
                  << spvOpcodeString(opcode);
 
         if (input_is_pointer && !result_is_pointer && !result_is_int_scalar)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                 << "Pointer can only be converted to another pointer or int "
+                 << "In SPIR-V 1.4 or earlier (and without "
+                    "SPV_KHR_physical_storage_buffer), pointer can only be "
+                    "converted "
+                    "to another pointer or int "
                     "scalar: "
                  << spvOpcodeString(opcode);
       }
@@ -664,6 +678,69 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
       break;
     }
 
+    case spv::Op::OpBitCastArrayQCOM: {
+      const auto result_type_inst = _.FindDef(inst->type_id());
+      const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
+      const auto source_type_inst = _.FindDef(source->type_id());
+
+      // Are the input and the result arrays?
+      if (result_type_inst->opcode() != spv::Op::OpTypeArray ||
+          source_type_inst->opcode() != spv::Op::OpTypeArray) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires OpTypeArray operands for the input and the "
+                  "result.";
+      }
+
+      const auto source_elt_type = _.GetComponentType(source_type_inst->id());
+      const auto result_elt_type = _.GetComponentType(result_type_inst->id());
+
+      if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the source element type be one of 32-bit "
+                  "OpTypeInt "
+                  "(signed/unsigned), 32-bit OpTypeFloat and 16-bit "
+                  "OpTypeFloat";
+      }
+
+      if (!_.IsIntNOrFP32OrFP16<32>(result_elt_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires the result element type be one of 32-bit "
+                  "OpTypeInt "
+                  "(signed/unsigned), 32-bit OpTypeFloat and 16-bit "
+                  "OpTypeFloat";
+      }
+
+      unsigned src_arr_len_id = source_type_inst->GetOperandAs<unsigned>(2u);
+      unsigned res_arr_len_id = result_type_inst->GetOperandAs<unsigned>(2u);
+
+      // Are the input and result element types compatible?
+      unsigned src_arr_len = UINT_MAX, res_arr_len = UINT_MAX;
+      bool src_arr_len_status =
+          _.GetConstantValueAs<unsigned>(src_arr_len_id, src_arr_len);
+      bool res_arr_len_status =
+          _.GetConstantValueAs<unsigned>(res_arr_len_id, res_arr_len);
+
+      bool is_src_arr_len_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(src_arr_len_id)->opcode());
+      bool is_res_arr_len_spec_const =
+          spvOpcodeIsSpecConstant(_.FindDef(res_arr_len_id)->opcode());
+
+      unsigned source_bitlen = _.GetBitWidth(source_elt_type) * src_arr_len;
+      unsigned result_bitlen = _.GetBitWidth(result_elt_type) * res_arr_len;
+      if (!is_src_arr_len_spec_const && !is_res_arr_len_spec_const &&
+          (!src_arr_len_status || !res_arr_len_status ||
+           source_bitlen != result_bitlen)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Opcode " << spvOpcodeString(inst->opcode())
+               << " requires source and result types be compatible for "
+                  "conversion.";
+      }
+      break;
+    }
+
     default:
       break;
   }

+ 174 - 119
3rdparty/spirv-tools/source/val/validate_decorations.cpp

@@ -398,11 +398,29 @@ bool IsAlignedTo(uint32_t offset, uint32_t alignment) {
   return 0 == (offset % alignment);
 }
 
+std::string getStorageClassString(spv::StorageClass sc) {
+  switch (sc) {
+    case spv::StorageClass::Uniform:
+      return "Uniform";
+    case spv::StorageClass::UniformConstant:
+      return "UniformConstant";
+    case spv::StorageClass::PushConstant:
+      return "PushConstant";
+    case spv::StorageClass::Workgroup:
+      return "Workgroup";
+    case spv::StorageClass::PhysicalStorageBuffer:
+      return "PhysicalStorageBuffer";
+    default:
+      // Only other valid storage class in these checks
+      return "StorageBuffer";
+  }
+}
+
 // Returns SPV_SUCCESS if the given struct satisfies standard layout rules for
 // Block or BufferBlocks in Vulkan.  Otherwise emits a diagnostic and returns
 // something other than SPV_SUCCESS.  Matrices inherit the specified column
 // or row major-ness.
-spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
+spv_result_t checkLayout(uint32_t struct_id, spv::StorageClass storage_class,
                          const char* decoration_str, bool blockRules,
                          bool scalar_block_layout, uint32_t incoming_offset,
                          MemberConstraints& constraints,
@@ -418,22 +436,48 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
   // is more permissive than relaxed layout.
   const bool relaxed_block_layout = vstate.IsRelaxedBlockLayout();
 
-  auto fail = [&vstate, struct_id, storage_class_str, decoration_str,
-               blockRules, relaxed_block_layout,
+  auto fail = [&vstate, struct_id, storage_class, decoration_str, blockRules,
+               relaxed_block_layout,
                scalar_block_layout](uint32_t member_idx) -> DiagnosticStream {
-    DiagnosticStream ds =
-        std::move(vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id))
-                  << "Structure id " << struct_id << " decorated as "
-                  << decoration_str << " for variable in " << storage_class_str
-                  << " storage class must follow "
-                  << (scalar_block_layout
-                          ? "scalar "
-                          : (relaxed_block_layout ? "relaxed " : "standard "))
-                  << (blockRules ? "uniform buffer" : "storage buffer")
-                  << " layout rules: member " << member_idx << " ");
+    DiagnosticStream ds = std::move(
+        vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id))
+        << "Structure id " << struct_id << " decorated as " << decoration_str
+        << " for variable in " << getStorageClassString(storage_class)
+        << " storage class must follow "
+        << (scalar_block_layout
+                ? "scalar "
+                : (relaxed_block_layout ? "relaxed " : "standard "))
+        << (blockRules ? "uniform buffer" : "storage buffer")
+        << " layout rules: member " << member_idx << " ");
     return ds;
   };
 
+  // People often use spirv-val from Vulkan Validation Layers, it ends up
+  // mapping the various block layout rules from the enabled feature. This
+  // offers a hint to help the user understand possbily why things are not
+  // working when the shader itself "seems" valid, but just was a lack of adding
+  // a supported feature
+  auto extra = [&vstate, scalar_block_layout, storage_class,
+                relaxed_block_layout, blockRules]() {
+    if (!scalar_block_layout) {
+      if (storage_class == spv::StorageClass::Workgroup) {
+        return vstate.MissingFeature(
+            "workgroupMemoryExplicitLayoutScalarBlockLayout feature",
+            "--workgroup-scalar-block-layout", true);
+      } else if (!relaxed_block_layout) {
+        return vstate.MissingFeature("VK_KHR_relaxed_block_layout extension",
+                                     "--relax-block-layout", true);
+      } else if (blockRules) {
+        return vstate.MissingFeature("uniformBufferStandardLayout feature",
+                                     "--uniform-buffer-standard-layout", true);
+      } else {
+        return vstate.MissingFeature("scalarBlockLayout feature",
+                                     "--scalar-block-layout", true);
+      }
+    }
+    return std::string("");
+  };
+
   // If we are checking the layout of untyped pointers or physical storage
   // buffer pointers, we may not actually have a struct here. Instead, pretend
   // we have a struct with a single member at offset 0.
@@ -507,7 +551,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
     const auto size = getSize(id, constraint, constraints, vstate);
     // Check offset.
     if (offset == 0xffffffff)
-      return fail(memberIdx) << "is missing an Offset decoration";
+      return fail(memberIdx) << "is missing an Offset decoration" << extra();
 
     if (opcode == spv::Op::OpTypeRuntimeArray &&
         ordered_member_idx != member_offsets.size() - 1) {
@@ -524,42 +568,44 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
       const auto componentId = inst->words()[2];
       const auto scalar_alignment = getScalarAlignment(componentId, vstate);
       if (!IsAlignedTo(offset, scalar_alignment)) {
-        return fail(memberIdx)
-               << "at offset " << offset
-               << " is not aligned to scalar element size " << scalar_alignment;
+        return fail(memberIdx) << "at offset " << offset
+                               << " is not aligned to scalar element size "
+                               << scalar_alignment << extra();
       }
     } else {
       // Without relaxed block layout, the offset must be divisible by the
       // alignment requirement.
       if (!IsAlignedTo(offset, alignment)) {
-        return fail(memberIdx)
-               << "at offset " << offset << " is not aligned to " << alignment;
+        return fail(memberIdx) << "at offset " << offset
+                               << " is not aligned to " << alignment << extra();
       }
     }
     if (offset < nextValidOffset)
       return fail(memberIdx) << "at offset " << offset
                              << " overlaps previous member ending at offset "
-                             << nextValidOffset - 1;
+                             << nextValidOffset - 1 << extra();
     if (!scalar_block_layout && relaxed_block_layout) {
       // Check improper straddle of vectors.
       if (spv::Op::OpTypeVector == opcode &&
           hasImproperStraddle(id, offset, constraint, constraints, vstate))
         return fail(memberIdx)
-               << "is an improperly straddling vector at offset " << offset;
+               << "is an improperly straddling vector at offset " << offset
+               << extra();
     }
     // Check struct members recursively.
     spv_result_t recursive_status = SPV_SUCCESS;
     if (spv::Op::OpTypeStruct == opcode &&
         SPV_SUCCESS != (recursive_status = checkLayout(
-                            id, storage_class_str, decoration_str, blockRules,
+                            id, storage_class, decoration_str, blockRules,
                             scalar_block_layout, offset, constraints, vstate)))
       return recursive_status;
     // Check matrix stride.
     if (spv::Op::OpTypeMatrix == opcode) {
       const auto stride = constraint.matrix_stride;
       if (!IsAlignedTo(stride, alignment)) {
-        return fail(memberIdx) << "is a matrix with stride " << stride
-                               << " not satisfying alignment to " << alignment;
+        return fail(memberIdx)
+               << "is a matrix with stride " << stride
+               << " not satisfying alignment to " << alignment << extra();
       }
     }
 
@@ -576,12 +622,13 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
         if (spv::Decoration::ArrayStride == decoration.dec_type()) {
           array_stride = decoration.params()[0];
           if (array_stride == 0) {
-            return fail(memberIdx) << "contains an array with stride 0";
+            return fail(memberIdx)
+                   << "contains an array with stride 0" << extra();
           }
           if (!IsAlignedTo(array_stride, array_alignment))
             return fail(memberIdx)
                    << "contains an array with stride " << decoration.params()[0]
-                   << " not satisfying alignment to " << alignment;
+                   << " not satisfying alignment to " << alignment << extra();
         }
       }
 
@@ -608,7 +655,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
 
           if (SPV_SUCCESS !=
               (recursive_status = checkLayout(
-                   typeId, storage_class_str, decoration_str, blockRules,
+                   typeId, storage_class, decoration_str, blockRules,
                    scalar_block_layout, next_offset, constraints, vstate)))
             return recursive_status;
 
@@ -620,7 +667,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
         if (!IsAlignedTo(stride, alignment)) {
           return fail(memberIdx)
                  << "is a matrix with stride " << stride
-                 << " not satisfying alignment to " << alignment;
+                 << " not satisfying alignment to " << alignment << extra();
         }
       }
 
@@ -636,7 +683,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
       if (element_size > array_stride) {
         return fail(memberIdx)
                << "contains an array with stride " << array_stride
-               << ", but with an element size of " << element_size;
+               << ", but with an element size of " << element_size << extra();
       }
     }
     nextValidOffset = offset + size;
@@ -801,32 +848,35 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
           if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
             if (has_task_payload) {
               return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
-                     << "There can be at most one OpVariable with storage "
+                     << "There can be at most one "
+                        "OpVariable with storage "
                         "class TaskPayloadWorkgroupEXT associated with "
                         "an OpEntryPoint";
             }
             has_task_payload = true;
           }
-        }
-        if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
+
           // Starting in 1.4, OpEntryPoint must list all global variables
           // it statically uses and those interfaces must be unique.
           if (storage_class == spv::StorageClass::Function) {
             return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
-                   << "OpEntryPoint interfaces should only list global "
+                   << "In SPIR-V 1.4 or later, OpEntryPoint interfaces should "
+                      "only list global "
                       "variables";
           }
 
           if (!seen_vars.insert(var_instr).second) {
             return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
-                   << "Non-unique OpEntryPoint interface "
+                   << "In SPIR-V 1.4 or later, non-unique OpEntryPoint "
+                      "interface "
                    << vstate.getIdName(interface) << " is disallowed";
           }
         } else {
           if (storage_class != spv::StorageClass::Input &&
               storage_class != spv::StorageClass::Output) {
             return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
-                   << "OpEntryPoint interfaces must be OpVariables with "
+                   << "In SPIR-V 1.3 or earlier, OpEntryPoint interfaces must "
+                      "be OpVariables with "
                       "Storage Class of Input(1) or Output(3). Found Storage "
                       "Class "
                    << uint32_t(storage_class) << " for Entry Point id "
@@ -1129,6 +1179,56 @@ void ComputeMemberConstraintsForArray(MemberConstraints* constraints,
   }
 }
 
+spv_result_t CheckDecorationsOfVariables(ValidationState_t& vstate) {
+  if (!spvIsVulkanEnv(vstate.context()->target_env)) {
+    return SPV_SUCCESS;
+  }
+  for (const auto& inst : vstate.ordered_instructions()) {
+    if ((spv::Op::OpVariable == inst.opcode()) ||
+        (spv::Op::OpUntypedVariableKHR == inst.opcode())) {
+      const auto var_id = inst.id();
+      const auto storageClass = inst.GetOperandAs<spv::StorageClass>(2);
+      const bool uniform = storageClass == spv::StorageClass::Uniform;
+      const bool uniform_constant =
+          storageClass == spv::StorageClass::UniformConstant;
+      const bool storage_buffer =
+          storageClass == spv::StorageClass::StorageBuffer;
+
+      const char* sc_str = uniform            ? "Uniform"
+                           : uniform_constant ? "UniformConstant"
+                                              : "StorageBuffer";
+      // Check variables in the UniformConstant, StorageBuffer, and Uniform
+      // storage classes are decorated with DescriptorSet and Binding
+      // (VUID-06677).
+      if (uniform_constant || storage_buffer || uniform) {
+        // Skip validation if the variable is not used and we're looking
+        // at a module coming from HLSL that has not been legalized yet.
+        if (vstate.options()->before_hlsl_legalization &&
+            vstate.EntryPointReferences(var_id).empty()) {
+          continue;
+        }
+        if (!hasDecoration(var_id, spv::Decoration::DescriptorSet, vstate)) {
+          return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
+                 << vstate.VkErrorID(6677) << sc_str << " id '" << var_id
+                 << "' is missing DescriptorSet decoration.\n"
+                 << "From Vulkan spec:\n"
+                 << "These variables must have DescriptorSet and Binding "
+                    "decorations specified";
+        }
+        if (!hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
+          return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
+                 << vstate.VkErrorID(6677) << sc_str << " id '" << var_id
+                 << "' is missing Binding decoration.\n"
+                 << "From Vulkan spec:\n"
+                 << "These variables must have DescriptorSet and Binding "
+                    "decorations specified";
+        }
+      }
+    }
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
   // Set of entry points that are known to use a push constant.
   std::unordered_set<uint32_t> uses_push_constant;
@@ -1148,8 +1248,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
       const auto storageClassVal = words[3];
       const auto storageClass = spv::StorageClass(storageClassVal);
       const bool uniform = storageClass == spv::StorageClass::Uniform;
-      const bool uniform_constant =
-          storageClass == spv::StorageClass::UniformConstant;
       const bool push_constant =
           storageClass == spv::StorageClass::PushConstant;
       const bool storage_buffer =
@@ -1172,29 +1270,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
             }
           }
         }
-        // Vulkan: Check DescriptorSet and Binding decoration for
-        // UniformConstant which cannot be a struct.
-        if (uniform_constant) {
-          auto entry_points = vstate.EntryPointReferences(var_id);
-          if (!entry_points.empty() &&
-              !hasDecoration(var_id, spv::Decoration::DescriptorSet, vstate)) {
-            return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
-                   << vstate.VkErrorID(6677) << "UniformConstant id '" << var_id
-                   << "' is missing DescriptorSet decoration.\n"
-                   << "From Vulkan spec:\n"
-                   << "These variables must have DescriptorSet and Binding "
-                      "decorations specified";
-          }
-          if (!entry_points.empty() &&
-              !hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
-            return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
-                   << vstate.VkErrorID(6677) << "UniformConstant id '" << var_id
-                   << "' is missing Binding decoration.\n"
-                   << "From Vulkan spec:\n"
-                   << "These variables must have DescriptorSet and Binding "
-                      "decorations specified";
-          }
-        }
       }
 
       if (spvIsOpenGLEnv(vstate.context()->target_env)) {
@@ -1207,8 +1282,8 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
           if (!entry_points.empty() &&
               !hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
             return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
-                   << (uniform ? "Uniform" : "Storage Buffer") << " id '"
-                   << var_id << "' is missing Binding decoration.\n"
+                   << getStorageClassString(storageClass) << " id '" << var_id
+                   << "' is missing Binding decoration.\n"
                    << "From ARB_gl_spirv extension:\n"
                    << "Uniform and shader storage block variables must "
                    << "also be decorated with a *Binding*.";
@@ -1243,12 +1318,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
           ComputeMemberConstraintsForStruct(&constraints, id,
                                             LayoutConstraints(), vstate);
         }
-        // Prepare for messages
-        const char* sc_str =
-            uniform
-                ? "Uniform"
-                : (push_constant ? "PushConstant"
-                                 : (workgroup ? "Workgroup" : "StorageBuffer"));
 
         if (spvIsVulkanEnv(vstate.context()->target_env)) {
           const bool block = hasDecoration(id, spv::Decoration::Block, vstate);
@@ -1286,30 +1355,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
                    << "Such variables must be identified with a Block or "
                       "BufferBlock decoration";
           }
-          // Vulkan: Check DescriptorSet and Binding decoration for
-          // Uniform and StorageBuffer variables.
-          if (uniform || storage_buffer) {
-            auto entry_points = vstate.EntryPointReferences(var_id);
-            if (!entry_points.empty() &&
-                !hasDecoration(var_id, spv::Decoration::DescriptorSet,
-                               vstate)) {
-              return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
-                     << vstate.VkErrorID(6677) << sc_str << " id '" << var_id
-                     << "' is missing DescriptorSet decoration.\n"
-                     << "From Vulkan spec:\n"
-                     << "These variables must have DescriptorSet and Binding "
-                        "decorations specified";
-            }
-            if (!entry_points.empty() &&
-                !hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
-              return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
-                     << vstate.VkErrorID(6677) << sc_str << " id '" << var_id
-                     << "' is missing Binding decoration.\n"
-                     << "From Vulkan spec:\n"
-                     << "These variables must have DescriptorSet and Binding "
-                        "decorations specified";
-            }
-          }
         }
 
         if (id != 0) {
@@ -1386,14 +1431,14 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
               if (spvIsVulkanEnv(vstate.context()->target_env)) {
                 if (blockRules &&
                     (SPV_SUCCESS !=
-                     (recursive_status = checkLayout(id, sc_str, deco_str, true,
-                                                     scalar_block_layout, 0,
-                                                     constraints, vstate)))) {
+                     (recursive_status = checkLayout(
+                          id, storageClass, deco_str, true, scalar_block_layout,
+                          0, constraints, vstate)))) {
                   return recursive_status;
                 } else if (bufferRules &&
                            (SPV_SUCCESS != (recursive_status = checkLayout(
-                                                id, sc_str, deco_str, false,
-                                                scalar_block_layout, 0,
+                                                id, storageClass, deco_str,
+                                                false, scalar_block_layout, 0,
                                                 constraints, vstate)))) {
                   return recursive_status;
                 }
@@ -1413,9 +1458,9 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
         ComputeMemberConstraintsForStruct(&constraints, pointee_type_id,
                                           LayoutConstraints(), vstate);
       }
-      if (auto res = checkLayout(pointee_type_id, "PhysicalStorageBuffer",
-                                 "Block", !buffer, scalar_block_layout, 0,
-                                 constraints, vstate)) {
+      if (auto res = checkLayout(
+              pointee_type_id, spv::StorageClass::PhysicalStorageBuffer,
+              "Block", !buffer, scalar_block_layout, 0, constraints, vstate)) {
         return res;
       }
     } else if (vstate.HasCapability(spv::Capability::UntypedPointersKHR) &&
@@ -1464,14 +1509,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
       const auto sc =
           vstate.FindDef(ptr_ty_id)->GetOperandAs<spv::StorageClass>(1);
 
-      const char* sc_str =
-          sc == spv::StorageClass::Uniform
-              ? "Uniform"
-              : (sc == spv::StorageClass::PushConstant
-                     ? "PushConstant"
-                     : (sc == spv::StorageClass::Workgroup ? "Workgroup"
-                                                           : "StorageBuffer"));
-
       auto data_type = vstate.FindDef(data_type_id);
       scalar_block_layout =
           sc == spv::StorageClass::Workgroup
@@ -1511,7 +1548,7 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
               ? (sc == spv::StorageClass::Uniform ? "BufferBlock" : "Block")
               : "Block";
       if (auto result =
-              checkLayout(data_type_id, sc_str, deco_str, !bufferRules,
+              checkLayout(data_type_id, sc, deco_str, !bufferRules,
                           scalar_block_layout, 0, constraints, vstate)) {
         return result;
       }
@@ -1732,14 +1769,19 @@ spv_result_t CheckFPRoundingModeForShaders(ValidationState_t& vstate,
   return SPV_SUCCESS;
 }
 
-// Returns SPV_SUCCESS if validation rules are satisfied for the NonWritable
+// Returns SPV_SUCCESS if validation rules are satisfied for the NonReadable or
+// NonWritable
 // decoration.  Otherwise emits a diagnostic and returns something other than
 // SPV_SUCCESS.  The |inst| parameter is the object being decorated.  This must
 // be called after TypePass and AnnotateCheckDecorationsOfBuffers are called.
-spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
-                                        const Instruction& inst,
-                                        const Decoration& decoration) {
+spv_result_t CheckNonReadableWritableDecorations(ValidationState_t& vstate,
+                                                 const Instruction& inst,
+                                                 const Decoration& decoration) {
   assert(inst.id() && "Parser ensures the target of the decoration has an ID");
+  const bool is_non_writable =
+      decoration.dec_type() == spv::Decoration::NonWritable;
+  assert(is_non_writable ||
+         decoration.dec_type() == spv::Decoration::NonReadable);
 
   if (decoration.struct_member_index() == Decoration::kInvalidMember) {
     // The target must be a memory object declaration.
@@ -1751,7 +1793,10 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
         opcode != spv::Op::OpFunctionParameter &&
         opcode != spv::Op::OpRawAccessChainNV) {
       return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
-             << "Target of NonWritable decoration must be a memory object "
+             << "Target of "
+             << (is_non_writable ? "NonWritable" : "NonReadable")
+             << " decoration must be a "
+                "memory object "
                 "declaration (a variable or a function parameter)";
     }
     const auto var_storage_class =
@@ -1762,7 +1807,8 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
                   : spv::StorageClass::Max;
     if ((var_storage_class == spv::StorageClass::Function ||
          var_storage_class == spv::StorageClass::Private) &&
-        vstate.features().nonwritable_var_in_function_or_private) {
+        vstate.features().nonwritable_var_in_function_or_private &&
+        is_non_writable) {
       // New permitted feature in SPIR-V 1.4.
     } else if (var_storage_class == spv::StorageClass::TileAttachmentQCOM) {
     } else if (
@@ -1770,12 +1816,18 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
         vstate.IsPointerToUniformBlock(type_id) ||
         vstate.IsPointerToStorageBuffer(type_id) ||
         vstate.IsPointerToStorageImage(type_id) ||
+        vstate.IsPointerToTensor(type_id) ||
         opcode == spv::Op::OpRawAccessChainNV) {
     } else {
       return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
-             << "Target of NonWritable decoration is invalid: must point to a "
-                "storage image, uniform block, "
-             << (vstate.features().nonwritable_var_in_function_or_private
+             << "Target of "
+             << (is_non_writable ? "NonWritable" : "NonReadable")
+             << " decoration is invalid: "
+                "must point to a "
+                "storage image, tensor variable in UniformConstant storage "
+                "class, uniform block, "
+             << (vstate.features().nonwritable_var_in_function_or_private &&
+                         is_non_writable
                      ? "storage buffer, or variable in Private or Function "
                        "storage class"
                      : "or storage buffer");
@@ -2063,8 +2115,10 @@ spv_result_t CheckDecorationsFromDecoration(ValidationState_t& vstate) {
             PASS_OR_BAIL(
                 CheckFPRoundingModeForShaders(vstate, *inst, decoration));
           break;
+        case spv::Decoration::NonReadable:
         case spv::Decoration::NonWritable:
-          PASS_OR_BAIL(CheckNonWritableDecoration(vstate, *inst, decoration));
+          PASS_OR_BAIL(
+              CheckNonReadableWritableDecorations(vstate, *inst, decoration));
           break;
         case spv::Decoration::Uniform:
         case spv::Decoration::UniformId:
@@ -2298,6 +2352,7 @@ spv_result_t ValidateDecorations(ValidationState_t& vstate) {
   if (auto error = CheckImportedVariableInitialization(vstate)) return error;
   if (auto error = CheckDecorationsOfEntryPoints(vstate)) return error;
   if (auto error = CheckDecorationsOfBuffers(vstate)) return error;
+  if (auto error = CheckDecorationsOfVariables(vstate)) return error;
   if (auto error = CheckDecorationsCompatibility(vstate)) return error;
   if (auto error = CheckLinkageAttrOfFunctions(vstate)) return error;
   if (auto error = CheckVulkanMemoryModelDeprecatedDecorations(vstate))

+ 10 - 4
3rdparty/spirv-tools/source/val/validate_extensions.cpp

@@ -1052,7 +1052,9 @@ bool IsDebugVariableWithIntScalarType(ValidationState_t& _,
 spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
   std::string extension = GetExtensionString(&(inst->c_inst()));
   if (_.version() < SPV_SPIRV_VERSION_WORD(1, 3)) {
-    if (extension == ExtensionToString(kSPV_KHR_vulkan_memory_model)) {
+    if (extension == ExtensionToString(kSPV_KHR_vulkan_memory_model) ||
+        extension ==
+            ExtensionToString(kSPV_QCOM_cooperative_matrix_conversion)) {
       return _.diag(SPV_ERROR_WRONG_VERSION, inst)
              << extension << " extension requires SPIR-V version 1.3 or later.";
     }
@@ -1064,7 +1066,9 @@ spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
         extension == ExtensionToString(kSPV_NV_shader_invocation_reorder) ||
         extension ==
             ExtensionToString(kSPV_NV_cluster_acceleration_structure) ||
-        extension == ExtensionToString(kSPV_NV_linear_swept_spheres)) {
+        extension == ExtensionToString(kSPV_NV_linear_swept_spheres) ||
+        extension == ExtensionToString(kSPV_QCOM_image_processing) ||
+        extension == ExtensionToString(kSPV_QCOM_image_processing2)) {
       return _.diag(SPV_ERROR_WRONG_VERSION, inst)
              << extension << " extension requires SPIR-V version 1.4 or later.";
     }
@@ -1081,8 +1085,10 @@ spv_result_t ValidateExtInstImport(ValidationState_t& _,
     const std::string name = inst->GetOperandAs<std::string>(name_id);
     if (name.find("NonSemantic.") == 0) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "NonSemantic extended instruction sets cannot be declared "
-                "without SPV_KHR_non_semantic_info.";
+             << "NonSemantic extended instruction "
+                "sets cannot be declared "
+                "without SPV_KHR_non_semantic_info. (This can also be fixed "
+                "having SPIR-V 1.6 or later)";
     }
   }
 

+ 4 - 6
3rdparty/spirv-tools/source/val/validate_function.cpp

@@ -89,7 +89,10 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
       spv::Op::OpName,
       spv::Op::OpCooperativeMatrixPerElementOpNV,
       spv::Op::OpCooperativeMatrixReduceNV,
-      spv::Op::OpCooperativeMatrixLoadTensorNV};
+      spv::Op::OpCooperativeMatrixLoadTensorNV,
+      spv::Op::OpConditionalEntryPointINTEL,
+  };
+
   for (auto& pair : inst->uses()) {
     const auto* use = pair.first;
     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
@@ -109,11 +112,6 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _,
   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
   size_t param_index = 0;
   size_t inst_num = inst->LineNum() - 1;
-  if (inst_num == 0) {
-    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
-           << "Function parameter cannot be the first instruction.";
-  }
-
   auto func_inst = &_.ordered_instructions()[inst_num];
   while (--inst_num) {
     func_inst = &_.ordered_instructions()[inst_num];

+ 547 - 0
3rdparty/spirv-tools/source/val/validate_graph.cpp

@@ -0,0 +1,547 @@
+// Copyright (c) 2023-2025 Arm Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Validates correctness of graph instructions.
+
+#include <deque>
+
+#include "source/opcode.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+bool IsTensorArray(ValidationState_t& _, uint32_t id) {
+  auto def = _.FindDef(id);
+  if (!def || (def->opcode() != spv::Op::OpTypeArray &&
+               def->opcode() != spv::Op::OpTypeRuntimeArray)) {
+    return false;
+  }
+  auto tdef = _.FindDef(def->word(2));
+  if (!tdef || tdef->opcode() != spv::Op::OpTypeTensorARM) {
+    return false;
+  }
+  return true;
+}
+
+bool IsGraphInterfaceType(ValidationState_t& _, uint32_t id) {
+  return _.IsTensorType(id) || IsTensorArray(_, id);
+}
+
+bool IsGraph(ValidationState_t& _, uint32_t id) {
+  auto def = _.FindDef(id);
+  if (!def || def->opcode() != spv::Op::OpGraphARM) {
+    return false;
+  }
+  return true;
+}
+
+bool IsGraphType(ValidationState_t& _, uint32_t id) {
+  auto def = _.FindDef(id);
+  if (!def || def->opcode() != spv::Op::OpTypeGraphARM) {
+    return false;
+  }
+  return true;
+}
+
+const uint32_t kGraphTypeIOStartWord = 3;
+
+uint32_t GraphTypeInstNumIO(const Instruction* inst) {
+  return static_cast<uint32_t>(inst->words().size()) - kGraphTypeIOStartWord;
+}
+
+uint32_t GraphTypeInstNumInputs(const Instruction* inst) {
+  return inst->word(2);
+}
+
+uint32_t GraphTypeInstNumOutputs(const Instruction* inst) {
+  return GraphTypeInstNumIO(inst) - GraphTypeInstNumInputs(inst);
+}
+
+uint32_t GraphTypeInstGetOutputAtIndex(const Instruction* inst,
+                                       uint64_t index) {
+  return inst->word(kGraphTypeIOStartWord + GraphTypeInstNumInputs(inst) +
+                    static_cast<uint32_t>(index));
+}
+
+uint32_t GraphTypeInstGetInputAtIndex(const Instruction* inst, uint64_t index) {
+  return inst->word(kGraphTypeIOStartWord + static_cast<uint32_t>(index));
+}
+
+spv_result_t ValidateGraphType(ValidationState_t& _, const Instruction* inst) {
+  // Check there are at least NumInputs types
+  uint32_t NumInputs = GraphTypeInstNumInputs(inst);
+  size_t NumIOTypes = GraphTypeInstNumIO(inst);
+  if (NumIOTypes < NumInputs) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << NumIOTypes << " I/O types were provided but the graph has "
+           << NumInputs << " inputs.";
+  }
+
+  // Check there is at least one output
+  if (NumIOTypes == NumInputs) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "A graph type must have at least one output.";
+  }
+
+  // Check all I/O types are graph interface type
+  for (unsigned i = kGraphTypeIOStartWord; i < inst->words().size(); i++) {
+    auto tid = inst->word(i);
+    if (!IsGraphInterfaceType(_, tid)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "I/O type " << _.getIdName(tid)
+             << " is not a Graph Interface Type.";
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGraphConstant(ValidationState_t& _,
+                                   const Instruction* inst) {
+  // Check Result Type
+  if (!_.IsTensorType(inst->type_id())) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode())
+           << " must have a Result Type that is a tensor type.";
+  }
+
+  // Check the instruction is not preceded by another OpGraphConstantARM with
+  // the same ID
+  const uint32_t cst_id = inst->word(3);
+  size_t inst_num = inst->LineNum() - 1;
+  while (--inst_num) {
+    auto prev_inst = &_.ordered_instructions()[inst_num];
+    if (prev_inst->opcode() == spv::Op::OpGraphConstantARM) {
+      const uint32_t prev_cst_id = prev_inst->word(3);
+      if (prev_cst_id == cst_id) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "No two OpGraphConstantARM instructions may have the same "
+                  "GraphConstantID";
+      }
+    }
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGraphEntryPoint(ValidationState_t& _,
+                                     const Instruction* inst) {
+  // Graph must be an OpGraphARM
+  uint32_t graph = inst->GetOperandAs<uint32_t>(0);
+  auto graph_inst = _.FindDef(graph);
+  if (!IsGraph(_, graph)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode())
+           << " Graph must be a OpGraphARM but found "
+           << spvOpcodeString(graph_inst->opcode()) << ".";
+  }
+
+  // Check number of Interface IDs matches number of I/Os of graph
+  auto graph_type_inst = _.FindDef(graph_inst->type_id());
+  size_t graph_type_num_io = GraphTypeInstNumIO(graph_type_inst);
+  size_t graph_entry_point_num_interface_id = inst->operands().size() - 2;
+  if (graph_type_inst->opcode() != spv::Op::OpTypeGraphARM) {
+    // This is invalid but we want ValidateGraph to report a clear error
+    // so stop validating the graph entry point instruction
+    return SPV_SUCCESS;
+  }
+  if (graph_type_num_io != graph_entry_point_num_interface_id) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode()) << " Interface list contains "
+           << graph_entry_point_num_interface_id << " IDs but Graph's type "
+           << _.getIdName(graph_inst->type_id()) << " has " << graph_type_num_io
+           << " inputs and outputs.";
+  }
+
+  // Check Interface IDs
+  for (uint32_t i = 2; i < inst->operands().size(); i++) {
+    uint32_t interface_id = inst->GetOperandAs<uint32_t>(i);
+    auto interface_inst = _.FindDef(interface_id);
+
+    // Check interface IDs come from OpVariable
+    if ((interface_inst->opcode() != spv::Op::OpVariable) ||
+        (interface_inst->GetOperandAs<spv::StorageClass>(2) !=
+         spv::StorageClass::UniformConstant)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, interface_inst)
+             << spvOpcodeString(inst->opcode()) << " Interface ID "
+             << _.getIdName(interface_id)
+             << " must come from OpVariable with UniformConstant Storage "
+                "Class.";
+    }
+
+    // Check type of interface variable matches type of the corresponding graph
+    // I/O
+    uint32_t corresponding_graph_io_type =
+        graph_type_inst->GetOperandAs<uint32_t>(i);
+
+    uint32_t interface_ptr_type = interface_inst->type_id();
+    auto interface_ptr_inst = _.FindDef(interface_ptr_type);
+    auto interface_pointee_type = interface_ptr_inst->GetOperandAs<uint32_t>(2);
+    if (interface_pointee_type != corresponding_graph_io_type) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << spvOpcodeString(inst->opcode()) << " Interface ID type "
+             << _.getIdName(interface_pointee_type)
+             << " must match the type of the corresponding graph I/O "
+             << _.getIdName(corresponding_graph_io_type);
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGraph(ValidationState_t& _, const Instruction* inst) {
+  // Result Type must be an OpTypeGraphARM
+  if (!IsGraphType(_, inst->type_id())) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode())
+           << " Result Type must be an OpTypeGraphARM.";
+  }
+
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGraphInput(ValidationState_t& _, const Instruction* inst) {
+  // Check type of InputIndex
+  auto input_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(2));
+  if (!input_index_inst ||
+      !_.IsIntScalarType(input_index_inst->type_id(), 32)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode())
+           << " InputIndex must be a 32-bit integer.";
+  }
+
+  bool has_element_index = inst->operands().size() > 3;
+
+  // Check type of ElementIndex
+  if (has_element_index) {
+    auto element_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(3));
+    if (!element_index_inst ||
+        !_.IsIntScalarType(element_index_inst->type_id(), 32)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << spvOpcodeString(inst->opcode())
+             << " ElementIndex must be a 32-bit integer.";
+    }
+  }
+
+  // Find graph definition
+  size_t inst_num = inst->LineNum() - 1;
+  auto graph_inst = &_.ordered_instructions()[inst_num];
+  while (--inst_num) {
+    graph_inst = &_.ordered_instructions()[inst_num];
+    if (graph_inst->opcode() == spv::Op::OpGraphARM) {
+      break;
+    }
+  }
+
+  // Can the InputIndex be evaluated?
+  // If not, there's nothing more we can validate here.
+  uint64_t input_index;
+  if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2), &input_index)) {
+    return SPV_SUCCESS;
+  }
+
+  auto const graph_type_inst = _.FindDef(graph_inst->type_id());
+  size_t graph_type_num_inputs = graph_type_inst->GetOperandAs<uint32_t>(1);
+
+  // Check InputIndex is in range
+  if (input_index >= graph_type_num_inputs) {
+    std::string disassembly = _.Disassemble(*inst);
+    return _.diag(SPV_ERROR_INVALID_DATA, nullptr)
+           << "Type " << _.getIdName(graph_type_inst->id()) << " for graph "
+           << _.getIdName(graph_inst->id()) << " has " << graph_type_num_inputs
+           << " inputs but found an OpGraphInputARM instruction with an "
+              "InputIndex that is "
+           << input_index << ": " << disassembly;
+  }
+
+  uint32_t graph_type_input_type =
+      GraphTypeInstGetInputAtIndex(graph_type_inst, input_index);
+
+  if (has_element_index) {
+    // Check ElementIndex is allowed
+    if (!IsTensorArray(_, graph_type_input_type)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "OpGraphInputARM ElementIndex not allowed when the graph input "
+                "selected by "
+             << "InputIndex is not an OpTypeArray or OpTypeRuntimeArray";
+    }
+
+    // Check ElementIndex is in range if it can be evaluated and the input is a
+    // fixed-sized array whose Length can be evaluated
+    uint64_t element_index;
+    if (_.IsArrayType(graph_type_input_type) &&
+        _.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(3),
+                                &element_index)) {
+      uint64_t array_length;
+      auto graph_type_input_type_inst = _.FindDef(graph_type_input_type);
+      if (_.EvalConstantValUint64(
+              graph_type_input_type_inst->GetOperandAs<uint32_t>(2),
+              &array_length)) {
+        if (element_index >= array_length) {
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "OpGraphInputARM ElementIndex out of range. The type of "
+                    "the graph input being accessed "
+                 << _.getIdName(graph_type_input_type) << " is an array of "
+                 << array_length << " elements but " << "ElementIndex is "
+                 << element_index;
+        }
+      }
+    }
+  }
+
+  // Check result type matches with graph type
+  if (has_element_index) {
+    uint32_t expected_type = _.GetComponentType(graph_type_input_type);
+    if (inst->type_id() != expected_type) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Result Type " << _.getIdName(inst->type_id())
+             << " of graph input instruction " << _.getIdName(inst->id())
+             << " does not match the component type "
+             << _.getIdName(expected_type) << " of input " << input_index
+             << " in the graph type.";
+    }
+  } else {
+    if (inst->type_id() != graph_type_input_type) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Result Type " << _.getIdName(inst->type_id())
+             << " of graph input instruction " << _.getIdName(inst->id())
+             << " does not match the type "
+             << _.getIdName(graph_type_input_type) << " of input "
+             << input_index << " in the graph type.";
+    }
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateGraphSetOutput(ValidationState_t& _,
+                                    const Instruction* inst) {
+  // Check type of OutputIndex
+  auto output_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(1));
+  if (!output_index_inst ||
+      !_.IsIntScalarType(output_index_inst->type_id(), 32)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode())
+           << " OutputIndex must be a 32-bit integer.";
+  }
+
+  bool has_element_index = inst->operands().size() > 2;
+
+  // Check type of ElementIndex
+  if (has_element_index) {
+    auto element_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(2));
+    if (!element_index_inst ||
+        !_.IsIntScalarType(element_index_inst->type_id(), 32)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << spvOpcodeString(inst->opcode())
+             << " ElementIndex must be a 32-bit integer.";
+    }
+  }
+
+  // Find graph definition
+  size_t inst_num = inst->LineNum() - 1;
+  auto graph_inst = &_.ordered_instructions()[inst_num];
+  while (--inst_num) {
+    graph_inst = &_.ordered_instructions()[inst_num];
+    if (graph_inst->opcode() == spv::Op::OpGraphARM) {
+      break;
+    }
+  }
+
+  // Can the OutputIndex be evaluated?
+  // If not, there's nothing more we can validate here.
+  uint64_t output_index;
+  if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(1),
+                               &output_index)) {
+    return SPV_SUCCESS;
+  }
+
+  // Check that the OutputIndex is valid with respect to the graph type
+  auto graph_type_inst = _.FindDef(graph_inst->type_id());
+  size_t graph_type_num_outputs = GraphTypeInstNumOutputs(graph_type_inst);
+
+  if (output_index >= graph_type_num_outputs) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(inst->opcode()) << " setting OutputIndex "
+           << output_index << " but graph only has " << graph_type_num_outputs
+           << " outputs.";
+  }
+
+  uint32_t graph_type_output_type =
+      GraphTypeInstGetOutputAtIndex(graph_type_inst, output_index);
+
+  if (has_element_index) {
+    // Check ElementIndex is allowed
+    if (!IsTensorArray(_, graph_type_output_type)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "OpGraphSetOutputARM ElementIndex not allowed when the graph "
+                "output selected by "
+             << "OutputIndex is not an OpTypeArray or OpTypeRuntimeArray";
+    }
+
+    // Check ElementIndex is in range if it can be evaluated and the output is a
+    // fixed-sized array whose Length can be evaluated
+    uint64_t element_index;
+    if (_.IsArrayType(graph_type_output_type) &&
+        _.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2),
+                                &element_index)) {
+      uint64_t array_length;
+      auto graph_type_output_type_inst = _.FindDef(graph_type_output_type);
+      if (_.EvalConstantValUint64(
+              graph_type_output_type_inst->GetOperandAs<uint32_t>(2),
+              &array_length)) {
+        if (element_index >= array_length) {
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "OpGraphSetOutputARM ElementIndex out of range. The type "
+                    "of the graph output being accessed "
+                 << _.getIdName(graph_type_output_type) << " is an array of "
+                 << array_length << " elements but " << "ElementIndex is "
+                 << element_index;
+        }
+      }
+    }
+  }
+
+  // Check Value's type matches with graph type
+  uint32_t value = inst->GetOperandAs<uint32_t>(0);
+  uint32_t value_type = _.FindDef(value)->type_id();
+  if (has_element_index) {
+    uint32_t expected_type = _.GetComponentType(graph_type_output_type);
+    if (value_type != expected_type) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "The type " << _.getIdName(value_type)
+             << " of Value provided to the graph output instruction "
+             << _.getIdName(value) << " does not match the component type "
+             << _.getIdName(expected_type) << " of output " << output_index
+             << " in the graph type.";
+    }
+  } else {
+    if (value_type != graph_type_output_type) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "The type " << _.getIdName(value_type)
+             << " of Value provided to the graph output instruction "
+             << _.getIdName(value) << " does not match the type "
+             << _.getIdName(graph_type_output_type) << " of output "
+             << output_index << " in the graph type.";
+    }
+  }
+  return SPV_SUCCESS;
+}
+
+bool InputOutputInstructionsHaveDuplicateIndices(
+    ValidationState_t& _, std::deque<const Instruction*>& inout_insts,
+    const Instruction** first_dup) {
+  std::set<std::pair<uint64_t, uint64_t>> inout_element_indices;
+  for (auto const inst : inout_insts) {
+    const bool is_input = inst->opcode() == spv::Op::OpGraphInputARM;
+    bool has_element_index = inst->operands().size() > (is_input ? 3 : 2);
+    uint64_t inout_index;
+    if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(is_input ? 2 : 1),
+                                 &inout_index)) {
+      continue;
+    }
+    uint64_t element_index = -1;  // -1 means no ElementIndex
+    if (has_element_index) {
+      if (!_.EvalConstantValUint64(
+              inst->GetOperandAs<uint32_t>(is_input ? 3 : 2), &element_index)) {
+        continue;
+      }
+    }
+    auto inout_element_pair = std::make_pair(inout_index, element_index);
+    auto inout_noelement_pair = std::make_pair(inout_index, -1);
+    if (inout_element_indices.count(inout_element_pair) ||
+        inout_element_indices.count(inout_noelement_pair)) {
+      *first_dup = inst;
+      return true;
+    }
+    inout_element_indices.insert(inout_element_pair);
+  }
+  return false;
+}
+
+spv_result_t ValidateGraphEnd(ValidationState_t& _, const Instruction* inst) {
+  size_t end_inst_num = inst->LineNum() - 1;
+
+  // Gather OpGraphInputARM and OpGraphSetOutputARM instructions
+  std::deque<const Instruction*> graph_inputs, graph_outputs;
+  size_t in_inst_num = end_inst_num;
+  auto graph_inst = &_.ordered_instructions()[in_inst_num];
+  while (--in_inst_num) {
+    graph_inst = &_.ordered_instructions()[in_inst_num];
+    if (graph_inst->opcode() == spv::Op::OpGraphInputARM) {
+      graph_inputs.push_front(graph_inst);
+      continue;
+    }
+    if (graph_inst->opcode() == spv::Op::OpGraphSetOutputARM) {
+      graph_outputs.push_front(graph_inst);
+      continue;
+    }
+    if (graph_inst->opcode() == spv::Op::OpGraphARM) {
+      break;
+    }
+  }
+
+  const Instruction* first_dup;
+
+  // Check that there are no duplicate InputIndex and ElementIndex values
+  if (InputOutputInstructionsHaveDuplicateIndices(_, graph_inputs,
+                                                  &first_dup)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, first_dup)
+           << "Two OpGraphInputARM instructions with the same InputIndex "
+              "must not be part of the same "
+           << "graph definition unless ElementIndex is present in both with "
+              "different values.";
+  }
+
+  // Check that there are no duplicate OutputIndex and ElementIndex values
+  if (InputOutputInstructionsHaveDuplicateIndices(_, graph_outputs,
+                                                  &first_dup)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, first_dup)
+           << "Two OpGraphSetOutputARM instructions with the same "
+              "OutputIndex must not be part of the same "
+           << "graph definition unless ElementIndex is present in both with "
+              "different values.";
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace
+
+// Validates correctness of graph instructions.
+spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst) {
+  switch (inst->opcode()) {
+    case spv::Op::OpTypeGraphARM:
+      return ValidateGraphType(_, inst);
+    case spv::Op::OpGraphConstantARM:
+      return ValidateGraphConstant(_, inst);
+    case spv::Op::OpGraphEntryPointARM:
+      return ValidateGraphEntryPoint(_, inst);
+    case spv::Op::OpGraphARM:
+      return ValidateGraph(_, inst);
+    case spv::Op::OpGraphInputARM:
+      return ValidateGraphInput(_, inst);
+    case spv::Op::OpGraphSetOutputARM:
+      return ValidateGraphSetOutput(_, inst);
+    case spv::Op::OpGraphEndARM:
+      return ValidateGraphEnd(_, inst);
+    default:
+      break;
+  }
+  return SPV_SUCCESS;
+}
+
+}  // namespace val
+}  // namespace spvtools

+ 55 - 34
3rdparty/spirv-tools/source/val/validate_id.cpp

@@ -115,6 +115,57 @@ spv_result_t CheckIdDefinitionDominateUse(ValidationState_t& _) {
   return SPV_SUCCESS;
 }
 
+bool InstructionCanHaveTypeOperand(const Instruction* inst) {
+  static std::unordered_set<spv::Op> instruction_allow_set{
+      spv::Op::OpSizeOf,
+      spv::Op::OpCooperativeMatrixLengthNV,
+      spv::Op::OpCooperativeMatrixLengthKHR,
+      spv::Op::OpUntypedArrayLengthKHR,
+      spv::Op::OpFunction,
+      spv::Op::OpAsmINTEL,
+  };
+  const auto opcode = inst->opcode();
+  bool type_instruction = spvOpcodeGeneratesType(opcode);
+  bool debug_instruction = spvOpcodeIsDebug(opcode) || inst->IsDebugInfo();
+  bool coop_matrix_spec_constant_op_length =
+      (opcode == spv::Op::OpSpecConstantOp) &&
+      (spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthNV ||
+       spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthKHR);
+  return type_instruction || debug_instruction || inst->IsNonSemantic() ||
+         spvOpcodeIsDecoration(opcode) || instruction_allow_set.count(opcode) ||
+         spvOpcodeGeneratesUntypedPointer(opcode) ||
+         coop_matrix_spec_constant_op_length;
+}
+
+bool InstructionRequiresTypeOperand(const Instruction* inst) {
+  static std::unordered_set<spv::Op> instruction_deny_set{
+      spv::Op::OpExtInst,
+      spv::Op::OpExtInstWithForwardRefsKHR,
+      spv::Op::OpExtInstImport,
+      spv::Op::OpSelectionMerge,
+      spv::Op::OpLoopMerge,
+      spv::Op::OpFunction,
+      spv::Op::OpSizeOf,
+      spv::Op::OpCooperativeMatrixLengthNV,
+      spv::Op::OpCooperativeMatrixLengthKHR,
+      spv::Op::OpPhi,
+      spv::Op::OpUntypedArrayLengthKHR,
+      spv::Op::OpAsmINTEL,
+  };
+  const auto opcode = inst->opcode();
+  bool debug_instruction = spvOpcodeIsDebug(opcode) || inst->IsDebugInfo();
+  bool coop_matrix_spec_constant_op_length =
+      opcode == spv::Op::OpSpecConstantOp &&
+      (spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthNV ||
+       spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthKHR);
+
+  return !debug_instruction && !inst->IsNonSemantic() &&
+         !spvOpcodeIsDecoration(opcode) && !spvOpcodeIsBranch(opcode) &&
+         !instruction_deny_set.count(opcode) &&
+         !spvOpcodeGeneratesUntypedPointer(opcode) &&
+         !coop_matrix_spec_constant_op_length;
+}
+
 // Performs SSA validation on the IDs of an instruction. The
 // can_have_forward_declared_ids  functor should return true if the
 // instruction operand's ID can be forward referenced.
@@ -158,44 +209,14 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
       case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
       case SPV_OPERAND_TYPE_SCOPE_ID:
         if (const auto def = _.FindDef(operand_word)) {
-          const auto opcode = inst->opcode();
           if (spvOpcodeGeneratesType(def->opcode()) &&
-              !spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
-              !inst->IsDebugInfo() && !inst->IsNonSemantic() &&
-              !spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
-              opcode != spv::Op::OpSizeOf &&
-              opcode != spv::Op::OpCooperativeMatrixLengthNV &&
-              opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
-              !spvOpcodeGeneratesUntypedPointer(opcode) &&
-              opcode != spv::Op::OpUntypedArrayLengthKHR &&
-              !(opcode == spv::Op::OpSpecConstantOp &&
-                (spv::Op(inst->word(3)) ==
-                     spv::Op::OpCooperativeMatrixLengthNV ||
-                 spv::Op(inst->word(3)) ==
-                     spv::Op::OpCooperativeMatrixLengthKHR))) {
+              !InstructionCanHaveTypeOperand(inst)) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " cannot be a type";
-          } else if (def->type_id() == 0 && !spvOpcodeGeneratesType(opcode) &&
-                     !spvOpcodeIsDebug(opcode) && !inst->IsDebugInfo() &&
-                     !inst->IsNonSemantic() && !spvOpcodeIsDecoration(opcode) &&
-                     !spvOpcodeIsBranch(opcode) && opcode != spv::Op::OpPhi &&
-                     opcode != spv::Op::OpExtInst &&
-                     opcode != spv::Op::OpExtInstWithForwardRefsKHR &&
-                     opcode != spv::Op::OpExtInstImport &&
-                     opcode != spv::Op::OpSelectionMerge &&
-                     opcode != spv::Op::OpLoopMerge &&
-                     opcode != spv::Op::OpFunction &&
-                     opcode != spv::Op::OpSizeOf &&
-                     opcode != spv::Op::OpCooperativeMatrixLengthNV &&
-                     opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
-                     !spvOpcodeGeneratesUntypedPointer(opcode) &&
-                     opcode != spv::Op::OpUntypedArrayLengthKHR &&
-                     !(opcode == spv::Op::OpSpecConstantOp &&
-                       (spv::Op(inst->word(3)) ==
-                            spv::Op::OpCooperativeMatrixLengthNV ||
-                        spv::Op(inst->word(3)) ==
-                            spv::Op::OpCooperativeMatrixLengthKHR))) {
+          } else if (def->type_id() == 0 &&
+                     !spvOpcodeGeneratesType(def->opcode()) &&
+                     InstructionRequiresTypeOperand(inst)) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " requires a type";

+ 3 - 1
3rdparty/spirv-tools/source/val/validate_image.cpp

@@ -464,7 +464,9 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << _.VkErrorID(10213)
                << "Image Operand Offset can only be used with "
-                  "OpImage*Gather operations";
+                  "OpImage*Gather operations."
+               << _.MissingFeature("maintenance8 feature",
+                                   "--allow-offset-texture-operand", false);
       }
     }
   }

+ 6 - 2
3rdparty/spirv-tools/source/val/validate_instruction.cpp

@@ -195,7 +195,8 @@ spv_result_t CheckRequiredCapabilities(ValidationState_t& state,
     // registers a capability with the module *before* checking capabilities.
     // So in the case of an OpCapability instruction, don't bother checking
     // enablement by another capability.
-    if (inst->opcode() != spv::Op::OpCapability) {
+    if (inst->opcode() != spv::Op::OpCapability &&
+        inst->opcode() != spv::Op::OpConditionalCapabilityINTEL) {
       const bool enabled_by_cap =
           state.HasAnyOfCapabilities(enabling_capabilities);
       if (!enabling_capabilities.empty() && !enabled_by_cap) {
@@ -461,10 +462,13 @@ spv_result_t CheckIfKnownExtension(ValidationState_t& _,
 
 spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
   const spv::Op opcode = inst->opcode();
-  if (opcode == spv::Op::OpExtension) {
+  if (opcode == spv::Op::OpExtension ||
+      opcode == spv::Op::OpConditionalExtensionINTEL) {
     CheckIfKnownExtension(_, inst);
   } else if (opcode == spv::Op::OpCapability) {
     _.RegisterCapability(inst->GetOperandAs<spv::Capability>(0));
+  } else if (opcode == spv::Op::OpConditionalCapabilityINTEL) {
+    _.RegisterCapability(inst->GetOperandAs<spv::Capability>(1));
   } else if (opcode == spv::Op::OpMemoryModel) {
     if (_.has_memory_model_specified()) {
       return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)

+ 84 - 72
3rdparty/spirv-tools/source/val/validate_interfaces.cpp

@@ -166,20 +166,17 @@ spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
       }
       break;
     case spv::Op::OpTypeMatrix:
-      // Matrices consume locations equivalent to arrays.
-      if (auto error = NumConsumedLocations(
-              _, _.FindDef(type->GetOperandAs<uint32_t>(1)), num_locations)) {
-        return error;
-      }
+      // Matrices consume locations equal to the underlying vector type for
+      // each column.
+      NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
+                           num_locations);
       *num_locations *= type->GetOperandAs<uint32_t>(2);
       break;
     case spv::Op::OpTypeArray: {
       // Arrays consume locations equal to the underlying type times the number
       // of elements in the vector.
-      if (auto error = NumConsumedLocations(
-              _, _.FindDef(type->GetOperandAs<uint32_t>(1)), num_locations)) {
-        return error;
-      }
+      NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
+                           num_locations);
       bool is_int = false;
       bool is_const = false;
       uint32_t value = 0;
@@ -249,31 +246,10 @@ uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
           NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
       num_components *= type->GetOperandAs<uint32_t>(2);
       break;
-    case spv::Op::OpTypeMatrix:
-      // Matrices consume all components of the location.
-      // Round up to next multiple of 4.
-      num_components =
-          NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
-      num_components *= type->GetOperandAs<uint32_t>(2);
-      num_components = ((num_components + 3) / 4) * 4;
-      break;
-    case spv::Op::OpTypeArray: {
-      // Arrays consume all components of the location.
-      // Round up to next multiple of 4.
-      num_components =
-          NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
-
-      bool is_int = false;
-      bool is_const = false;
-      uint32_t value = 0;
-      // Attempt to evaluate the number of array elements.
-      std::tie(is_int, is_const, value) =
-          _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
-      if (is_int && is_const) num_components *= value;
-
-      num_components = ((num_components + 3) / 4) * 4;
-      return num_components;
-    }
+    case spv::Op::OpTypeArray:
+      // Skip the array.
+      return NumConsumedComponents(_,
+                                   _.FindDef(type->GetOperandAs<uint32_t>(1)));
     case spv::Op::OpTypePointer:
       if (_.addressing_model() ==
               spv::AddressingModel::PhysicalStorageBuffer64 &&
@@ -356,10 +332,9 @@ spv_result_t GetLocationsForVariable(
     }
   }
 
-  // Vulkan 15.1.3 (Interface Matching): Tessellation control and mesh
-  // per-vertex outputs and tessellation control, evaluation and geometry
-  // per-vertex inputs have a layer of arraying that is not included in
-  // interface matching.
+  // Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
+  // tessellation control, evaluation and geometry per-vertex inputs have a
+  // layer of arraying that is not included in interface matching.
   bool is_arrayed = false;
   switch (entry_point->GetOperandAs<spv::ExecutionModel>(0)) {
     case spv::ExecutionModel::TessellationControl:
@@ -413,33 +388,51 @@ spv_result_t GetLocationsForVariable(
 
   const std::string storage_class = is_output ? "output" : "input";
   if (has_location) {
+    auto sub_type = type;
+    bool is_int = false;
+    bool is_const = false;
+    uint32_t array_size = 1;
+    // If the variable is still arrayed, mark the locations/components per
+    // index.
+    if (type->opcode() == spv::Op::OpTypeArray) {
+      // Determine the array size if possible and get the element type.
+      std::tie(is_int, is_const, array_size) =
+          _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
+      if (!is_int || !is_const) array_size = 1;
+      auto sub_type_id = type->GetOperandAs<uint32_t>(1);
+      sub_type = _.FindDef(sub_type_id);
+    }
+
     uint32_t num_locations = 0;
-    if (auto error = NumConsumedLocations(_, type, &num_locations))
+    if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
       return error;
-    uint32_t num_components = NumConsumedComponents(_, type);
+    uint32_t num_components = NumConsumedComponents(_, sub_type);
 
-    uint32_t start = location * 4;
-    uint32_t end = (location + num_locations) * 4;
-    if (num_components % 4 != 0) {
-      start += component;
-      end = start + num_components;
-    }
+    for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
+      uint32_t array_location = location + (num_locations * array_idx);
+      uint32_t start = array_location * 4;
+      if (kMaxLocations <= start) {
+        // Too many locations, give up.
+        break;
+      }
 
-    if (kMaxLocations <= start) {
-      // Too many locations, give up.
-      return SPV_SUCCESS;
-    }
+      uint32_t end = (array_location + num_locations) * 4;
+      if (num_components != 0) {
+        start += component;
+        end = array_location * 4 + component + num_components;
+      }
 
-    auto locs = locations;
-    if (has_index && index == 1) locs = output_index1_locations;
+      auto locs = locations;
+      if (has_index && index == 1) locs = output_index1_locations;
 
-    for (uint32_t i = start; i < end; ++i) {
-      if (!locs->insert(i).second) {
-        return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
-               << (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
-               << "Entry-point has conflicting " << storage_class
-               << " location assignment at location " << i / 4 << ", component "
-               << i % 4;
+      for (uint32_t i = start; i < end; ++i) {
+        if (!locs->insert(i).second) {
+          return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
+                 << (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
+                 << "Entry-point has conflicting " << storage_class
+                 << " location assignment at location " << i / 4
+                 << ", component " << i % 4;
+        }
       }
     }
   } else {
@@ -498,19 +491,38 @@ spv_result_t GetLocationsForVariable(
         continue;
       }
 
-      uint32_t end = (location + num_locations) * 4;
-      if (num_components % 4 != 0) {
-        start += component;
-        end = location * 4 + component + num_components;
-      }
-
-      for (uint32_t l = start; l < end; ++l) {
-        if (!locations->insert(l).second) {
-          return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
-                 << (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
-                 << "Entry-point has conflicting " << storage_class
-                 << " location assignment at location " << l / 4
-                 << ", component " << l % 4;
+      if (member->opcode() == spv::Op::OpTypeArray && num_components >= 1 &&
+          num_components < 4) {
+        // When an array has an element that takes less than a location in
+        // size, calculate the used locations in a strided manner.
+        for (uint32_t l = location; l < num_locations + location; ++l) {
+          for (uint32_t c = component; c < component + num_components; ++c) {
+            uint32_t check = 4 * l + c;
+            if (!locations->insert(check).second) {
+              return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
+                     << (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
+                     << "Entry-point has conflicting " << storage_class
+                     << " location assignment at location " << l
+                     << ", component " << c;
+            }
+          }
+        }
+      } else {
+        // TODO: There is a hole here is the member is an array of 3- or
+        // 4-element vectors of 64-bit types.
+        uint32_t end = (location + num_locations) * 4;
+        if (num_components != 0) {
+          start += component;
+          end = location * 4 + component + num_components;
+        }
+        for (uint32_t l = start; l < end; ++l) {
+          if (!locations->insert(l).second) {
+            return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
+                   << (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
+                   << "Entry-point has conflicting " << storage_class
+                   << " location assignment at location " << l / 4
+                   << ", component " << l % 4;
+          }
         }
       }
     }

+ 8 - 11
3rdparty/spirv-tools/source/val/validate_invalid_type.cpp

@@ -69,12 +69,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpGroupNonUniformFMul:
     case spv::Op::OpGroupNonUniformFMin: {
       const uint32_t result_type = inst->type_id();
-      if (_.IsBfloat16ScalarType(result_type) ||
-          _.IsBfloat16VectorType(result_type)) {
+      if (_.IsBfloat16Type(result_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
       }
-      if (_.IsFP8ScalarOrVectorType(result_type)) {
+      if (_.IsFP8Type(result_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode)
                << " doesn't support FP8 E4M3/E5M2 types.";
@@ -103,12 +102,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
     case spv::Op::OpIsNormal:
     case spv::Op::OpSignBitSet: {
       const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
-      if (_.IsBfloat16ScalarType(operand_type) ||
-          _.IsBfloat16VectorType(operand_type)) {
+      if (_.IsBfloat16Type(operand_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
       }
-      if (_.IsFP8ScalarOrVectorType(operand_type)) {
+      if (_.IsFP8Type(operand_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode)
                << " doesn't support FP8 E4M3/E5M2 types.";
@@ -118,12 +116,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
 
     case spv::Op::OpGroupNonUniformAllEqual: {
       const auto value_type = _.GetOperandTypeId(inst, 3);
-      if (_.IsBfloat16ScalarType(value_type) ||
-          _.IsBfloat16VectorType(value_type)) {
+      if (_.IsBfloat16Type(value_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
       }
-      if (_.IsFP8ScalarOrVectorType(value_type)) {
+      if (_.IsFP8Type(value_type)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << spvOpcodeString(opcode)
                << " doesn't support FP8 E4M3/E5M2 types.";
@@ -140,12 +137,12 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
       uint32_t res_component_type = 0;
       if (_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
                               &res_col_type, &res_component_type)) {
-        if (_.IsBfloat16ScalarType(res_component_type)) {
+        if (_.IsBfloat16Type(res_component_type)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << spvOpcodeString(opcode)
                  << " doesn't support BFloat16 type.";
         }
-        if (_.IsFP8ScalarOrVectorType(res_component_type)) {
+        if (_.IsFP8Type(res_component_type)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << spvOpcodeString(opcode)
                  << " doesn't support FP8 E4M3/E5M2 types.";

+ 79 - 3
3rdparty/spirv-tools/source/val/validate_layout.cpp

@@ -342,13 +342,84 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
         break;
     }
   } else {
-    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
-           << spvOpcodeString(opcode)
-           << " cannot appear in a function declaration";
+    _.ProgressToNextLayoutSectionOrder();
+    // All function sections have been processed. Recursively call
+    // ModuleLayoutPass to process the next section of the module
+    return ModuleLayoutPass(_, inst);
   }
   return SPV_SUCCESS;
 }
 
+spv_result_t GraphScopedInstructions(ValidationState_t& _,
+                                     const Instruction* inst, spv::Op opcode) {
+  if (_.IsOpcodeInCurrentLayoutSection(opcode)) {
+    switch (opcode) {
+      case spv::Op::OpGraphARM: {
+        if (_.graph_definition_region() > kGraphDefinitionOutside) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << "Cannot define a graph in a graph";
+        }
+        _.SetGraphDefinitionRegion(kGraphDefinitionBegin);
+      } break;
+      case spv::Op::OpGraphInputARM: {
+        if ((_.graph_definition_region() != kGraphDefinitionBegin) &&
+            (_.graph_definition_region() != kGraphDefinitionInputs)) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << "OpGraphInputARM"
+                 << " must immediately follow an OpGraphARM or OpGraphInputARM "
+                    "instruction.";
+        }
+        _.SetGraphDefinitionRegion(kGraphDefinitionInputs);
+      } break;
+      case spv::Op::OpGraphSetOutputARM: {
+        if ((_.graph_definition_region() != kGraphDefinitionBegin) &&
+            (_.graph_definition_region() != kGraphDefinitionInputs) &&
+            (_.graph_definition_region() != kGraphDefinitionBody) &&
+            (_.graph_definition_region() != kGraphDefinitionOutputs)) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << "Op" << spvOpcodeString(opcode)
+                 << " must immediately precede an OpGraphEndARM or "
+                    "OpGraphSetOutputARM instruction.";
+        }
+        _.SetGraphDefinitionRegion(kGraphDefinitionOutputs);
+      } break;
+      case spv::Op::OpGraphEndARM: {
+        if (_.graph_definition_region() != kGraphDefinitionOutputs) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << spvOpcodeString(opcode)
+                 << " must be preceded by at least one OpGraphSetOutputARM "
+                    "instruction";
+        }
+        _.SetGraphDefinitionRegion(kGraphDefinitionOutside);
+      } break;
+      case spv::Op::OpGraphEntryPointARM:
+        if (_.graph_definition_region() != kGraphDefinitionOutside) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << spvOpcodeString(opcode)
+                 << " cannot appear in the definition of a graph";
+        }
+        break;
+      default:
+        if (_.graph_definition_region() == kGraphDefinitionOutside) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << "Op" << spvOpcodeString(opcode)
+                 << " must appear in a graph body";
+        }
+        if (_.graph_definition_region() == kGraphDefinitionOutputs) {
+          return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+                 << spvOpcodeString(opcode)
+                 << " cannot appear after a graph output instruction";
+        }
+        _.SetGraphDefinitionRegion(kGraphDefinitionBody);
+        break;
+    }
+  } else {
+    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+           << "Op" << spvOpcodeString(opcode)
+           << " cannot appear in the graph definitions section";
+  }
+  return SPV_SUCCESS;
+}
 }  // namespace
 
 // TODO(umar): Check linkage capabilities for function declarations
@@ -379,6 +450,11 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst) {
         return error;
       }
       break;
+    case kLayoutGraphDefinitions:
+      if (auto error = GraphScopedInstructions(_, inst, opcode)) {
+        return error;
+      }
+      break;
   }
   return SPV_SUCCESS;
 }

+ 120 - 22
3rdparty/spirv-tools/source/val/validate_memory.cpp

@@ -196,10 +196,10 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
   return false;
 }
 
-std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
-    ValidationState_t& _, const Instruction* inst) {
-  spv::StorageClass dst_sc = spv::StorageClass::Max;
-  spv::StorageClass src_sc = spv::StorageClass::Max;
+std::pair<Instruction*, Instruction*> GetPointerTypes(ValidationState_t& _,
+                                                      const Instruction* inst) {
+  Instruction* dst_pointer_type = nullptr;
+  Instruction* src_pointer_type = nullptr;
   switch (inst->opcode()) {
     case spv::Op::OpCooperativeMatrixLoadNV:
     case spv::Op::OpCooperativeMatrixLoadTensorNV:
@@ -207,8 +207,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
     case spv::Op::OpCooperativeVectorLoadNV:
     case spv::Op::OpLoad: {
       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
-      auto load_pointer_type = _.FindDef(load_pointer->type_id());
-      dst_sc = load_pointer_type->GetOperandAs<spv::StorageClass>(1);
+      dst_pointer_type = _.FindDef(load_pointer->type_id());
       break;
     }
     case spv::Op::OpCooperativeMatrixStoreNV:
@@ -217,25 +216,23 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
     case spv::Op::OpCooperativeVectorStoreNV:
     case spv::Op::OpStore: {
       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
-      auto store_pointer_type = _.FindDef(store_pointer->type_id());
-      dst_sc = store_pointer_type->GetOperandAs<spv::StorageClass>(1);
+      dst_pointer_type = _.FindDef(store_pointer->type_id());
       break;
     }
+    // Spec: "Matching Storage Class is not required"
     case spv::Op::OpCopyMemory:
     case spv::Op::OpCopyMemorySized: {
-      auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
-      auto dst_type = _.FindDef(dst->type_id());
-      dst_sc = dst_type->GetOperandAs<spv::StorageClass>(1);
-      auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
-      auto src_type = _.FindDef(src->type_id());
-      src_sc = src_type->GetOperandAs<spv::StorageClass>(1);
+      auto dst_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
+      dst_pointer_type = _.FindDef(dst_pointer->type_id());
+      auto src_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(1));
+      src_pointer_type = _.FindDef(src_pointer->type_id());
       break;
     }
     default:
       break;
   }
 
-  return std::make_pair(dst_sc, src_sc);
+  return std::make_pair(dst_pointer_type, src_pointer_type);
 }
 
 // Returns the number of instruction words taken up by a memory access
@@ -288,8 +285,17 @@ bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
 
 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
                                uint32_t index) {
-  spv::StorageClass dst_sc, src_sc;
-  std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
+  Instruction* dst_pointer_type = nullptr;
+  Instruction* src_pointer_type = nullptr;  // only used for OpCopyMemory
+  std::tie(dst_pointer_type, src_pointer_type) = GetPointerTypes(_, inst);
+
+  const spv::StorageClass dst_sc =
+      dst_pointer_type ? dst_pointer_type->GetOperandAs<spv::StorageClass>(1)
+                       : spv::StorageClass::Max;
+  const spv::StorageClass src_sc =
+      src_pointer_type ? src_pointer_type->GetOperandAs<spv::StorageClass>(1)
+                       : spv::StorageClass::Max;
+
   if (inst->operands().size() <= index) {
     // Cases where lack of some operand is invalid
     if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
@@ -390,6 +396,23 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
              << "Memory accesses Aligned operand value " << aligned_value
              << " is not a power of two.";
     }
+
+    uint32_t largest_scalar = 0;
+    if (dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
+      largest_scalar =
+          _.GetLargestScalarType(dst_pointer_type->GetOperandAs<uint32_t>(2));
+    }
+    if (src_sc == spv::StorageClass::PhysicalStorageBuffer) {
+      largest_scalar = std::max(
+          largest_scalar,
+          _.GetLargestScalarType(src_pointer_type->GetOperandAs<uint32_t>(2)));
+    }
+    if (aligned_value < largest_scalar) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << _.VkErrorID(6314) << "Memory accesses Aligned operand value "
+             << aligned_value << " is too small, the largest scalar type is "
+             << largest_scalar << " bytes.";
+    }
   }
 
   return SPV_SUCCESS;
@@ -435,6 +458,7 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
       }
       if (spvIsVulkanEnv(_.context()->target_env)) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << _.VkErrorID(11167)
                << "Vulkan requires that data type be specified";
       }
     }
@@ -1555,6 +1579,60 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Base type must be a non-pointer type";
     }
+
+    const auto ContainsBlock = [&_](const Instruction* type_inst) {
+      if (type_inst->opcode() == spv::Op::OpTypeStruct) {
+        if (_.HasDecoration(type_inst->id(), spv::Decoration::Block) ||
+            _.HasDecoration(type_inst->id(), spv::Decoration::BufferBlock)) {
+          return true;
+        }
+      }
+      return false;
+    };
+
+    // Block (and BufferBlock) arrays cannot be reinterpreted via untyped access
+    // chains.
+    const bool base_type_block_array =
+        base_type->opcode() == spv::Op::OpTypeArray &&
+        _.ContainsType(base_type->id(), ContainsBlock,
+                       /* traverse_all_types = */ false);
+
+    const auto base_index = untyped_pointer ? 3 : 2;
+    const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
+    auto base = _.FindDef(base_id);
+    // Strictly speaking this misses trivial access chains and function
+    // parameter chasing, but that would be a significant complication in the
+    // traversal.
+    while (base->opcode() == spv::Op::OpCopyObject) {
+      base = _.FindDef(base->GetOperandAs<uint32_t>(2));
+    }
+    const Instruction* base_data_type = nullptr;
+    if (base->opcode() == spv::Op::OpVariable) {
+      const auto ptr_type = _.FindDef(base->type_id());
+      base_data_type = _.FindDef(ptr_type->GetOperandAs<uint32_t>(2));
+    } else if (base->opcode() == spv::Op::OpUntypedVariableKHR) {
+      if (base->operands().size() > 3) {
+        base_data_type = _.FindDef(base->GetOperandAs<uint32_t>(3));
+      }
+    }
+
+    if (base_data_type) {
+      const bool base_block_array =
+          base_data_type->opcode() == spv::Op::OpTypeArray &&
+          _.ContainsType(base_data_type->id(), ContainsBlock,
+                         /* traverse_all_types = */ false);
+
+      if (base_type_block_array != base_block_array) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "Both Base Type and Base must be Block or BufferBlock arrays "
+                  "or neither can be";
+      } else if (base_type_block_array && base_block_array &&
+                 base_type->id() != base_data_type->id()) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << "If Base or Base Type is a Block or BufferBlock array, the "
+                  "other must also be the same array";
+      }
+    }
   }
 
   // Base must be a pointer, pointing to the base of a composite object.
@@ -1845,14 +1923,34 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
 
   const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
 
-  const auto base_id = inst->GetOperandAs<uint32_t>(2);
-  const auto base = _.FindDef(base_id);
-  const auto base_type = untyped_pointer
-                             ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
-                             : _.FindDef(base->type_id());
+  const auto base_idx = untyped_pointer ? 3 : 2;
+  const auto base = _.FindDef(inst->GetOperandAs<uint32_t>(base_idx));
+  const auto base_type = _.FindDef(base->type_id());
   const auto base_type_storage_class =
       base_type->GetOperandAs<spv::StorageClass>(1);
 
+  const auto element_idx = untyped_pointer ? 4 : 3;
+  const auto element = _.FindDef(inst->GetOperandAs<uint32_t>(element_idx));
+  const auto element_type = _.FindDef(element->type_id());
+  if (!element_type || element_type->opcode() != spv::Op::OpTypeInt) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Element must be an integer";
+  }
+  uint64_t element_val = 0;
+  if (_.EvalConstantValUint64(element->id(), &element_val)) {
+    if (element_val != 0) {
+      const auto interp_type =
+          untyped_pointer ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
+                          : _.FindDef(base_type->GetOperandAs<uint32_t>(2));
+      if (interp_type->opcode() == spv::Op::OpTypeStruct &&
+          (_.HasDecoration(interp_type->id(), spv::Decoration::Block) ||
+           _.HasDecoration(interp_type->id(), spv::Decoration::BufferBlock))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Element must be 0 if the interpretation type is a Block- or "
+                  "BufferBlock-decorated structure";
+      }
+    }
+  }
+
   if (_.HasCapability(spv::Capability::Shader) &&
       (base_type_storage_class == spv::StorageClass::Uniform ||
        base_type_storage_class == spv::StorageClass::StorageBuffer ||

+ 175 - 147
3rdparty/spirv-tools/source/val/validate_memory_semantics.cpp

@@ -32,6 +32,9 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
   uint32_t value = 0;
   std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id);
 
+  const bool is_vulkan = spvIsVulkanEnv(_.context()->target_env) ||
+                         _.memory_model() == spv::MemoryModel::VulkanKHR;
+
   if (!is_int32) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << spvOpcodeString(opcode)
@@ -56,6 +59,21 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
     return SPV_SUCCESS;
   }
 
+  if (value & uint32_t(spv::MemorySemanticsMask::UniformMemory) &&
+      !_.HasCapability(spv::Capability::Shader)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(opcode)
+           << ": Memory Semantics UniformMemory requires capability Shader";
+  }
+
+  if (value & uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR) &&
+      !_.HasCapability(spv::Capability::VulkanMemoryModel)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << spvOpcodeString(opcode)
+           << ": Memory Semantics OutputMemoryKHR requires capability "
+           << "VulkanMemoryModelKHR";
+  }
+
   const size_t num_memory_order_set_bits = spvtools::utils::CountSetBits(
       value & uint32_t(spv::MemorySemanticsMask::Acquire |
                        spv::MemorySemanticsMask::Release |
@@ -64,197 +82,207 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
 
   if (num_memory_order_set_bits > 1) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics can have at most one of the following "
-              "bits set: Acquire, Release, AcquireRelease or "
-              "SequentiallyConsistent";
+           << _.VkErrorID(10865) << spvOpcodeString(opcode)
+           << ": Memory Semantics must have at most one non-relaxed "
+              "memory order bit set";
   }
 
-  if (_.memory_model() == spv::MemoryModel::VulkanKHR &&
-      value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent)) {
+  if (is_vulkan &&
+      (value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "SequentiallyConsistent memory "
-              "semantics cannot be used with "
-              "the VulkanKHR memory model.";
+           << _.VkErrorID(10866) << spvOpcodeString(opcode)
+           << ": Memory Semantics with SequentiallyConsistent memory order "
+              "must not be used in the Vulkan API";
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR) &&
-      !_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
+  if ((opcode == spv::Op::OpAtomicStore ||
+       opcode == spv::Op::OpAtomicFlagClear) &&
+      (value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
+       value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics MakeAvailableKHR requires capability "
-           << "VulkanMemoryModelKHR";
+           << _.VkErrorID(10867) << spvOpcodeString(opcode)
+           << ": MemorySemantics must not use Acquire or AcquireRelease "
+              "memory order with "
+           << spvOpcodeString(opcode);
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR) &&
-      !_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
+  if (opcode == spv::Op::OpAtomicLoad &&
+      (value & uint32_t(spv::MemorySemanticsMask::Release) ||
+       value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics MakeVisibleKHR requires capability "
-           << "VulkanMemoryModelKHR";
+           << _.VkErrorID(10868) << spvOpcodeString(opcode)
+           << ": MemorySemantics must not use Release or AcquireRelease "
+              "memory order with "
+           << spvOpcodeString(opcode);
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR) &&
-      !_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
+  // In OpenCL, a relaxed fence has no effect but is not explicitly forbidden
+  if (is_vulkan && opcode == spv::Op::OpMemoryBarrier &&
+      !num_memory_order_set_bits) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics OutputMemoryKHR requires capability "
-           << "VulkanMemoryModelKHR";
+           << _.VkErrorID(10869) << spvOpcodeString(opcode)
+           << ": MemorySemantics must not use Relaxed memory order with "
+           << spvOpcodeString(opcode);
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::Volatile)) {
-    if (!_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
+  if (is_vulkan) {
+    const bool includes_storage_class =
+        value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
+                         spv::MemorySemanticsMask::WorkgroupMemory |
+                         spv::MemorySemanticsMask::ImageMemory |
+                         spv::MemorySemanticsMask::OutputMemoryKHR);
+
+    if (num_memory_order_set_bits && !includes_storage_class) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << spvOpcodeString(opcode)
-             << ": Memory Semantics Volatile requires capability "
-                "VulkanMemoryModelKHR";
+             << _.VkErrorID(10870) << spvOpcodeString(opcode)
+             << ": Memory Semantics with a non-relaxed memory order (Acquire, "
+                "Release, or AcquireRelease) must have at least one "
+                "Vulkan-supported storage class semantics bit set "
+                "(UniformMemory, WorkgroupMemory, ImageMemory, or "
+                "OutputMemory)";
     }
 
-    if (!spvOpcodeIsAtomicOp(inst->opcode())) {
+    if (!num_memory_order_set_bits && includes_storage_class) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Memory Semantics Volatile can only be used with atomic "
-                "instructions";
+             << _.VkErrorID(10871) << spvOpcodeString(opcode)
+             << ": Memory Semantics with at least one Vulkan-supported "
+                "storage class semantics bit set (UniformMemory, "
+                "WorkgroupMemory, ImageMemory, or OutputMemory) must use "
+                "a non-relaxed memory order (Acquire, Release, or "
+                "AcquireRelease)";
     }
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::UniformMemory) &&
-      !_.HasCapability(spv::Capability::Shader)) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics UniformMemory requires capability Shader";
-  }
-
-  // Checking for spv::Capability::AtomicStorage is intentionally not done here.
-  // See https://github.com/KhronosGroup/glslang/issues/1618 for the reasoning
-  // why.
-
-  if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR |
-                       spv::MemorySemanticsMask::MakeVisibleKHR)) {
-    const bool includes_storage_class =
-        value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
-                         spv::MemorySemanticsMask::SubgroupMemory |
-                         spv::MemorySemanticsMask::WorkgroupMemory |
-                         spv::MemorySemanticsMask::CrossWorkgroupMemory |
-                         spv::MemorySemanticsMask::AtomicCounterMemory |
-                         spv::MemorySemanticsMask::ImageMemory |
-                         spv::MemorySemanticsMask::OutputMemoryKHR);
-
-    if (!includes_storage_class) {
+  if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR)) {
+    if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << spvOpcodeString(opcode)
-             << ": expected Memory Semantics to include a storage class";
+             << ": Memory Semantics MakeAvailableKHR requires capability "
+             << "VulkanMemoryModelKHR";
+    }
+    if (!(value & uint32_t(spv::MemorySemanticsMask::Release |
+                           spv::MemorySemanticsMask::AcquireRelease))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << _.VkErrorID(10872) << spvOpcodeString(opcode)
+             << ": Memory Semantics with MakeAvailable bit set must use "
+                "Release or AcquireRelease memory order";
     }
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR) &&
-      !(value & uint32_t(spv::MemorySemanticsMask::Acquire |
-                         spv::MemorySemanticsMask::AcquireRelease))) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": MakeVisibleKHR Memory Semantics also requires either Acquire "
-              "or AcquireRelease Memory Semantics";
+  if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR)) {
+    if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << spvOpcodeString(opcode)
+             << ": Memory Semantics MakeVisibleKHR requires capability "
+             << "VulkanMemoryModelKHR";
+    }
+    if (!(value & uint32_t(spv::MemorySemanticsMask::Acquire |
+                           spv::MemorySemanticsMask::AcquireRelease))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << _.VkErrorID(10873) << spvOpcodeString(opcode)
+             << ": Memory Semantics with MakeVisible bit set must use Acquire "
+                "or AcquireRelease memory order";
+    }
   }
 
-  if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR) &&
-      !(value & uint32_t(spv::MemorySemanticsMask::Release |
-                         spv::MemorySemanticsMask::AcquireRelease))) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": MakeAvailableKHR Memory Semantics also requires either "
-              "Release or AcquireRelease Memory Semantics";
+  if (value & uint32_t(spv::MemorySemanticsMask::Volatile)) {
+    if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << spvOpcodeString(opcode)
+             << ": Memory Semantics Volatile requires capability "
+                "VulkanMemoryModelKHR";
+    }
+    if (!spvOpcodeIsAtomicOp(inst->opcode())) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << _.VkErrorID(10874) << spvOpcodeString(opcode)
+             << ": Memory Semantics with Volatile bit set must not be used "
+                "with barrier instructions";
+    }
   }
 
-  if (spvIsVulkanEnv(_.context()->target_env)) {
-    const bool includes_storage_class =
-        value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
-                         spv::MemorySemanticsMask::WorkgroupMemory |
-                         spv::MemorySemanticsMask::ImageMemory |
-                         spv::MemorySemanticsMask::OutputMemoryKHR);
-
-    if (opcode == spv::Op::OpMemoryBarrier && !num_memory_order_set_bits) {
+  if ((opcode == spv::Op::OpAtomicCompareExchange ||
+       opcode == spv::Op::OpAtomicCompareExchangeWeak) &&
+      operand_index == 5) {
+    if (value & uint32_t(spv::MemorySemanticsMask::Release) ||
+        value & uint32_t(spv::MemorySemanticsMask::AcquireRelease)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4732) << spvOpcodeString(opcode)
-             << ": Vulkan specification requires Memory Semantics to have "
-                "one of the following bits set: Acquire, Release, "
-                "AcquireRelease or SequentiallyConsistent";
-    } else if (opcode != spv::Op::OpMemoryBarrier &&
-               num_memory_order_set_bits) {
-      // should leave only atomics and control barriers for Vulkan env
-      bool memory_is_int32 = false, memory_is_const_int32 = false;
-      uint32_t memory_value = 0;
-      std::tie(memory_is_int32, memory_is_const_int32, memory_value) =
-          _.EvalInt32IfConst(memory_scope);
-      if (memory_is_int32 &&
-          spv::Scope(memory_value) == spv::Scope::Invocation) {
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << _.VkErrorID(4641) << spvOpcodeString(opcode)
-               << ": Vulkan specification requires Memory Semantics to be None "
-                  "if used with Invocation Memory Scope";
-      }
+             << _.VkErrorID(10875) << spvOpcodeString(opcode)
+             << " Unequal Memory Semantics must not use Release or "
+                "AcquireRelease memory order";
     }
 
-    if (opcode == spv::Op::OpMemoryBarrier && !includes_storage_class) {
+    bool is_equal_int32 = false;
+    bool is_equal_const = false;
+    uint32_t equal_value = 0;
+    std::tie(is_equal_int32, is_equal_const, equal_value) =
+        _.EvalInt32IfConst(inst->GetOperandAs<uint32_t>(4));
+
+    const auto equal_mask_seq_cst =
+        uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent);
+    const auto equal_mask_acquire = uint32_t(
+        // Allow EqualMemorySemantics Release with UnequalMemorySemantics
+        // Acquire, since the C standard doesn't clearly forbid it.
+        spv::MemorySemanticsMask::SequentiallyConsistent |
+        spv::MemorySemanticsMask::AcquireRelease |
+        spv::MemorySemanticsMask::Release | spv::MemorySemanticsMask::Acquire);
+
+    if (((value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent)) &&
+         !(equal_value & equal_mask_seq_cst)) ||
+        ((value & uint32_t(spv::MemorySemanticsMask::Acquire)) &&
+         !(equal_value & equal_mask_acquire))) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4733) << spvOpcodeString(opcode)
-             << ": expected Memory Semantics to include a Vulkan-supported "
-                "storage class";
+             << _.VkErrorID(10876) << spvOpcodeString(opcode)
+             << " Unequal Memory Semantics must not use a stronger memory "
+                "order than the corresponding Equal Memory Semantics";
     }
 
-    if (opcode == spv::Op::OpControlBarrier && value) {
-      if (!num_memory_order_set_bits) {
+    if (is_vulkan) {
+      auto storage_class_semantics_mask =
+          uint32_t(spv::MemorySemanticsMask::UniformMemory |
+                   spv::MemorySemanticsMask::WorkgroupMemory |
+                   spv::MemorySemanticsMask::ImageMemory |
+                   spv::MemorySemanticsMask::OutputMemoryKHR);
+
+      if (value & ~equal_value & storage_class_semantics_mask) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << _.VkErrorID(10609) << spvOpcodeString(opcode)
-               << ": Vulkan specification requires non-zero Memory Semantics "
-                  "to have one of the following bits set: Acquire, Release, "
-                  "AcquireRelease or SequentiallyConsistent";
+               << _.VkErrorID(10877) << spvOpcodeString(opcode)
+               << " Unequal Memory Semantics must not have any "
+                  "Vulkan-supported storage class semantics bit set "
+                  "(UniformMemory, WorkgroupMemory, ImageMemory, or "
+                  "OutputMemory) unless this bit is also set in the "
+                  "corresponding Equal Memory Semantics";
       }
-      if (!includes_storage_class) {
+
+      if (value & ~equal_value &
+          uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << _.VkErrorID(4650) << spvOpcodeString(opcode)
-               << ": expected Memory Semantics to include a Vulkan-supported "
-                  "storage class if Memory Semantics is not None";
+               << _.VkErrorID(10878) << spvOpcodeString(opcode)
+               << " Unequal Memory Semantics must not have MakeVisible bit set "
+                  "unless this bit is also set in the corresponding Equal "
+                  "Memory Semantics";
       }
-    }
-  }
-
-  if (opcode == spv::Op::OpAtomicFlagClear &&
-      (value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
-       value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Memory Semantics Acquire and AcquireRelease cannot be used "
-              "with "
-           << spvOpcodeString(opcode);
-  }
 
-  if (opcode == spv::Op::OpAtomicCompareExchange && operand_index == 5 &&
-      (value & uint32_t(spv::MemorySemanticsMask::Release) ||
-       value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << spvOpcodeString(opcode)
-           << ": Memory Semantics Release and AcquireRelease cannot be "
-              "used "
-              "for operand Unequal";
-  }
-
-  if (spvIsVulkanEnv(_.context()->target_env)) {
-    if (opcode == spv::Op::OpAtomicLoad &&
-        (value & uint32_t(spv::MemorySemanticsMask::Release) ||
-         value & uint32_t(spv::MemorySemanticsMask::AcquireRelease) ||
-         value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4731)
-             << "Vulkan spec disallows OpAtomicLoad with Memory Semantics "
-                "Release, AcquireRelease and SequentiallyConsistent";
+      if ((equal_value & uint32_t(spv::MemorySemanticsMask::Volatile)) ^
+          (value & uint32_t(spv::MemorySemanticsMask::Volatile))) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << _.VkErrorID(10879) << spvOpcodeString(opcode)
+               << " Unequal Memory Semantics must have Volatile bit set if and "
+                  "only if this bit is also set in the corresponding Equal "
+                  "Memory Semantics";
+      }
     }
+  }
 
-    if (opcode == spv::Op::OpAtomicStore &&
-        (value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
-         value & uint32_t(spv::MemorySemanticsMask::AcquireRelease) ||
-         value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
+  if (is_vulkan && num_memory_order_set_bits) {
+    bool memory_is_int32 = false, memory_is_const_int32 = false;
+    uint32_t memory_value = 0;
+    std::tie(memory_is_int32, memory_is_const_int32, memory_value) =
+        _.EvalInt32IfConst(memory_scope);
+    if (memory_is_int32 && spv::Scope(memory_value) == spv::Scope::Invocation) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << _.VkErrorID(4730)
-             << "Vulkan spec disallows OpAtomicStore with Memory Semantics "
-                "Acquire, AcquireRelease and SequentiallyConsistent";
+             << _.VkErrorID(4641) << spvOpcodeString(opcode)
+             << ": Vulkan specification requires Memory Semantics to be "
+                "Relaxed if used with Invocation Memory Scope";
     }
   }
 

+ 123 - 72
3rdparty/spirv-tools/source/val/validate_mode_setting.cpp

@@ -59,20 +59,22 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
   }
 
   const auto* execution_modes = _.GetExecutionModes(entry_point_id);
+  auto has_mode = [&execution_modes](spv::ExecutionMode mode) {
+    return execution_modes && execution_modes->count(mode);
+  };
+
   if (_.HasCapability(spv::Capability::Shader)) {
     switch (execution_model) {
       case spv::ExecutionModel::Fragment:
-        if (execution_modes &&
-            execution_modes->count(spv::ExecutionMode::OriginUpperLeft) &&
-            execution_modes->count(spv::ExecutionMode::OriginLowerLeft)) {
+        if (has_mode(spv::ExecutionMode::OriginUpperLeft) &&
+            has_mode(spv::ExecutionMode::OriginLowerLeft)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Fragment execution model entry points can only specify "
                     "one of OriginUpperLeft or OriginLowerLeft execution "
                     "modes.";
         }
-        if (!execution_modes ||
-            (!execution_modes->count(spv::ExecutionMode::OriginUpperLeft) &&
-             !execution_modes->count(spv::ExecutionMode::OriginLowerLeft))) {
+        if (!has_mode(spv::ExecutionMode::OriginUpperLeft) &&
+            !has_mode(spv::ExecutionMode::OriginLowerLeft)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Fragment execution model entry points require either an "
                     "OriginUpperLeft or OriginLowerLeft execution mode.";
@@ -285,36 +287,31 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
     }
   }
 
+  bool has_workgroup_size = false;
+  bool has_local_size_id = false;
+  for (auto& i : _.ordered_instructions()) {
+    if (i.opcode() == spv::Op::OpFunction) break;
+    if (i.opcode() == spv::Op::OpDecorate && i.operands().size() > 2) {
+      if (i.GetOperandAs<spv::Decoration>(1) == spv::Decoration::BuiltIn &&
+          i.GetOperandAs<spv::BuiltIn>(2) == spv::BuiltIn::WorkgroupSize) {
+        has_workgroup_size = true;
+      }
+    }
+    if (i.opcode() == spv::Op::OpExecutionModeId) {
+      if (i.GetOperandAs<spv::ExecutionMode>(1) ==
+          spv::ExecutionMode::LocalSizeId) {
+        has_local_size_id = true;
+      }
+    }
+  }
+
   if (spvIsVulkanEnv(_.context()->target_env)) {
     switch (execution_model) {
       case spv::ExecutionModel::GLCompute:
-        if (!execution_modes ||
-            !execution_modes->count(spv::ExecutionMode::LocalSize)) {
-          bool ok = false;
-          for (auto& i : _.ordered_instructions()) {
-            if (i.opcode() == spv::Op::OpDecorate) {
-              if (i.operands().size() > 2) {
-                if (i.GetOperandAs<spv::Decoration>(1) ==
-                        spv::Decoration::BuiltIn &&
-                    i.GetOperandAs<spv::BuiltIn>(2) ==
-                        spv::BuiltIn::WorkgroupSize) {
-                  ok = true;
-                  break;
-                }
-              }
-            }
-            if (i.opcode() == spv::Op::OpExecutionModeId) {
-              const auto mode = i.GetOperandAs<spv::ExecutionMode>(1);
-              if (mode == spv::ExecutionMode::LocalSizeId) {
-                ok = true;
-                break;
-              }
-            }
-          }
+        if (!has_mode(spv::ExecutionMode::LocalSize)) {
+          bool ok = has_workgroup_size || has_local_size_id;
           if (!ok && _.HasCapability(spv::Capability::TileShadingQCOM)) {
-            ok =
-                execution_modes &&
-                execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM);
+            ok = has_mode(spv::ExecutionMode::TileShadingRateQCOM);
           }
           if (!ok) {
             return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -332,25 +329,20 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
         }
 
         if (_.HasCapability(spv::Capability::TileShadingQCOM)) {
-          if (execution_modes) {
-            if (execution_modes->count(
-                    spv::ExecutionMode::TileShadingRateQCOM) &&
-                (execution_modes->count(spv::ExecutionMode::LocalSize) ||
-                 execution_modes->count(spv::ExecutionMode::LocalSizeId))) {
-              return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                     << "If the TileShadingRateQCOM execution mode is used, "
-                     << "LocalSize and LocalSizeId must not be specified.";
-            }
-            if (execution_modes->count(
-                    spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
-              return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                     << "The NonCoherentTileAttachmentQCOM execution mode must "
-                        "not be used in any stage other than fragment.";
-            }
+          if (has_mode(spv::ExecutionMode::TileShadingRateQCOM) &&
+              (has_mode(spv::ExecutionMode::LocalSize) ||
+               has_mode(spv::ExecutionMode::LocalSizeId))) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "If the TileShadingRateQCOM execution mode is used, "
+                   << "LocalSize and LocalSizeId must not be specified.";
+          }
+          if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "The NonCoherentTileAttachmentQCOM execution mode must "
+                      "not be used in any stage other than fragment.";
           }
         } else {
-          if (execution_modes &&
-              execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM)) {
+          if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) {
             return _.diag(SPV_ERROR_INVALID_DATA, inst)
                    << "If the TileShadingRateQCOM execution mode is used, the "
                       "TileShadingQCOM capability must be enabled.";
@@ -358,16 +350,13 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
         }
         break;
       default:
-        if (execution_modes &&
-            execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM)) {
+        if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "The TileShadingRateQCOM execution mode must not be used "
                     "in any stage other than compute.";
         }
         if (execution_model != spv::ExecutionModel::Fragment) {
-          if (execution_modes &&
-              execution_modes->count(
-                  spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
+          if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
             return _.diag(SPV_ERROR_INVALID_DATA, inst)
                    << "The NonCoherentTileAttachmentQCOM execution mode must "
                       "not be used in any stage other than fragment.";
@@ -378,9 +367,7 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
                       "any stage other than compute or fragment.";
           }
         } else {
-          if (execution_modes &&
-              execution_modes->count(
-                  spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
+          if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
             if (!_.HasCapability(spv::Capability::TileShadingQCOM)) {
               return _.diag(SPV_ERROR_INVALID_DATA, inst)
                      << "If the NonCoherentTileAttachmentReadQCOM execution "
@@ -393,7 +380,9 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
     }
   }
 
-  if (_.EntryPointHasLocalSizeOrId(entry_point_id)) {
+  // WorkgroupSize decoration takes precedence over any LocalSize or LocalSizeId
+  // execution mode, so the values can be ignored
+  if (_.EntryPointHasLocalSizeOrId(entry_point_id) && !has_workgroup_size) {
     const Instruction* local_size_inst =
         _.EntryPointLocalSizeOrId(entry_point_id);
     if (local_size_inst) {
@@ -402,7 +391,8 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
       const uint32_t operand_y = local_size_inst->GetOperandAs<uint32_t>(3);
       const uint32_t operand_z = local_size_inst->GetOperandAs<uint32_t>(4);
       if (mode == spv::ExecutionMode::LocalSize) {
-        if ((operand_x * operand_y * operand_z) == 0) {
+        const uint64_t product_size = operand_x * operand_y * operand_z;
+        if (product_size == 0) {
           return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
                  << "Local Size execution mode must not have a product of zero "
                     "(X "
@@ -410,6 +400,32 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
                  << operand_x << ", Y = " << operand_y << ", Z = " << operand_z
                  << ").";
         }
+        if (has_mode(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
+          if (operand_x % 2 != 0 || operand_y % 2 != 0) {
+            return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
+                   << _.VkErrorID(10151)
+                   << "Local Size execution mode dimensions is "
+                      "(X = "
+                   << operand_x << ", Y = " << operand_y
+                   << ") but Entry Point id " << entry_point_id
+                   << " also has an DerivativeGroupQuadsKHR execution mode, so "
+                      "both dimensions must be a multiple of 2";
+          }
+        }
+        if (has_mode(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
+          if (product_size % 4 != 0) {
+            return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
+                   << _.VkErrorID(10152)
+                   << "Local Size execution mode dimensions is (X = "
+                   << operand_x << ", Y = " << operand_y
+                   << ", Z = " << operand_z << ") but Entry Point id "
+                   << entry_point_id
+                   << " also has an DerivativeGroupLinearKHR execution mode, "
+                      "so "
+                      "the product ("
+                   << product_size << ") must be a multiple of 4";
+          }
+        }
       } else if (mode == spv::ExecutionMode::LocalSizeId) {
         // can only validate product if static and not spec constant
         // (This is done for us in EvalConstantValUint64)
@@ -417,13 +433,42 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
         bool static_x = _.EvalConstantValUint64(operand_x, &x_size);
         bool static_y = _.EvalConstantValUint64(operand_y, &y_size);
         bool static_z = _.EvalConstantValUint64(operand_z, &z_size);
-        if (static_x && static_y && static_z &&
-            ((x_size * y_size * z_size) == 0)) {
-          return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
-                 << "Local Size Id execution mode must not have a product of "
-                    "zero "
-                    "(X = "
-                 << x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
+        if (static_x && static_y && static_z) {
+          const uint64_t product_size = x_size * y_size * z_size;
+          if (product_size == 0) {
+            return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
+                   << "LocalSizeId execution mode must not have a product of "
+                      "zero "
+                      "(X = "
+                   << x_size << ", Y = " << y_size << ", Z = " << z_size
+                   << ").";
+          }
+          if (has_mode(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
+            if (x_size % 2 != 0 || y_size % 2 != 0) {
+              return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
+                     << _.VkErrorID(10151)
+                     << "LocalSizeId execution mode dimensions is "
+                        "(X = "
+                     << x_size << ", Y = " << y_size << ") but Entry Point id "
+                     << entry_point_id
+                     << " also has an DerivativeGroupQuadsKHR execution mode, "
+                        "so "
+                        "both dimensions must be a multiple of 2";
+            }
+          }
+          if (has_mode(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
+            if (product_size % 4 != 0) {
+              return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
+                     << _.VkErrorID(10152)
+                     << "LocalSizeId execution mode dimensions is (X = "
+                     << x_size << ", Y = " << y_size << ", Z = " << z_size
+                     << ") but Entry Point id " << entry_point_id
+                     << " also has an DerivativeGroupLinearKHR execution mode, "
+                        "so "
+                        "the product ("
+                     << product_size << ") must be a multiple of 4";
+            }
+          }
         }
       }
     }
@@ -557,6 +602,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
               "Operands that are not id operands.";
   }
 
+  const bool is_vulkan_env = (spvIsVulkanEnv(_.context()->target_env));
   const auto* models = _.GetExecutionModels(entry_point_id);
   switch (mode) {
     case spv::ExecutionMode::Invocations:
@@ -667,7 +713,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
                     "tessellation execution model.";
         }
       }
-      if (spvIsVulkanEnv(_.context()->target_env)) {
+      if (is_vulkan_env) {
         if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
             inst->GetOperandAs<uint32_t>(2) == 0) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -690,8 +736,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
                   "execution "
                   "model.";
       }
-      if (mode == spv::ExecutionMode::OutputPrimitivesEXT &&
-          spvIsVulkanEnv(_.context()->target_env)) {
+      if (mode == spv::ExecutionMode::OutputPrimitivesEXT && is_vulkan_env) {
         if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
             inst->GetOperandAs<uint32_t>(2) == 0) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -761,9 +806,15 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
       break;
     case spv::ExecutionMode::LocalSize:
     case spv::ExecutionMode::LocalSizeId:
-      if (mode == spv::ExecutionMode::LocalSizeId && !_.IsLocalSizeIdAllowed())
+      if (mode == spv::ExecutionMode::LocalSizeId &&
+          !_.IsLocalSizeIdAllowed()) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "LocalSizeId mode is not allowed by the current environment.";
+               << "LocalSizeId mode is not allowed by the current environment."
+               << (is_vulkan_env
+                       ? _.MissingFeature("maintenance4 feature",
+                                          "--allow-localsizeid", false)
+                       : "");
+      }
 
       if (!std::all_of(
               models->begin(), models->end(),
@@ -812,7 +863,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
     }
   }
 
-  if (spvIsVulkanEnv(_.context()->target_env)) {
+  if (is_vulkan_env) {
     if (mode == spv::ExecutionMode::OriginLowerLeft) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << _.VkErrorID(4653)

+ 1 - 1
3rdparty/spirv-tools/source/val/validate_non_uniform.cpp

@@ -130,7 +130,7 @@ spv_result_t ValidateGroupNonUniformBroadcastShuffle(ValidationState_t& _,
     if (!spvOpcodeIsConstant(id_op)) {
       std::string operand = GetOperandName(inst->opcode());
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Before SPIR-V 1.5, " << operand
+             << "In SPIR-V 1.4 or earlier, " << operand
              << " must be a constant instruction";
     }
   }

+ 1 - 1
3rdparty/spirv-tools/source/val/validate_scopes.cpp

@@ -94,7 +94,7 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _,
 
   // Vulkan specific rules
   if (spvIsVulkanEnv(_.context()->target_env)) {
-    // Vulkan 1.1 specific rules
+    // Subgroups were not added until 1.1
     if (_.context()->target_env != SPV_ENV_VULKAN_1_0) {
       // Scope for Non Uniform Group Operations must be limited to Subgroup
       if ((spvOpcodeIsNonUniformGroupOperation(opcode) &&

+ 2 - 4
3rdparty/spirv-tools/source/val/validate_tensor.cpp

@@ -83,8 +83,7 @@ spv_result_t ValidateTensorRead(ValidationState_t& _, const Instruction* inst) {
   auto op_coord = inst->word(4);
   auto inst_coord = _.FindDef(op_coord);
   auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
-  if (tensor_rank == 0 ||
-      !_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
+  if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Coordinates to be an array whose Element Type is an "
               "integer type and whose Length is equal to the Rank of Tensor.";
@@ -143,8 +142,7 @@ spv_result_t ValidateTensorWrite(ValidationState_t& _,
   auto op_coord = inst->word(2);
   auto inst_coord = _.FindDef(op_coord);
   auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
-  if (tensor_rank == 0 ||
-      !_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
+  if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Coordinates to be an array whose Element Type is an "
               "integer type and whose Length is equal to the Rank of Tensor.";

+ 8 - 5
3rdparty/spirv-tools/source/val/validate_type.cpp

@@ -140,7 +140,7 @@ spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "8-bit floating point type requires an encoding.";
     }
-    const spvtools::OperandDesc* desc;
+    const spvtools::OperandDesc* desc = nullptr;
     const std::set<spv::FPEncoding> known_encodings{
         spv::FPEncoding::Float8E4M3EXT, spv::FPEncoding::Float8E5M2EXT};
     spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
@@ -433,10 +433,9 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
              << "Structure <id> " << _.getIdName(member_type_id)
              << " contains members with BuiltIn decoration. Therefore this "
              << "structure may not be contained as a member of another "
-             << "structure "
-             << "type. Structure <id> " << _.getIdName(struct_id)
-             << " contains structure <id> " << _.getIdName(member_type_id)
-             << ".";
+             << "structure " << "type. Structure <id> "
+             << _.getIdName(struct_id) << " contains structure <id> "
+             << _.getIdName(member_type_id) << ".";
     }
 
     if (spvIsVulkanEnv(_.context()->target_env) &&
@@ -562,6 +561,9 @@ spv_result_t ValidateTypePointer(ValidationState_t& _,
       // a storage image.
       if (sampled == 2) _.RegisterPointerToStorageImage(inst->id());
     }
+    if (type->opcode() == spv::Op::OpTypeTensorARM) {
+      _.RegisterPointerToTensor(inst->id());
+    }
   }
 
   if (!_.IsValidStorageClass(storage_class)) {
@@ -614,6 +616,7 @@ spv_result_t ValidateTypeFunction(ValidationState_t& _,
   for (auto& pair : inst->uses()) {
     const auto* use = pair.first;
     if (use->opcode() != spv::Op::OpFunction &&
+        use->opcode() != spv::Op::OpAsmINTEL &&
         !spvOpcodeIsDebug(use->opcode()) && !use->IsNonSemantic() &&
         !spvOpcodeIsDecoration(use->opcode())) {
       return _.diag(SPV_ERROR_INVALID_ID, use)

+ 216 - 58
3rdparty/spirv-tools/source/val/validation_state.cpp

@@ -42,14 +42,17 @@ ModuleLayoutSection InstructionLayoutSection(
 
   switch (op) {
     case spv::Op::OpCapability:
+    case spv::Op::OpConditionalCapabilityINTEL:
       return kLayoutCapabilities;
     case spv::Op::OpExtension:
+    case spv::Op::OpConditionalExtensionINTEL:
       return kLayoutExtensions;
     case spv::Op::OpExtInstImport:
       return kLayoutExtInstImport;
     case spv::Op::OpMemoryModel:
       return kLayoutMemoryModel;
     case spv::Op::OpEntryPoint:
+    case spv::Op::OpConditionalEntryPointINTEL:
       return kLayoutEntryPoint;
     case spv::Op::OpExecutionMode:
     case spv::Op::OpExecutionModeId:
@@ -85,6 +88,9 @@ ModuleLayoutSection InstructionLayoutSection(
       // spv::Op::OpExtInst is only allowed in types section for certain
       // extended instruction sets. This will be checked separately.
       if (current_section == kLayoutTypes) return kLayoutTypes;
+      // SpvOpExtInst is allowed in graph definitions.
+      if (current_section == kLayoutGraphDefinitions)
+        return kLayoutGraphDefinitions;
       return kLayoutFunctionDefinitions;
     case spv::Op::OpLine:
     case spv::Op::OpNoLine:
@@ -99,6 +105,16 @@ ModuleLayoutSection InstructionLayoutSection(
       return kLayoutFunctionDefinitions;
     case spv::Op::OpSamplerImageAddressingModeNV:
       return kLayoutSamplerImageAddressMode;
+    case spv::Op::OpGraphEntryPointARM:
+    case spv::Op::OpGraphARM:
+    case spv::Op::OpGraphInputARM:
+    case spv::Op::OpGraphSetOutputARM:
+    case spv::Op::OpGraphEndARM:
+      return kLayoutGraphDefinitions;
+    case spv::Op::OpCompositeExtract:
+      if (current_section == kLayoutGraphDefinitions)
+        return kLayoutGraphDefinitions;
+      return kLayoutFunctionDefinitions;
     default:
       break;
   }
@@ -174,6 +190,7 @@ ValidationState_t::ValidationState_t(const spv_const_context ctx,
       pointer_size_and_alignment_(0),
       sampler_image_addressing_mode_(0),
       in_function_(false),
+      graph_definition_region_(kGraphDefinitionOutside),
       num_of_warnings_(0),
       max_num_of_warnings_(max_warnings) {
   assert(opt && "Validator options may not be Null.");
@@ -362,6 +379,10 @@ bool ValidationState_t::in_block() const {
          module_functions_.back().current_block() != nullptr;
 }
 
+GraphDefinitionRegion ValidationState_t::graph_definition_region() const {
+  return graph_definition_region_;
+}
+
 void ValidationState_t::RegisterCapability(spv::Capability cap) {
   // Avoid redundant work.  Otherwise the recursion could induce work
   // quadrdatic in the capability dependency depth. (Ok, not much, but
@@ -532,6 +553,13 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() {
   return SPV_SUCCESS;
 }
 
+void ValidationState_t::SetGraphDefinitionRegion(GraphDefinitionRegion region) {
+  assert((region == kGraphDefinitionOutside &&
+          graph_definition_region_ == kGraphDefinitionOutputs) ||
+         region >= graph_definition_region_);
+  graph_definition_region_ = region;
+}
+
 Instruction* ValidationState_t::AddOrderedInstruction(
     const spv_parsed_instruction_t* inst) {
   ordered_instructions_.emplace_back(inst);
@@ -875,9 +903,12 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
     case spv::Op::OpTypeFloat:
     case spv::Op::OpTypeInt:
     case spv::Op::OpTypeBool:
+    case spv::Op::OpTypePointer:
+    case spv::Op::OpTypeUntypedPointerKHR:
       return id;
 
     case spv::Op::OpTypeArray:
+    case spv::Op::OpTypeRuntimeArray:
       return inst->word(2);
 
     case spv::Op::OpTypeVector:
@@ -939,11 +970,20 @@ uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
   const Instruction* inst = FindDef(component_type_id);
   assert(inst);
 
-  if (inst->opcode() == spv::Op::OpTypeFloat ||
-      inst->opcode() == spv::Op::OpTypeInt)
-    return inst->word(2);
-
-  if (inst->opcode() == spv::Op::OpTypeBool) return 1;
+  switch (inst->opcode()) {
+    case spv::Op::OpTypeFloat:
+    case spv::Op::OpTypeInt:
+      return inst->word(2);
+    case spv::Op::OpTypeBool:
+      return 1;
+    case spv::Op::OpTypePointer:
+    case spv::Op::OpTypeUntypedPointerKHR:
+      assert(inst->GetOperandAs<spv::StorageClass>(1) ==
+             spv::StorageClass::PhysicalStorageBuffer);
+      return 64;  // all pointers to another PSB is 64-bit
+    default:
+      break;
+  }
 
   assert(0);
   return 0;
@@ -958,6 +998,23 @@ bool ValidationState_t::IsScalarType(uint32_t id) const {
   return IsIntScalarType(id) || IsFloatScalarType(id) || IsBoolScalarType(id);
 }
 
+bool ValidationState_t::IsArrayType(uint32_t id, uint64_t length) const {
+  const Instruction* inst = FindDef(id);
+  if (!inst || inst->opcode() != spv::Op::OpTypeArray) {
+    return false;
+  }
+  if (length != 0) {
+    const auto len_id = inst->GetOperandAs<uint32_t>(2);
+    const auto len = FindDef(len_id);
+    uint64_t len_value = 0;
+    if (!len || !spvOpcodeIsConstant(len->opcode()) ||
+        (EvalConstantValUint64(len_id, &len_value) && (length != len_value))) {
+      return false;
+    }
+  }
+  return true;
+}
+
 bool ValidationState_t::IsBfloat16ScalarType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
@@ -984,6 +1041,24 @@ bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
   return false;
 }
 
+bool ValidationState_t::IsBfloat16CoopMatType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  if (!inst) {
+    return false;
+  }
+
+  if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    return IsBfloat16ScalarType(inst->word(2));
+  }
+
+  return false;
+}
+
+bool ValidationState_t::IsBfloat16Type(uint32_t id) const {
+  return IsBfloat16ScalarType(id) || IsBfloat16VectorType(id) ||
+         IsBfloat16CoopMatType(id);
+}
+
 bool ValidationState_t::IsFP8ScalarType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
@@ -1011,28 +1086,32 @@ bool ValidationState_t::IsFP8VectorType(uint32_t id) const {
   return false;
 }
 
-bool ValidationState_t::IsFP8ScalarOrVectorType(uint32_t id) const {
-  return IsFP8ScalarType(id) || IsFP8VectorType(id);
-}
-
-bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
-  const Instruction* inst = FindDef(id);
-  return inst && inst->opcode() == spv::Op::OpTypeFloat;
-}
-
-bool ValidationState_t::IsFloatArrayType(uint32_t id) const {
+bool ValidationState_t::IsFP8CoopMatType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   if (!inst) {
     return false;
   }
 
-  if (inst->opcode() == spv::Op::OpTypeArray) {
-    return IsFloatScalarType(GetComponentType(id));
+  if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    return IsFP8ScalarType(inst->word(2));
   }
 
   return false;
 }
 
+bool ValidationState_t::IsFP8Type(uint32_t id) const {
+  return IsFP8ScalarType(id) || IsFP8VectorType(id) || IsFP8CoopMatType(id);
+}
+
+bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && inst->opcode() == spv::Op::OpTypeFloat;
+}
+
+bool ValidationState_t::IsFloatArrayType(uint32_t id) const {
+  return IsArrayType(id) && IsFloatScalarType(GetComponentType(id));
+}
+
 bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   if (!inst) {
@@ -1077,36 +1156,27 @@ bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
   return false;
 }
 
-bool ValidationState_t::IsIntScalarType(uint32_t id) const {
-  const Instruction* inst = FindDef(id);
-  return inst && inst->opcode() == spv::Op::OpTypeInt;
-}
-
-bool ValidationState_t::IsIntArrayType(uint32_t id, uint64_t length) const {
+bool ValidationState_t::IsIntScalarType(uint32_t id, uint32_t width) const {
   const Instruction* inst = FindDef(id);
-  if (!inst) {
+  bool is_int = inst && inst->opcode() == spv::Op::OpTypeInt;
+  if (!is_int) {
     return false;
   }
-
-  if (inst->opcode() != spv::Op::OpTypeArray) {
-    return false;
-  }
-
-  if (!IsIntScalarType(GetComponentType(id))) {
+  if ((width != 0) && (width != inst->word(2))) {
     return false;
   }
+  return true;
+}
 
-  if (length != 0) {
-    const auto len_id = inst->GetOperandAs<uint32_t>(2);
-    const auto len = FindDef(len_id);
-    uint64_t len_value = 0;
-    if (!len || !spvOpcodeIsConstant(len->opcode()) ||
-        (EvalConstantValUint64(len_id, &len_value) && (length != len_value))) {
-      return false;
-    }
-  }
+bool ValidationState_t::IsIntScalarTypeWithSignedness(
+    uint32_t id, uint32_t signedness) const {
+  const Instruction* inst = FindDef(id);
+  return inst && inst->opcode() == spv::Op::OpTypeInt &&
+         inst->word(3) == signedness;
+}
 
-  return true;
+bool ValidationState_t::IsIntArrayType(uint32_t id, uint64_t length) const {
+  return IsArrayType(id, length) && IsIntScalarType(GetComponentType(id));
 }
 
 bool ValidationState_t::IsIntVectorType(uint32_t id) const {
@@ -1140,8 +1210,7 @@ bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
 }
 
 bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
-  const Instruction* inst = FindDef(id);
-  return inst && inst->opcode() == spv::Op::OpTypeInt && inst->word(3) == 0;
+  return IsIntScalarTypeWithSignedness(id, 0);
 }
 
 bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
@@ -1312,6 +1381,28 @@ bool ValidationState_t::GetPointerTypeInfo(
   return true;
 }
 
+uint32_t ValidationState_t::GetLargestScalarType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+
+  switch (inst->opcode()) {
+    case spv::Op::OpTypeStruct: {
+      uint32_t size = 0;
+      for (uint32_t i = 1; i < inst->operands().size(); ++i) {
+        const uint32_t member_size =
+            GetLargestScalarType(inst->GetOperandAs<uint32_t>(i));
+        size = std::max(size, member_size);
+      }
+      return size;
+    }
+    case spv::Op::OpTypeArray:
+      return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
+    case spv::Op::OpTypeVector:
+      return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
+    default:
+      return GetBitWidth(id) / 8;
+  }
+}
+
 bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   return inst && inst->opcode() == spv::Op::OpTypeAccelerationStructureKHR;
@@ -1411,6 +1502,11 @@ bool ValidationState_t::IsUnsignedIntCooperativeVectorNVType(
   return IsUnsignedIntScalarType(FindDef(id)->word(2));
 }
 
+bool ValidationState_t::IsTensorType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && inst->opcode() == spv::Op::OpTypeTensorARM;
+}
+
 spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
     const Instruction* inst, uint32_t result_type_id, uint32_t m2,
     bool is_conversion, bool swap_row_col) {
@@ -1445,8 +1541,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
 
   if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
     return diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Expected scopes of Matrix and Result Type to be "
-           << "identical";
+           << "Expected scopes of Matrix and Result Type to be " << "identical";
   }
 
   std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
@@ -1949,6 +2044,14 @@ bool ValidationState_t::IsValidStorageClass(
   return true;
 }
 
+std::string ValidationState_t::MissingFeature(const std::string& feature,
+                                              const std::string& cmdline,
+                                              bool hint) const {
+  return "\nThis is " + (hint ? std::string("may be ") : "") +
+         "allowed if you enable the " + feature + " (or use the " + cmdline +
+         " command line flag)";
+}
+
 #define VUID_WRAP(vuid) "[" #vuid "] "
 
 // Currently no 2 VUID share the same id, so no need for |reference|
@@ -2211,6 +2314,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-Position-Position-04321);
     case 4330:
       return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04330);
+    case 4333:
+      return VUID_WRAP(VUID-PrimitiveId-Fragment-04333);
     case 4334:
       return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04334);
     case 4336:
@@ -2399,10 +2504,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-None-04644);
     case 4645:
       return VUID_WRAP(VUID-StandaloneSpirv-None-04645);
-    case 10609:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpControlBarrier-10609);
-    case 4650:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpControlBarrier-04650);
     case 4651:
       return VUID_WRAP(VUID-StandaloneSpirv-OpVariable-04651);
     case 4652:
@@ -2469,14 +2570,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-PhysicalStorageBuffer64-04710);
     case 4711:
       return VUID_WRAP(VUID-StandaloneSpirv-OpTypeForwardPointer-04711);
-    case 4730:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpAtomicStore-04730);
-    case 4731:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpAtomicLoad-04731);
-    case 4732:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpMemoryBarrier-04732);
-    case 4733:
-      return VUID_WRAP(VUID-StandaloneSpirv-OpMemoryBarrier-04733);
     case 4734:
       return VUID_WRAP(VUID-StandaloneSpirv-OpVariable-04734);
     case 4744:
@@ -2485,8 +2578,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-OpImage-04777);
     case 4780:
       return VUID_WRAP(VUID-StandaloneSpirv-Result-04780);
-    case 4781:
-      return VUID_WRAP(VUID-StandaloneSpirv-Base-04781);
     case 4915:
       return VUID_WRAP(VUID-StandaloneSpirv-Location-04915);
     case 4916:
@@ -2511,6 +2602,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-Flat-06202);
     case 6214:
       return VUID_WRAP(VUID-StandaloneSpirv-OpTypeImage-06214);
+    case 6314:
+      return VUID_WRAP(VUID-StandaloneSpirv-PhysicalStorageBuffer64-06314);
     case 6491:
       return VUID_WRAP(VUID-StandaloneSpirv-DescriptorSet-06491);
     case 6671:
@@ -2621,6 +2714,10 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-OpEntryPoint-09658);
     case 9659:
       return VUID_WRAP(VUID-StandaloneSpirv-OpEntryPoint-09659);
+    case 10151:
+      return VUID_WRAP(VUID-StandaloneSpirv-DerivativeGroupQuadsKHR-10151);
+    case 10152:
+      return VUID_WRAP(VUID-StandaloneSpirv-DerivativeGroupLinearKHR-10152);
     case 10213:
       // This use to be a standalone, but maintenance8 will set allow_offset_texture_operand now
       return VUID_WRAP(VUID-RuntimeSpirv-Offset-10213);
@@ -2628,10 +2725,71 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-OpTypeFloat-10370);
     case 10583:
       return VUID_WRAP(VUID-StandaloneSpirv-Component-10583);
+    case 10589:
+      return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10589);
+    case 10590:
+      return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10590);
+    case 10591:
+      return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10591);
+    case 10592:
+      return VUID_WRAP(VUID-Layer-Layer-10592);
+    case 10593:
+      return VUID_WRAP(VUID-Layer-Layer-10593);
+    case 10594:
+      return VUID_WRAP(VUID-Layer-Layer-10594);
+    case 10598:
+      return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10598);
+    case 10599:
+      return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10599);
+    case 10600:
+      return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10600);
+    case 10601:
+      return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10601);
+    case 10602:
+      return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10602);
+    case 10603:
+      return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10603);
     case 10684:
       return VUID_WRAP(VUID-StandaloneSpirv-None-10684);
     case 10685:
       return VUID_WRAP(VUID-StandaloneSpirv-None-10685);
+    case 10824:
+      // This use to be a standalone, but maintenance9 will set allow_vulkan_32_bit_bitwise now
+      return VUID_WRAP(VUID-RuntimeSpirv-None-10824);
+    case 10865:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10865);
+    case 10866:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10866);
+    case 10867:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10867);
+    case 10868:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10868);
+    case 10869:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10869);
+    case 10870:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10870);
+    case 10871:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10871);
+    case 10872:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10872);
+    case 10873:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10873);
+    case 10874:
+      return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10874);
+    case 10875:
+      return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10875);
+    case 10876:
+      return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10876);
+    case 10877:
+      return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10877);
+    case 10878:
+      return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10878);
+    case 10879:
+      return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10879);
+    case 10880:
+      return VUID_WRAP(VUID-StandaloneSpirv-TessLevelInner-10880);
+    case 11167:
+      return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedVariableKHR-11167);
     default:
       return "";  // unknown id
   }

+ 104 - 6
3rdparty/spirv-tools/source/val/validation_state.h

@@ -50,6 +50,7 @@ enum ModuleLayoutSection {
   kLayoutExtInstImport,            /// < Section 2.4 #3
   kLayoutMemoryModel,              /// < Section 2.4 #4
   kLayoutSamplerImageAddressMode,  /// < Section 2.4 #5
+                                   /// (SPV_NV_bindless_texture)
   kLayoutEntryPoint,               /// < Section 2.4 #6
   kLayoutExecutionMode,            /// < Section 2.4 #7
   kLayoutDebug1,                   /// < Section 2.4 #8 > 1
@@ -58,7 +59,18 @@ enum ModuleLayoutSection {
   kLayoutAnnotations,              /// < Section 2.4 #9
   kLayoutTypes,                    /// < Section 2.4 #10
   kLayoutFunctionDeclarations,     /// < Section 2.4 #11
-  kLayoutFunctionDefinitions       /// < Section 2.4 #12
+  kLayoutFunctionDefinitions,      /// < Section 2.4 #12
+  kLayoutGraphDefinitions          /// < Section 2.4 #13 (SPV_ARM_graph)
+};
+
+/// This enum represents the regions of a graph definition. The relative
+/// ordering of the values is significant.
+enum GraphDefinitionRegion {
+  kGraphDefinitionOutside,
+  kGraphDefinitionBegin,
+  kGraphDefinitionInputs,
+  kGraphDefinitionBody,
+  kGraphDefinitionOutputs,
 };
 
 /// This class manages the state of the SPIR-V validation as it is being parsed.
@@ -213,6 +225,9 @@ class ValidationState_t {
   /// instruction
   bool in_block() const;
 
+  /// Returns the region of a graph definition we are in.
+  GraphDefinitionRegion graph_definition_region() const;
+
   struct EntryPointDescription {
     std::string name;
     std::vector<uint32_t> interfaces;
@@ -313,6 +328,16 @@ class ValidationState_t {
   /// ComputeFunctionToEntryPointMapping.
   void ComputeRecursiveEntryPoints();
 
+  /// Registers |id| as a graph entry point.
+  void RegisterGraphEntryPoint(const uint32_t id) {
+    graph_entry_points_.push_back(id);
+  }
+
+  /// Returns a list of graph entry point graph ids
+  const std::vector<uint32_t>& graph_entry_points() const {
+    return graph_entry_points_;
+  }
+
   /// Returns all the entry points that can call |func|.
   const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
 
@@ -350,6 +375,9 @@ class ValidationState_t {
   /// Register a function end instruction
   spv_result_t RegisterFunctionEnd();
 
+  /// Sets the region of a graph definition we're in.
+  void SetGraphDefinitionRegion(GraphDefinitionRegion region);
+
   /// Returns true if the capability is enabled in the module.
   bool HasCapability(spv::Capability cap) const {
     return module_capabilities_.contains(cap);
@@ -632,23 +660,26 @@ class ValidationState_t {
   bool GetStructMemberTypes(uint32_t struct_type_id,
                             std::vector<uint32_t>* member_types) const;
 
-  // Returns true iff |id| is a type corresponding to the name of the function.
+  // Returns true if |id| is a type corresponding to the name of the function.
   // Only works for types not for objects.
   bool IsVoidType(uint32_t id) const;
   bool IsScalarType(uint32_t id) const;
   bool IsBfloat16ScalarType(uint32_t id) const;
   bool IsBfloat16VectorType(uint32_t id) const;
+  bool IsBfloat16CoopMatType(uint32_t id) const;
+  bool IsBfloat16Type(uint32_t id) const;
   bool IsFP8ScalarType(uint32_t id) const;
   bool IsFP8VectorType(uint32_t id) const;
-  bool IsFP8ScalarOrVectorType(uint32_t id) const;
+  bool IsFP8CoopMatType(uint32_t id) const;
+  bool IsFP8Type(uint32_t id) const;
   bool IsFloatScalarType(uint32_t id) const;
   bool IsFloatArrayType(uint32_t id) const;
   bool IsFloatVectorType(uint32_t id) const;
   bool IsFloat16Vector2Or4Type(uint32_t id) const;
   bool IsFloatScalarOrVectorType(uint32_t id) const;
   bool IsFloatMatrixType(uint32_t id) const;
-  bool IsIntScalarType(uint32_t id) const;
-  bool IsIntArrayType(uint32_t id, uint64_t length = 0) const;
+  bool IsIntScalarType(uint32_t id, uint32_t width = 0) const;
+  bool IsIntScalarTypeWithSignedness(uint32_t id, uint32_t signedness) const;
   bool IsIntVectorType(uint32_t id) const;
   bool IsIntScalarOrVectorType(uint32_t id) const;
   bool IsUnsignedIntScalarType(uint32_t id) const;
@@ -675,6 +706,36 @@ class ValidationState_t {
   bool IsFloatCooperativeVectorNVType(uint32_t id) const;
   bool IsIntCooperativeVectorNVType(uint32_t id) const;
   bool IsUnsignedIntCooperativeVectorNVType(uint32_t id) const;
+  bool IsTensorType(uint32_t id) const;
+  // When |length| is not 0, return true only if the array length is equal to
+  // |length| and the array length is not defined by a specialization constant.
+  bool IsArrayType(uint32_t id, uint64_t length = 0) const;
+  bool IsIntArrayType(uint32_t id, uint64_t length = 0) const;
+  template <unsigned int N>
+  bool IsIntNOrFP32OrFP16(unsigned int type_id) {
+    return this->ContainsType(
+        type_id,
+        [](const Instruction* inst) {
+          if (inst->opcode() == spv::Op::OpTypeInt) {
+            return inst->GetOperandAs<uint32_t>(1) == N;
+          } else if (inst->opcode() == spv::Op::OpTypeFloat) {
+            if (inst->operands().size() > 2) {
+              // Not IEEE
+              return false;
+            }
+            auto width = inst->GetOperandAs<uint32_t>(1);
+            return width == 32 || width == 16;
+          }
+          return false;
+        },
+        /* traverse_all_types = */ false);
+  }
+
+  // Will walk the type to find the largest scalar value size.
+  // Returns value is in bytes.
+  // This is designed to pass in the %type from a PSB pointer
+  //   %ptr = OpTypePointer PhysicalStorageBuffer %type
+  uint32_t GetLargestScalarType(uint32_t id) const;
 
   // Returns true if |id| is a type id that contains |type| (or integer or
   // floating point type) of |width| bits.
@@ -715,6 +776,17 @@ class ValidationState_t {
   bool GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
                           spv::StorageClass* storage_class) const;
 
+  // Returns the value assocated with id via 'value' if id is an OpConstant
+  template <typename T>
+  bool GetConstantValueAs(unsigned int id, T& value) {
+    const auto inst = FindDef(id);
+    uint64_t ui64_val = 0u;
+    bool status = (inst && spvOpcodeIsConstant(inst->opcode()) &&
+                   EvalConstantValUint64(id, &ui64_val));
+    if (status == true) value = static_cast<T>(ui64_val);
+    return status;
+  }
+
   // Is the ID the type of a pointer to a uniform block: Block-decorated struct
   // in uniform storage class? The result is only valid after internal method
   // CheckDecorationsOfBuffers has been called.
@@ -772,6 +844,16 @@ class ValidationState_t {
     pointer_to_storage_image_.insert(type_id);
   }
 
+  // Is the ID the type of a pointer to a tensor?  That is, the pointee
+  // type is a tensor type.
+  bool IsPointerToTensor(uint32_t type_id) const {
+    return pointer_to_tensor_.find(type_id) != pointer_to_tensor_.cend();
+  }
+  // Save the ID of a pointer to a tensor.
+  void RegisterPointerToTensor(uint32_t type_id) {
+    pointer_to_tensor_.insert(type_id);
+  }
+
   // Tries to evaluate a any scalar integer OpConstant as uint64.
   // OpConstantNull is defined as zero for scalar int (will return true)
   // OpSpecConstant* return false since their values cannot be relied upon
@@ -844,6 +926,12 @@ class ValidationState_t {
   // Validates the storage class for the target environment.
   bool IsValidStorageClass(spv::StorageClass storage_class) const;
 
+  // Helps formulate a mesesage to user that setting one of the validator
+  // options might make their SPIR-V actually valid The |hint| option is because
+  // some checks are intertwined with each other, so hard to give confirmation
+  std::string MissingFeature(const std::string& feature,
+                             const std::string& cmdline, bool hint) const;
+
   // Takes a Vulkan Valid Usage ID (VUID) as |id| and optional |reference| and
   // will return a non-empty string only if ID is known and targeting Vulkan.
   // VUIDs are found in the Vulkan-Docs repo in the form "[[VUID-ref-ref-id]]"
@@ -939,6 +1027,9 @@ class ValidationState_t {
   /// graph that recurses.
   std::set<uint32_t> recursive_entry_points_;
 
+  /// IDs that are graph entry points, ie, arguments to OpGraphEntryPointARM.
+  std::vector<uint32_t> graph_entry_points_;
+
   /// Functions IDs that are target of OpFunctionCall.
   std::unordered_set<uint32_t> function_call_targets_;
 
@@ -981,9 +1072,13 @@ class ValidationState_t {
   /// bit width of sampler/image type variables. Valid values are 32 and 64
   uint32_t sampler_image_addressing_mode_;
 
-  /// NOTE: See correspoding getter functions
+  /// NOTE: See corresponding getter functions
   bool in_function_;
 
+  /// Where in a graph definition we are
+  /// NOTE: See corresponding getter/setter functions
+  GraphDefinitionRegion graph_definition_region_;
+
   /// The state of optional features.  These are determined by capabilities
   /// declared by the module and the environment.
   Feature features_;
@@ -1030,6 +1125,9 @@ class ValidationState_t {
   // The IDs of types of pointers to storage images.  This is populated in the
   // TypePass.
   std::unordered_set<uint32_t> pointer_to_storage_image_;
+  // The IDs of types of pointers to tensors.  This is populated in the
+  // TypePass.
+  std::unordered_set<uint32_t> pointer_to_tensor_;
 
   /// Maps ids to friendly names.
   std::unique_ptr<spvtools::FriendlyNameMapper> friendly_mapper_;

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно