Browse Source

Updated spirv-tools.

Бранимир Караџић 2 years ago
parent
commit
aaf0fdf7cf
52 changed files with 1633 additions and 1096 deletions
  1. 1 1
      3rdparty/spirv-tools/include/generated/build-version.inc
  2. 6 0
      3rdparty/spirv-tools/include/generated/core.insts-unified1.inc
  3. 2 0
      3rdparty/spirv-tools/include/generated/enum_string_mapping.inc
  4. 1 0
      3rdparty/spirv-tools/include/generated/extension_enum.inc
  5. 4 2
      3rdparty/spirv-tools/include/generated/generators.inc
  6. 33 3
      3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc
  7. 5 12
      3rdparty/spirv-tools/include/spirv-tools/instrument.hpp
  8. 7 0
      3rdparty/spirv-tools/include/spirv-tools/libspirv.h
  9. 20 10
      3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp
  10. 3 2
      3rdparty/spirv-tools/source/assembly_grammar.cpp
  11. 5 1
      3rdparty/spirv-tools/source/binary.cpp
  12. 134 43
      3rdparty/spirv-tools/source/diff/diff.cpp
  13. 2 0
      3rdparty/spirv-tools/source/opcode.cpp
  14. 11 0
      3rdparty/spirv-tools/source/operand.cpp
  15. 1 0
      3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp
  16. 118 31
      3rdparty/spirv-tools/source/opt/const_folding_rules.cpp
  17. 27 0
      3rdparty/spirv-tools/source/opt/constants.cpp
  18. 6 0
      3rdparty/spirv-tools/source/opt/constants.h
  19. 58 15
      3rdparty/spirv-tools/source/opt/fold.cpp
  20. 8 0
      3rdparty/spirv-tools/source/opt/fold.h
  21. 6 2
      3rdparty/spirv-tools/source/opt/folding_rules.cpp
  22. 238 463
      3rdparty/spirv-tools/source/opt/inst_bindless_check_pass.cpp
  23. 11 96
      3rdparty/spirv-tools/source/opt/inst_bindless_check_pass.h
  24. 16 48
      3rdparty/spirv-tools/source/opt/inst_buff_addr_check_pass.cpp
  25. 0 4
      3rdparty/spirv-tools/source/opt/inst_buff_addr_check_pass.h
  26. 4 2
      3rdparty/spirv-tools/source/opt/inst_debug_printf_pass.cpp
  27. 25 3
      3rdparty/spirv-tools/source/opt/instruction.cpp
  28. 10 0
      3rdparty/spirv-tools/source/opt/instruction.h
  29. 92 115
      3rdparty/spirv-tools/source/opt/instrument_pass.cpp
  30. 5 28
      3rdparty/spirv-tools/source/opt/instrument_pass.h
  31. 16 0
      3rdparty/spirv-tools/source/opt/ir_builder.h
  32. 2 1
      3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp
  33. 56 53
      3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp
  34. 53 50
      3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp
  35. 37 40
      3rdparty/spirv-tools/source/opt/optimizer.cpp
  36. 31 0
      3rdparty/spirv-tools/source/opt/type_manager.cpp
  37. 5 6
      3rdparty/spirv-tools/source/opt/type_manager.h
  38. 42 0
      3rdparty/spirv-tools/source/opt/types.cpp
  39. 35 0
      3rdparty/spirv-tools/source/opt/types.h
  40. 2 1
      3rdparty/spirv-tools/source/text.cpp
  41. 143 10
      3rdparty/spirv-tools/source/val/validate_arithmetics.cpp
  42. 20 0
      3rdparty/spirv-tools/source/val/validate_composites.cpp
  43. 2 0
      3rdparty/spirv-tools/source/val/validate_constants.cpp
  44. 16 2
      3rdparty/spirv-tools/source/val/validate_conversion.cpp
  45. 58 23
      3rdparty/spirv-tools/source/val/validate_decorations.cpp
  46. 10 4
      3rdparty/spirv-tools/source/val/validate_id.cpp
  47. 0 1
      3rdparty/spirv-tools/source/val/validate_image.cpp
  48. 25 6
      3rdparty/spirv-tools/source/val/validate_interfaces.cpp
  49. 128 4
      3rdparty/spirv-tools/source/val/validate_memory.cpp
  50. 20 7
      3rdparty/spirv-tools/source/val/validate_type.cpp
  51. 68 7
      3rdparty/spirv-tools/source/val/validation_state.cpp
  52. 5 0
      3rdparty/spirv-tools/source/val/validation_state.h

+ 1 - 1
3rdparty/spirv-tools/include/generated/build-version.inc

@@ -1 +1 @@
-"v2023.3", "SPIRV-Tools v2023.3 v2022.4-199-gcd9fa015"
+"v2023.3", "SPIRV-Tools v2023.3 v2022.4-246-g7e02c531"

+ 6 - 0
3rdparty/spirv-tools/include/generated/core.insts-unified1.inc

@@ -10,6 +10,7 @@ static const spv::Capability pygen_variable_caps_AtomicFloat16MinMaxEXTAtomicFlo
 static const spv::Capability pygen_variable_caps_BFloat16ConversionINTEL[] = {spv::Capability::BFloat16ConversionINTEL};
 static const spv::Capability pygen_variable_caps_BindlessTextureNV[] = {spv::Capability::BindlessTextureNV};
 static const spv::Capability pygen_variable_caps_BlockingPipesINTEL[] = {spv::Capability::BlockingPipesINTEL};
+static const spv::Capability pygen_variable_caps_CooperativeMatrixKHR[] = {spv::Capability::CooperativeMatrixKHR};
 static const spv::Capability pygen_variable_caps_CooperativeMatrixNV[] = {spv::Capability::CooperativeMatrixNV};
 static const spv::Capability pygen_variable_caps_DemoteToHelperInvocation[] = {spv::Capability::DemoteToHelperInvocation};
 static const spv::Capability pygen_variable_caps_DemoteToHelperInvocationEXT[] = {spv::Capability::DemoteToHelperInvocationEXT};
@@ -487,6 +488,11 @@ static const spv_opcode_desc_t kOpcodeTableEntries[] = {
   {"UDotAccSatKHR", spv::Op::OpUDotAccSatKHR, 1, pygen_variable_caps_DotProductKHR, 6, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT}, 1, 1, 1, pygen_variable_exts_SPV_KHR_integer_dot_product, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu},
   {"SUDotAccSat", spv::Op::OpSUDotAccSat, 1, pygen_variable_caps_DotProduct, 6, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT}, 1, 1, 0, nullptr, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu},
   {"SUDotAccSatKHR", spv::Op::OpSUDotAccSatKHR, 1, pygen_variable_caps_DotProductKHR, 6, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT}, 1, 1, 1, pygen_variable_exts_SPV_KHR_integer_dot_product, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu},
+  {"TypeCooperativeMatrixKHR", spv::Op::OpTypeCooperativeMatrixKHR, 1, pygen_variable_caps_CooperativeMatrixKHR, 6, {SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_SCOPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 0, 0, nullptr, 0xffffffffu, 0xffffffffu},
+  {"CooperativeMatrixLoadKHR", spv::Op::OpCooperativeMatrixLoadKHR, 1, pygen_variable_caps_CooperativeMatrixKHR, 6, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_ID, SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},
+  {"CooperativeMatrixStoreKHR", spv::Op::OpCooperativeMatrixStoreKHR, 1, pygen_variable_caps_CooperativeMatrixKHR, 5, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_ID, SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS}, 0, 0, 0, nullptr, 0xffffffffu, 0xffffffffu},
+  {"CooperativeMatrixMulAddKHR", spv::Op::OpCooperativeMatrixMulAddKHR, 1, pygen_variable_caps_CooperativeMatrixKHR, 6, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},
+  {"CooperativeMatrixLengthKHR", spv::Op::OpCooperativeMatrixLengthKHR, 1, pygen_variable_caps_CooperativeMatrixKHR, 3, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},
   {"TypeRayQueryKHR", spv::Op::OpTypeRayQueryKHR, 1, pygen_variable_caps_RayQueryKHR, 1, {SPV_OPERAND_TYPE_RESULT_ID}, 1, 0, 1, pygen_variable_exts_SPV_KHR_ray_query, 0xffffffffu, 0xffffffffu},
   {"RayQueryInitializeKHR", spv::Op::OpRayQueryInitializeKHR, 1, pygen_variable_caps_RayQueryKHR, 8, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 0, 0, 1, pygen_variable_exts_SPV_KHR_ray_query, 0xffffffffu, 0xffffffffu},
   {"RayQueryTerminateKHR", spv::Op::OpRayQueryTerminateKHR, 1, pygen_variable_caps_RayQueryKHR, 1, {SPV_OPERAND_TYPE_ID}, 0, 0, 1, pygen_variable_exts_SPV_KHR_ray_query, 0xffffffffu, 0xffffffffu},

File diff suppressed because it is too large
+ 2 - 0
3rdparty/spirv-tools/include/generated/enum_string_mapping.inc


+ 1 - 0
3rdparty/spirv-tools/include/generated/extension_enum.inc

@@ -67,6 +67,7 @@ kSPV_INTEL_vector_compute,
 kSPV_KHR_16bit_storage,
 kSPV_KHR_8bit_storage,
 kSPV_KHR_bit_instructions,
+kSPV_KHR_cooperative_matrix,
 kSPV_KHR_device_group,
 kSPV_KHR_expect_assume,
 kSPV_KHR_float_controls,

+ 4 - 2
3rdparty/spirv-tools/include/generated/generators.inc

@@ -30,8 +30,10 @@
 {29, "Mikkosoft Productions", "MSP Shader Compiler", "Mikkosoft Productions MSP Shader Compiler"},
 {30, "SpvGenTwo community", "SpvGenTwo SPIR-V IR Tools", "SpvGenTwo community SpvGenTwo SPIR-V IR Tools"},
 {31, "Google", "Skia SkSL", "Google Skia SkSL"},
-{32, "TornadoVM", "SPIRV Beehive Toolkit", "TornadoVM SPIRV Beehive Toolkit"},
+{32, "TornadoVM", "Beehive SPIRV Toolkit", "TornadoVM Beehive SPIRV Toolkit"},
 {33, "DragonJoker", "ShaderWriter", "DragonJoker ShaderWriter"},
 {34, "Rayan Hatout", "SPIRVSmith", "Rayan Hatout SPIRVSmith"},
 {35, "Saarland University", "Shady", "Saarland University Shady"},
-{36, "Taichi Graphics", "Taichi", "Taichi Graphics Taichi"},
+{36, "Taichi Graphics", "Taichi", "Taichi Graphics Taichi"},
+{37, "heroseh", "Hero C Compiler", "heroseh Hero C Compiler"},
+{38, "Meta", "SparkSL", "Meta SparkSL"},

+ 33 - 3
3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc

@@ -6,6 +6,7 @@ static const spv::Capability pygen_variable_caps_BindlessTextureNV[] = {spv::Cap
 static const spv::Capability pygen_variable_caps_ClipDistance[] = {spv::Capability::ClipDistance};
 static const spv::Capability pygen_variable_caps_ComputeDerivativeGroupLinearNV[] = {spv::Capability::ComputeDerivativeGroupLinearNV};
 static const spv::Capability pygen_variable_caps_ComputeDerivativeGroupQuadsNV[] = {spv::Capability::ComputeDerivativeGroupQuadsNV};
+static const spv::Capability pygen_variable_caps_CooperativeMatrixKHR[] = {spv::Capability::CooperativeMatrixKHR};
 static const spv::Capability pygen_variable_caps_CoreBuiltinsARM[] = {spv::Capability::CoreBuiltinsARM};
 static const spv::Capability pygen_variable_caps_CullDistance[] = {spv::Capability::CullDistance};
 static const spv::Capability pygen_variable_caps_DenormFlushToZero[] = {spv::Capability::DenormFlushToZero};
@@ -198,6 +199,7 @@ static const spvtools::Extension pygen_variable_exts_SPV_INTEL_vector_compute[]
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_16bit_storage[] = {spvtools::Extension::kSPV_KHR_16bit_storage};
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_8bit_storage[] = {spvtools::Extension::kSPV_KHR_8bit_storage};
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_bit_instructions[] = {spvtools::Extension::kSPV_KHR_bit_instructions};
+static const spvtools::Extension pygen_variable_exts_SPV_KHR_cooperative_matrix[] = {spvtools::Extension::kSPV_KHR_cooperative_matrix};
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_device_group[] = {spvtools::Extension::kSPV_KHR_device_group};
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_expect_assume[] = {spvtools::Extension::kSPV_KHR_expect_assume};
 static const spvtools::Extension pygen_variable_exts_SPV_KHR_float_controls[] = {spvtools::Extension::kSPV_KHR_float_controls};
@@ -390,7 +392,8 @@ static const spv_operand_desc_t pygen_variable_SourceLanguageEntries[] = {
   {"OpenCL_CPP", 4, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
   {"HLSL", 5, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
   {"CPP_for_OpenCL", 6, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
-  {"SYCL", 7, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
+  {"SYCL", 7, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
+  {"HERO_C", 8, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
 };
 
 static const spv_operand_desc_t pygen_variable_ExecutionModelEntries[] = {
@@ -666,7 +669,9 @@ static const spv_operand_desc_t pygen_variable_ImageChannelDataTypeEntries[] = {
   {"HalfFloat", 13, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
   {"Float", 14, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
   {"UnormInt24", 15, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
-  {"UnormInt101010_2", 16, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
+  {"UnormInt101010_2", 16, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
+  {"UnsignedIntRaw10EXT", 19, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
+  {"UnsignedIntRaw12EXT", 20, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
 };
 
 static const spv_operand_desc_t pygen_variable_FPRoundingModeEntries[] = {
@@ -1253,6 +1258,7 @@ static const spv_operand_desc_t pygen_variable_CapabilityEntries[] = {
   {"DotProduct", 6019, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_integer_dot_product, {}, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu},
   {"DotProductKHR", 6019, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_integer_dot_product, {}, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu},
   {"RayCullMaskKHR", 6020, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_ray_cull_mask, {}, 0xffffffffu, 0xffffffffu},
+  {"CooperativeMatrixKHR", 6022, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_cooperative_matrix, {}, 0xffffffffu, 0xffffffffu},
   {"BitInstructions", 6025, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_bit_instructions, {}, 0xffffffffu, 0xffffffffu},
   {"GroupNonUniformRotateKHR", 6026, 1, pygen_variable_caps_GroupNonUniform, 1, pygen_variable_exts_SPV_KHR_subgroup_rotate, {}, 0xffffffffu, 0xffffffffu},
   {"AtomicFloat32AddEXT", 6033, 0, nullptr, 1, pygen_variable_exts_SPV_EXT_shader_atomic_float_add, {}, 0xffffffffu, 0xffffffffu},
@@ -1290,6 +1296,26 @@ static const spv_operand_desc_t pygen_variable_PackedVectorFormatEntries[] = {
   {"PackedVectorFormat4x8BitKHR", 0, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_integer_dot_product, {}, SPV_SPIRV_VERSION_WORD(1,6), 0xffffffffu}
 };
 
+static const spv_operand_desc_t pygen_variable_CooperativeMatrixOperandsEntries[] = {
+  {"None", 0x0000, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixASignedComponents", 0x0001, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixBSignedComponents", 0x0002, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixCSignedComponents", 0x0004, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixResultSignedComponents", 0x0008, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"SaturatingAccumulation", 0x0010, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu}
+};
+
+static const spv_operand_desc_t pygen_variable_CooperativeMatrixLayoutEntries[] = {
+  {"RowMajorKHR", 0, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"ColumnMajorKHR", 1, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu}
+};
+
+static const spv_operand_desc_t pygen_variable_CooperativeMatrixUseEntries[] = {
+  {"MatrixAKHR", 0, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixBKHR", 1, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu},
+  {"MatrixAccumulatorKHR", 2, 1, pygen_variable_caps_CooperativeMatrixKHR, 0, nullptr, {}, 0xffffffffu, 0xffffffffu}
+};
+
 static const spv_operand_desc_t pygen_variable_DebugInfoFlagsEntries[] = {
   {"None", 0x0000, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
   {"FlagIsProtected", 0x01, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
@@ -1449,6 +1475,9 @@ static const spv_operand_desc_group_t pygen_variable_OperandInfoTable[] = {
   {SPV_OPERAND_TYPE_RAY_QUERY_COMMITTED_INTERSECTION_TYPE, ARRAY_SIZE(pygen_variable_RayQueryCommittedIntersectionTypeEntries), pygen_variable_RayQueryCommittedIntersectionTypeEntries},
   {SPV_OPERAND_TYPE_RAY_QUERY_CANDIDATE_INTERSECTION_TYPE, ARRAY_SIZE(pygen_variable_RayQueryCandidateIntersectionTypeEntries), pygen_variable_RayQueryCandidateIntersectionTypeEntries},
   {SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT, ARRAY_SIZE(pygen_variable_PackedVectorFormatEntries), pygen_variable_PackedVectorFormatEntries},
+  {SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS, ARRAY_SIZE(pygen_variable_CooperativeMatrixOperandsEntries), pygen_variable_CooperativeMatrixOperandsEntries},
+  {SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT, ARRAY_SIZE(pygen_variable_CooperativeMatrixLayoutEntries), pygen_variable_CooperativeMatrixLayoutEntries},
+  {SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE, ARRAY_SIZE(pygen_variable_CooperativeMatrixUseEntries), pygen_variable_CooperativeMatrixUseEntries},
   {SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS, ARRAY_SIZE(pygen_variable_DebugInfoFlagsEntries), pygen_variable_DebugInfoFlagsEntries},
   {SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING, ARRAY_SIZE(pygen_variable_DebugBaseTypeAttributeEncodingEntries), pygen_variable_DebugBaseTypeAttributeEncodingEntries},
   {SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE, ARRAY_SIZE(pygen_variable_DebugCompositeTypeEntries), pygen_variable_DebugCompositeTypeEntries},
@@ -1463,5 +1492,6 @@ static const spv_operand_desc_group_t pygen_variable_OperandInfoTable[] = {
   {SPV_OPERAND_TYPE_OPTIONAL_IMAGE, ARRAY_SIZE(pygen_variable_ImageOperandsEntries), pygen_variable_ImageOperandsEntries},
   {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, ARRAY_SIZE(pygen_variable_MemoryAccessEntries), pygen_variable_MemoryAccessEntries},
   {SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER, ARRAY_SIZE(pygen_variable_AccessQualifierEntries), pygen_variable_AccessQualifierEntries},
-  {SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT, ARRAY_SIZE(pygen_variable_PackedVectorFormatEntries), pygen_variable_PackedVectorFormatEntries}
+  {SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT, ARRAY_SIZE(pygen_variable_PackedVectorFormatEntries), pygen_variable_PackedVectorFormatEntries},
+  {SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS, ARRAY_SIZE(pygen_variable_CooperativeMatrixOperandsEntries), pygen_variable_CooperativeMatrixOperandsEntries}
 };

+ 5 - 12
3rdparty/spirv-tools/include/spirv-tools/instrument.hpp

@@ -182,18 +182,11 @@ static const int kInstMaxOutCnt = kInstStageOutCnt + 6;
 // Validation Error Codes
 //
 // These are the possible validation error codes.
-static const int kInstErrorBindlessBounds = 0;
-static const int kInstErrorBindlessUninit = 1;
-static const int kInstErrorBuffAddrUnallocRef = 2;
-// Deleted: static const int kInstErrorBindlessBuffOOB = 3;
-// This comment will will remain for 2 releases to allow
-// for the transition of all builds. Buffer OOB is
-// generating the following four differentiated codes instead:
-static const int kInstErrorBuffOOBUniform = 4;
-static const int kInstErrorBuffOOBStorage = 5;
-static const int kInstErrorBuffOOBUniformTexel = 6;
-static const int kInstErrorBuffOOBStorageTexel = 7;
-static const int kInstErrorMax = kInstErrorBuffOOBStorageTexel;
+static const int kInstErrorBindlessBounds = 1;
+static const int kInstErrorBindlessUninit = 2;
+static const int kInstErrorBuffAddrUnallocRef = 3;
+static const int kInstErrorOOB = 4;
+static const int kInstErrorMax = kInstErrorOOB;
 
 // Direct Input Buffer Offsets
 //

+ 7 - 0
3rdparty/spirv-tools/include/spirv-tools/libspirv.h

@@ -286,6 +286,13 @@ typedef enum spv_operand_type_t {
   // An optional packed vector format
   SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
 
+  // Concrete operand types for cooperative matrix.
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
+  // An optional cooperative matrix operands
+  SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
+
   // This is a sentinel value, and does not represent an operand type.
   // It should come last.
   SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,

+ 20 - 10
3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp

@@ -97,12 +97,24 @@ class Optimizer {
   // Registers passes that attempt to improve performance of generated code.
   // This sequence of passes is subject to constant review and will change
   // from time to time.
+  //
+  // If |preserve_interface| is true, all non-io variables in the entry point
+  // interface are considered live and are not eliminated.
+  // |preserve_interface| should be true if HLSL is generated
+  // from the SPIR-V bytecode.
   Optimizer& RegisterPerformancePasses();
+  Optimizer& RegisterPerformancePasses(bool preserve_interface);
 
   // Registers passes that attempt to improve the size of generated code.
   // This sequence of passes is subject to constant review and will change
   // from time to time.
+  //
+  // If |preserve_interface| is true, all non-io variables in the entry point
+  // interface are considered live and are not eliminated.
+  // |preserve_interface| should be true if HLSL is generated
+  // from the SPIR-V bytecode.
   Optimizer& RegisterSizePasses();
+  Optimizer& RegisterSizePasses(bool preserve_interface);
 
   // Registers passes that attempt to legalize the generated code.
   //
@@ -112,7 +124,13 @@ class Optimizer {
   //
   // This sequence of passes is subject to constant review and will change
   // from time to time.
+  //
+  // If |preserve_interface| is true, all non-io variables in the entry point
+  // interface are considered live and are not eliminated.
+  // |preserve_interface| should be true if HLSL is generated
+  // from the SPIR-V bytecode.
   Optimizer& RegisterLegalizationPasses();
+  Optimizer& RegisterLegalizationPasses(bool preserve_interface);
 
   // Register passes specified in the list of |flags|.  Each flag must be a
   // string of a form accepted by Optimizer::FlagHasValidForm().
@@ -751,16 +769,8 @@ Optimizer::PassToken CreateCombineAccessChainsPass();
 // The instrumentation will read and write buffers in debug
 // descriptor set |desc_set|. It will write |shader_id| in each output record
 // to identify the shader module which generated the record.
-// |desc_length_enable| controls instrumentation of runtime descriptor array
-// references, |desc_init_enable| controls instrumentation of descriptor
-// initialization checking, and |buff_oob_enable| controls instrumentation
-// of storage and uniform buffer bounds checking, all of which require input
-// buffer support. |texbuff_oob_enable| controls instrumentation of texel
-// buffers, which does not require input buffer support.
-Optimizer::PassToken CreateInstBindlessCheckPass(
-    uint32_t desc_set, uint32_t shader_id, bool desc_length_enable = false,
-    bool desc_init_enable = false, bool buff_oob_enable = false,
-    bool texbuff_oob_enable = false);
+Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set,
+                                                 uint32_t shader_id);
 
 // Create a pass to instrument physical buffer address checking
 // This pass instruments all physical buffer address references to check that

+ 3 - 2
3rdparty/spirv-tools/source/assembly_grammar.cpp

@@ -154,11 +154,12 @@ const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
     CASE(InBoundsAccessChain),
     CASE(PtrAccessChain),
     CASE(InBoundsPtrAccessChain),
-    CASE(CooperativeMatrixLengthNV)
+    CASE(CooperativeMatrixLengthNV),
+    CASE(CooperativeMatrixLengthKHR)
 };
 
 // The 60 is determined by counting the opcodes listed in the spec.
-static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
               "OpSpecConstantOp opcode table is incomplete");
 #undef CASE
 // clang-format on

+ 5 - 1
3rdparty/spirv-tools/source/binary.cpp

@@ -691,7 +691,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
     case SPV_OPERAND_TYPE_SELECTION_CONTROL:
     case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
-    case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: {
+    case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
       // This operand is a mask.
 
       // Map an optional operand type to its corresponding concrete type.
@@ -699,6 +701,8 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
         parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
       else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
         parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+      if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
+        parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
 
       // Check validity of set mask bits. Also prepare for operands for those
       // masks if they have any.  To get operand order correct, scan from

+ 134 - 43
3rdparty/spirv-tools/source/diff/diff.cpp

@@ -101,9 +101,12 @@ class IdMap {
     return from < id_map_.size() && id_map_[from] != 0;
   }
 
-  // Map any ids in src and dst that have not been mapped to new ids in dst and
-  // src respectively.
-  void MapUnmatchedIds(IdMap& other_way);
+  bool IsMapped(const opt::Instruction* from_inst) const {
+    assert(from_inst != nullptr);
+    assert(!from_inst->HasResultId());
+
+    return inst_map_.find(from_inst) != inst_map_.end();
+  }
 
   // Some instructions don't have result ids.  Those are mapped by pointer.
   void MapInsts(const opt::Instruction* from_inst,
@@ -117,6 +120,12 @@ class IdMap {
 
   uint32_t IdBound() const { return static_cast<uint32_t>(id_map_.size()); }
 
+  // Generate a fresh id in this mapping's domain.
+  uint32_t MakeFreshId() {
+    id_map_.push_back(0);
+    return static_cast<uint32_t>(id_map_.size()) - 1;
+  }
+
  private:
   // Given an id, returns the corresponding id in the other module, or 0 if not
   // matched yet.
@@ -150,10 +159,16 @@ class SrcDstIdMap {
 
   bool IsSrcMapped(uint32_t src) { return src_to_dst_.IsMapped(src); }
   bool IsDstMapped(uint32_t dst) { return dst_to_src_.IsMapped(dst); }
+  bool IsDstMapped(const opt::Instruction* dst_inst) {
+    return dst_to_src_.IsMapped(dst_inst);
+  }
 
   // Map any ids in src and dst that have not been mapped to new ids in dst and
-  // src respectively.
-  void MapUnmatchedIds();
+  // src respectively. Use src_insn_defined and dst_insn_defined to ignore ids
+  // that are simply never defined. (Since we assume the inputs are valid
+  // SPIR-V, this implies they are also never used.)
+  void MapUnmatchedIds(std::function<bool(uint32_t)> src_insn_defined,
+                       std::function<bool(uint32_t)> dst_insn_defined);
 
   // Some instructions don't have result ids.  Those are mapped by pointer.
   void MapInsts(const opt::Instruction* src_inst,
@@ -203,6 +218,11 @@ struct IdInstructions {
 
   void MapIdToInstruction(uint32_t id, const opt::Instruction* inst);
 
+  // Return true if id is mapped to any instruction, false otherwise.
+  bool IsDefined(uint32_t id) {
+    return id < inst_map_.size() && inst_map_[id] != nullptr;
+  }
+
   void MapIdsToInstruction(
       opt::IteratorRange<opt::Module::const_inst_iterator> section);
   void MapIdsToInfos(
@@ -338,6 +358,59 @@ class Differ {
       std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
           match_group);
 
+  // Bucket `src_ids` and `dst_ids` by the key ids returned by `get_group`, and
+  // then call `match_group` on pairs of buckets whose key ids are matched with
+  // each other.
+  //
+  // For example, suppose we want to pair up groups of instructions with the
+  // same type. Naturally, the source instructions refer to their types by their
+  // ids in the source, and the destination instructions use destination type
+  // ids, so simply comparing source and destination type ids as integers, as
+  // `GroupIdsAndMatch` would do, is meaningless. But if a prior call to
+  // `MatchTypeIds` has established type matches between the two modules, then
+  // we can consult those to pair source and destination buckets whose types are
+  // equivalent.
+  //
+  // Suppose our input groups are as follows:
+  //
+  // - src_ids: { 1 -> 100, 2 -> 300, 3 -> 100, 4 -> 200 }
+  // - dst_ids: { 5 -> 10, 6 -> 20, 7 -> 10, 8 -> 300 }
+  //
+  // Here, `X -> Y` means that the instruction with SPIR-V id `X` is a member of
+  // the group, and `Y` is the id of its type. If we use
+  // `Differ::GroupIdsHelperGetTypeId` for `get_group`, then
+  // `get_group(X) == Y`.
+  //
+  // These instructions are bucketed by type as follows:
+  //
+  // - source:      [1, 3] -> 100
+  //                [4] -> 200
+  //                [2] -> 300
+  //
+  // - destination: [5, 7] -> 10
+  //                [6] -> 20
+  //                [8] -> 300
+  //
+  // Now suppose that we have previously matched up src type 100 with dst type
+  // 10, and src type 200 with dst type 20, but no other types are matched.
+  //
+  // Then `match_group` is called twice:
+  // - Once with ([1,3], [5, 7]), corresponding to 100/10
+  // - Once with ([4],[6]), corresponding to 200/20
+  //
+  // The source type 300 isn't matched with anything, so the fact that there's a
+  // destination type 300 is irrelevant, and thus 2 and 8 are never passed to
+  // `match_group`.
+  //
+  // This function isn't specific to types; it simply buckets by the ids
+  // returned from `get_group`, and consults existing matches to pair up the
+  // resulting buckets.
+  void GroupIdsAndMatchByMappedId(
+      const IdGroup& src_ids, const IdGroup& dst_ids,
+      uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
+      std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+          match_group);
+
   // Helper functions that determine if two instructions match
   bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
   bool DoesOperandMatch(const opt::Operand& src_operand,
@@ -504,36 +577,27 @@ class Differ {
   FunctionMap dst_funcs_;
 };
 
-void IdMap::MapUnmatchedIds(IdMap& other_way) {
-  const uint32_t src_id_bound = static_cast<uint32_t>(id_map_.size());
-  const uint32_t dst_id_bound = static_cast<uint32_t>(other_way.id_map_.size());
-
-  uint32_t next_src_id = src_id_bound;
-  uint32_t next_dst_id = dst_id_bound;
+void SrcDstIdMap::MapUnmatchedIds(
+    std::function<bool(uint32_t)> src_insn_defined,
+    std::function<bool(uint32_t)> dst_insn_defined) {
+  const uint32_t src_id_bound = static_cast<uint32_t>(src_to_dst_.IdBound());
+  const uint32_t dst_id_bound = static_cast<uint32_t>(dst_to_src_.IdBound());
 
   for (uint32_t src_id = 1; src_id < src_id_bound; ++src_id) {
-    if (!IsMapped(src_id)) {
-      MapIds(src_id, next_dst_id);
-
-      other_way.id_map_.push_back(0);
-      other_way.MapIds(next_dst_id++, src_id);
+    if (!src_to_dst_.IsMapped(src_id) && src_insn_defined(src_id)) {
+      uint32_t fresh_dst_id = dst_to_src_.MakeFreshId();
+      MapIds(src_id, fresh_dst_id);
     }
   }
 
   for (uint32_t dst_id = 1; dst_id < dst_id_bound; ++dst_id) {
-    if (!other_way.IsMapped(dst_id)) {
-      id_map_.push_back(0);
-      MapIds(next_src_id, dst_id);
-
-      other_way.MapIds(dst_id, next_src_id++);
+    if (!dst_to_src_.IsMapped(dst_id) && dst_insn_defined(dst_id)) {
+      uint32_t fresh_src_id = src_to_dst_.MakeFreshId();
+      MapIds(fresh_src_id, dst_id);
     }
   }
 }
 
-void SrcDstIdMap::MapUnmatchedIds() {
-  src_to_dst_.MapUnmatchedIds(dst_to_src_);
-}
-
 void IdInstructions::MapIdToInstruction(uint32_t id,
                                         const opt::Instruction* inst) {
   assert(id != 0);
@@ -889,6 +953,37 @@ void Differ::GroupIdsAndMatch(
   }
 }
 
+void Differ::GroupIdsAndMatchByMappedId(
+    const IdGroup& src_ids, const IdGroup& dst_ids,
+    uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
+    std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+        match_group) {
+  // Group the ids based on a key (get_group)
+  std::map<uint32_t, IdGroup> src_groups;
+  std::map<uint32_t, IdGroup> dst_groups;
+
+  GroupIds<uint32_t>(src_ids, true, &src_groups, get_group);
+  GroupIds<uint32_t>(dst_ids, false, &dst_groups, get_group);
+
+  // Iterate over pairs of groups whose keys map to each other.
+  for (const auto& iter : src_groups) {
+    const uint32_t& src_key = iter.first;
+    const IdGroup& src_group = iter.second;
+
+    if (src_key == 0) {
+      continue;
+    }
+
+    if (id_map_.IsSrcMapped(src_key)) {
+      const uint32_t& dst_key = id_map_.MappedDstId(src_key);
+      const IdGroup& dst_group = dst_groups[dst_key];
+
+      // Let the caller match the groups as appropriate.
+      match_group(src_group, dst_group);
+    }
+  }
+}
+
 bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
   assert(dst_id != 0);
   return id_map_.MappedDstId(src_id) == dst_id;
@@ -1419,7 +1514,6 @@ void Differ::MatchTypeForwardPointersByName(const IdGroup& src,
   GroupIdsAndMatch<std::string>(
       src, dst, "", &Differ::GetSanitizedName,
       [this](const IdGroup& src_group, const IdGroup& dst_group) {
-
         // Match only if there's a unique forward declaration with this debug
         // name.
         if (src_group.size() == 1 && dst_group.size() == 1) {
@@ -1574,6 +1668,8 @@ void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
 
     id_map_.MapIds(match_result.src_id, match_result.dst_id);
 
+    MatchFunctionParamIds(src_funcs_[match_result.src_id],
+                          dst_funcs_[match_result.dst_id]);
     MatchIdsInFunctionBodies(src_func_insts.at(match_result.src_id),
                              dst_func_insts.at(match_result.dst_id),
                              match_result.src_match, match_result.dst_match, 0);
@@ -1598,7 +1694,6 @@ void Differ::MatchFunctionParamIds(const opt::Function* src_func,
   GroupIdsAndMatch<std::string>(
       src_params, dst_params, "", &Differ::GetSanitizedName,
       [this](const IdGroup& src_group, const IdGroup& dst_group) {
-
         // There shouldn't be two parameters with the same name, so the ids
         // should match. There is nothing restricting the SPIR-V however to have
         // two parameters with the same name, so be resilient against that.
@@ -1609,17 +1704,17 @@ void Differ::MatchFunctionParamIds(const opt::Function* src_func,
 
   // Then match the parameters by their type.  If there are multiple of them,
   // match them by their order.
-  GroupIdsAndMatch<uint32_t>(
-      src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
+  GroupIdsAndMatchByMappedId(
+      src_params, dst_params, &Differ::GroupIdsHelperGetTypeId,
       [this](const IdGroup& src_group_by_type_id,
              const IdGroup& dst_group_by_type_id) {
-
         const size_t shared_param_count =
             std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
 
         for (size_t param_index = 0; param_index < shared_param_count;
              ++param_index) {
-          id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
+          id_map_.MapIds(src_group_by_type_id[param_index],
+                         dst_group_by_type_id[param_index]);
         }
       });
 }
@@ -2064,9 +2159,10 @@ void Differ::MatchEntryPointIds() {
     }
 
     // Otherwise match them by name.
-    bool matched = false;
     for (const opt::Instruction* src_inst : src_insts) {
       for (const opt::Instruction* dst_inst : dst_insts) {
+        if (id_map_.IsDstMapped(dst_inst)) continue;
+
         const opt::Operand& src_name = src_inst->GetOperand(2);
         const opt::Operand& dst_name = dst_inst->GetOperand(2);
 
@@ -2075,13 +2171,9 @@ void Differ::MatchEntryPointIds() {
           uint32_t dst_id = dst_inst->GetSingleWordOperand(1);
           id_map_.MapIds(src_id, dst_id);
           id_map_.MapInsts(src_inst, dst_inst);
-          matched = true;
           break;
         }
       }
-      if (matched) {
-        break;
-      }
     }
   }
 }
@@ -2126,7 +2218,6 @@ void Differ::MatchTypeForwardPointers() {
       spv::StorageClass::Max, &Differ::GroupIdsHelperGetTypePointerStorageClass,
       [this](const IdGroup& src_group_by_storage_class,
              const IdGroup& dst_group_by_storage_class) {
-
         // Group them further by the type they are pointing to and loop over
         // them.
         GroupIdsAndMatch<spv::Op>(
@@ -2134,7 +2225,6 @@ void Differ::MatchTypeForwardPointers() {
             spv::Op::Max, &Differ::GroupIdsHelperGetTypePointerTypeOp,
             [this](const IdGroup& src_group_by_type_op,
                    const IdGroup& dst_group_by_type_op) {
-
               // Group them even further by debug info, if possible and match by
               // debug name.
               MatchTypeForwardPointersByName(src_group_by_type_op,
@@ -2199,7 +2289,9 @@ void Differ::MatchTypeIds() {
         case spv::Op::OpTypeVoid:
         case spv::Op::OpTypeBool:
         case spv::Op::OpTypeSampler:
-          // void, bool and sampler are unique, match them.
+        case spv::Op::OpTypeAccelerationStructureNV:
+        case spv::Op::OpTypeRayQueryKHR:
+          // the above types have no operands and are unique, match them.
           return true;
         case spv::Op::OpTypeInt:
         case spv::Op::OpTypeFloat:
@@ -2378,7 +2470,6 @@ void Differ::MatchFunctions() {
   GroupIdsAndMatch<std::string>(
       src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
       [this](const IdGroup& src_group, const IdGroup& dst_group) {
-
         // If there is a single function with this name in src and dst, it's a
         // definite match.
         if (src_group.size() == 1 && dst_group.size() == 1) {
@@ -2392,7 +2483,6 @@ void Differ::MatchFunctions() {
                                    &Differ::GroupIdsHelperGetTypeId,
                                    [this](const IdGroup& src_group_by_type_id,
                                           const IdGroup& dst_group_by_type_id) {
-
                                      if (src_group_by_type_id.size() == 1 &&
                                          dst_group_by_type_id.size() == 1) {
                                        id_map_.MapIds(src_group_by_type_id[0],
@@ -2437,7 +2527,6 @@ void Differ::MatchFunctions() {
       src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
       [this](const IdGroup& src_group_by_type_id,
              const IdGroup& dst_group_by_type_id) {
-
         BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
                                  src_func_insts_, dst_func_insts_);
       });
@@ -2647,7 +2736,9 @@ opt::Instruction Differ::ToMappedSrcIds(const opt::Instruction& dst_inst) {
 }
 
 spv_result_t Differ::Output() {
-  id_map_.MapUnmatchedIds();
+  id_map_.MapUnmatchedIds(
+      [this](uint32_t src_id) { return src_id_to_.IsDefined(src_id); },
+      [this](uint32_t dst_id) { return dst_id_to_.IsDefined(dst_id); });
   src_id_to_.inst_map_.resize(id_map_.SrcToDstMap().IdBound(), nullptr);
   dst_id_to_.inst_map_.resize(id_map_.DstToSrcMap().IdBound(), nullptr);
 

+ 2 - 0
3rdparty/spirv-tools/source/opcode.cpp

@@ -274,6 +274,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeStruct:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return true;
     default:
       return false;
@@ -340,6 +341,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
     case spv::Op::OpTypeNamedBarrier:
     case spv::Op::OpTypeAccelerationStructureNV:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     // case spv::Op::OpTypeAccelerationStructureKHR: covered by
     // spv::Op::OpTypeAccelerationStructureNV
     case spv::Op::OpTypeRayQueryKHR:

+ 11 - 0
3rdparty/spirv-tools/source/operand.cpp

@@ -236,6 +236,13 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
     case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
       return "packed vector format";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
+      return "cooperative matrix operands";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+      return "cooperative matrix layout";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
+      return "cooperative matrix use";
     case SPV_OPERAND_TYPE_IMAGE:
     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
       return "image";
@@ -369,6 +376,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
     case SPV_OPERAND_TYPE_OVERFLOW_MODES:
     case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
       return true;
     default:
       break;
@@ -387,6 +396,7 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
     case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
     case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
       return true;
     default:
       break;
@@ -405,6 +415,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
     case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
     case SPV_OPERAND_TYPE_OPTIONAL_CIV:
       return true;
     default:

+ 1 - 0
3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp

@@ -994,6 +994,7 @@ void AggressiveDCEPass::InitExtensions() {
       "SPV_KHR_non_semantic_info",
       "SPV_KHR_uniform_group_instructions",
       "SPV_KHR_fragment_shader_barycentric",
+      "SPV_NV_bindless_texture",
   });
 }
 

+ 118 - 31
3rdparty/spirv-tools/source/opt/const_folding_rules.cpp

@@ -88,6 +88,22 @@ const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
   return nullptr;
 }
 
+// Returns a constants with the value |-val| of the given type.
+const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
+                                         const analysis::Constant* val,
+                                         analysis::ConstantManager* const_mgr) {
+  const analysis::Integer* int_type = result_type->AsInteger();
+  assert(int_type != nullptr);
+
+  if (val->AsNullConstant()) {
+    return val;
+  }
+
+  uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
+  return const_mgr->GetIntConst(new_value, int_type->width(),
+                                int_type->IsSigned());
+}
+
 // Folds an OpcompositeExtract where input is a composite constant.
 ConstantFoldingRule FoldExtractWithConstants() {
   return [](IRContext* context, Instruction* inst,
@@ -341,6 +357,69 @@ ConstantFoldingRule FoldVectorTimesScalar() {
   };
 }
 
+// Returns to the constant that results from tranposing |matrix|. The result
+// will have type |result_type|, and |matrix| must exist in |context|. The
+// result constant will also exist in |context|.
+const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
+                                          analysis::Matrix* result_type,
+                                          IRContext* context) {
+  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+  if (matrix->AsNullConstant() != nullptr) {
+    return const_mgr->GetNullCompositeConstant(result_type);
+  }
+
+  const auto& columns = matrix->AsMatrixConstant()->GetComponents();
+  uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
+
+  // Collect the ids of the elements in their new positions.
+  std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
+  for (const analysis::Constant* column : columns) {
+    if (column->AsNullConstant()) {
+      column = const_mgr->GetNullCompositeConstant(column->type());
+    }
+    const auto& column_components = column->AsVectorConstant()->GetComponents();
+
+    for (uint32_t row = 0; row < number_of_rows; ++row) {
+      result_elements[row].push_back(
+          const_mgr->GetDefiningInstruction(column_components[row])
+              ->result_id());
+    }
+  }
+
+  // Create the constant for each row in the result, and collect the ids.
+  std::vector<uint32_t> result_columns(number_of_rows);
+  for (uint32_t col = 0; col < number_of_rows; ++col) {
+    auto* element = const_mgr->GetConstant(result_type->element_type(),
+                                           result_elements[col]);
+    result_columns[col] =
+        const_mgr->GetDefiningInstruction(element)->result_id();
+  }
+
+  // Create the matrix constant from the row ids, and return it.
+  return const_mgr->GetConstant(result_type, result_columns);
+}
+
+const analysis::Constant* FoldTranspose(
+    IRContext* context, Instruction* inst,
+    const std::vector<const analysis::Constant*>& constants) {
+  assert(inst->opcode() == spv::Op::OpTranspose);
+
+  analysis::TypeManager* type_mgr = context->get_type_mgr();
+  if (!inst->IsFloatingPointFoldingAllowed()) {
+    if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
+      return nullptr;
+    }
+  }
+
+  const analysis::Constant* matrix = constants[0];
+  if (matrix == nullptr) {
+    return nullptr;
+  }
+
+  auto* result_type = type_mgr->GetType(inst->type_id());
+  return TransposeMatrix(matrix, result_type->AsMatrix(), context);
+}
+
 ConstantFoldingRule FoldVectorTimesMatrix() {
   return [](IRContext* context, Instruction* inst,
             const std::vector<const analysis::Constant*>& constants)
@@ -376,13 +455,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
     assert(c1->type()->AsVector()->element_type() == element_type &&
            c2->type()->AsMatrix()->element_type() == vector_type);
 
-    // Get a float vector that is the result of vector-times-matrix.
-    std::vector<const analysis::Constant*> c1_components =
-        c1->GetVectorComponents(const_mgr);
-    std::vector<const analysis::Constant*> c2_components =
-        c2->AsMatrixConstant()->GetComponents();
     uint32_t resultVectorSize = result_type->AsVector()->element_count();
-
     std::vector<uint32_t> ids;
 
     if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@@ -395,6 +468,12 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
       return const_mgr->GetConstant(vector_type, ids);
     }
 
+    // Get a float vector that is the result of vector-times-matrix.
+    std::vector<const analysis::Constant*> c1_components =
+        c1->GetVectorComponents(const_mgr);
+    std::vector<const analysis::Constant*> c2_components =
+        c2->AsMatrixConstant()->GetComponents();
+
     if (float_type->width() == 32) {
       for (uint32_t i = 0; i < resultVectorSize; ++i) {
         float result_scalar = 0.0f;
@@ -472,13 +551,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
     assert(c1->type()->AsMatrix()->element_type() == vector_type);
     assert(c2->type()->AsVector()->element_type() == element_type);
 
-    // Get a float vector that is the result of matrix-times-vector.
-    std::vector<const analysis::Constant*> c1_components =
-        c1->AsMatrixConstant()->GetComponents();
-    std::vector<const analysis::Constant*> c2_components =
-        c2->GetVectorComponents(const_mgr);
     uint32_t resultVectorSize = result_type->AsVector()->element_count();
-
     std::vector<uint32_t> ids;
 
     if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@@ -491,6 +564,12 @@ ConstantFoldingRule FoldMatrixTimesVector() {
       return const_mgr->GetConstant(vector_type, ids);
     }
 
+    // Get a float vector that is the result of matrix-times-vector.
+    std::vector<const analysis::Constant*> c1_components =
+        c1->AsMatrixConstant()->GetComponents();
+    std::vector<const analysis::Constant*> c2_components =
+        c2->GetVectorComponents(const_mgr);
+
     if (float_type->width() == 32) {
       for (uint32_t i = 0; i < resultVectorSize; ++i) {
         float result_scalar = 0.0f;
@@ -587,25 +666,22 @@ using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
     const analysis::Type* result_type, const analysis::Constant* a,
     const analysis::Constant* b, analysis::ConstantManager*)>;
 
-// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
-// using |scalar_rule| and unary float point vectors ops by applying
+// Returns a |ConstantFoldingRule| that folds unary scalar ops
+// using |scalar_rule| and unary vectors ops by applying
 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
 // that is returned assumes that |constants| contains 1 entry.  If they are
 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
 // whose element type is |Float| or |Integer|.
-ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
+ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
   return [scalar_rule](IRContext* context, Instruction* inst,
                        const std::vector<const analysis::Constant*>& constants)
              -> const analysis::Constant* {
+
     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
     analysis::TypeManager* type_mgr = context->get_type_mgr();
     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
     const analysis::Vector* vector_type = result_type->AsVector();
 
-    if (!inst->IsFloatingPointFoldingAllowed()) {
-      return nullptr;
-    }
-
     const analysis::Constant* arg =
         (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
 
@@ -640,6 +716,25 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
   };
 }
 
+// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
+// using |scalar_rule| and unary float point vectors ops by applying
+// |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
+// that is returned assumes that |constants| contains 1 entry.  If they are
+// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
+// whose element type is |Float| or |Integer|.
+ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
+  auto folding_rule = FoldUnaryOp(scalar_rule);
+  return [folding_rule](IRContext* context, Instruction* inst,
+                        const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    if (!inst->IsFloatingPointFoldingAllowed()) {
+      return nullptr;
+    }
+
+    return folding_rule(context, inst, constants);
+  };
+}
+
 // Returns the result of folding the constants in |constants| according the
 // |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
 // per component.
@@ -1042,18 +1137,8 @@ ConstantFoldingRule FoldOpDotWithConstants() {
   };
 }
 
-// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
-// from zero.
-UnaryScalarFoldingRule FoldFNegateOp() {
-  return [](const analysis::Type* result_type, const analysis::Constant* a,
-            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
-    assert(result_type != nullptr && a != nullptr);
-    assert(result_type == a->type());
-    return NegateFPConst(result_type, a, const_mgr);
-  };
-}
-
-ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
+ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
+ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
 
 ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
   return [cmp_opcode](IRContext* context, Instruction* inst,
@@ -1566,8 +1651,10 @@ void ConstantFoldingRules::AddFoldingRules() {
   rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
   rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
   rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
+  rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
 
   rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
+  rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
   rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
 
   // Add rules for GLSLstd450

+ 27 - 0
3rdparty/spirv-tools/source/opt/constants.cpp

@@ -435,6 +435,8 @@ const Constant* ConstantManager::GetNumericVectorConstantWithWords(
     words_per_element = float_type->width() / 32;
   else if (const auto* int_type = element_type->AsInteger())
     words_per_element = int_type->width() / 32;
+  else if (element_type->AsBool() != nullptr)
+    words_per_element = 1;
 
   if (words_per_element != 1 && words_per_element != 2) return nullptr;
 
@@ -487,6 +489,31 @@ uint32_t ConstantManager::GetSIntConstId(int32_t val) {
   return GetDefiningInstruction(c)->result_id();
 }
 
+const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth,
+                                             bool isSigned) {
+  Type* int_type = context()->get_type_mgr()->GetIntType(bitWidth, isSigned);
+
+  if (isSigned) {
+    // Sign extend the value.
+    int32_t num_of_bit_to_ignore = 64 - bitWidth;
+    val = static_cast<int64_t>(val << num_of_bit_to_ignore) >>
+          num_of_bit_to_ignore;
+  } else {
+    // Clear the upper bit that are not used.
+    uint64_t mask = ((1ull << bitWidth) - 1);
+    val &= mask;
+  }
+
+  if (bitWidth <= 32) {
+    return GetConstant(int_type, {static_cast<uint32_t>(val)});
+  }
+
+  // If the value is more than 32-bit, we need to split the operands into two
+  // 32-bit integers.
+  return GetConstant(
+      int_type, {static_cast<uint32_t>(val >> 32), static_cast<uint32_t>(val)});
+}
+
 uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
   Type* uint_type = context()->get_type_mgr()->GetUIntType();
   const Constant* c = GetConstant(uint_type, {val});

+ 6 - 0
3rdparty/spirv-tools/source/opt/constants.h

@@ -659,6 +659,12 @@ class ConstantManager {
   // Returns the id of a 32-bit signed integer constant with value |val|.
   uint32_t GetSIntConstId(int32_t val);
 
+  // Returns an integer constant with `bitWidth` and value |val|. If `isSigned`
+  // is true, the constant will be a signed integer. Otherwise it will be
+  // unsigned. Only the `bitWidth` lower order bits of |val| will be used. The
+  // rest will be ignored.
+  const Constant* GetIntConst(uint64_t val, int32_t bitWidth, bool isSigned);
+
   // Returns the id of a 32-bit unsigned integer constant with value |val|.
   uint32_t GetUIntConstId(uint32_t val);
 

+ 58 - 15
3rdparty/spirv-tools/source/opt/fold.cpp

@@ -627,7 +627,8 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
 
-  if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) {
+  if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&
+      !GetConstantFoldingRules().HasFoldingRule(inst)) {
     return nullptr;
   }
   // Collect the values of the constant parameters.
@@ -661,29 +662,58 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
     }
   }
 
-  uint32_t result_val = 0;
   bool successful = false;
+
   // If all parameters are constant, fold the instruction to a constant.
-  if (!missing_constants && inst->IsFoldableByFoldScalar()) {
-    result_val = FoldScalars(inst->opcode(), constants);
-    successful = true;
-  }
+  if (inst->IsFoldableByFoldScalar()) {
+    uint32_t result_val = 0;
 
-  if (!successful && inst->IsFoldableByFoldScalar()) {
-    successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
-  }
+    if (!missing_constants) {
+      result_val = FoldScalars(inst->opcode(), constants);
+      successful = true;
+    }
+
+    if (!successful) {
+      successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
+    }
+
+    if (successful) {
+      const analysis::Constant* result_const =
+          const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
+      Instruction* folded_inst =
+          const_mgr->GetDefiningInstruction(result_const, inst->type_id());
+      return folded_inst;
+    }
+  } else if (inst->IsFoldableByFoldVector()) {
+    std::vector<uint32_t> result_val;
+
+    if (!missing_constants) {
+      if (Instruction* inst_type =
+              context_->get_def_use_mgr()->GetDef(inst->type_id())) {
+        result_val = FoldVectors(
+            inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);
+        successful = true;
+      }
+    }
 
-  if (successful) {
-    const analysis::Constant* result_const =
-        const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
-    Instruction* folded_inst =
-        const_mgr->GetDefiningInstruction(result_const, inst->type_id());
-    return folded_inst;
+    if (successful) {
+      const analysis::Constant* result_const =
+          const_mgr->GetNumericVectorConstantWithWords(
+              const_mgr->GetType(inst)->AsVector(), result_val);
+      Instruction* folded_inst =
+          const_mgr->GetDefiningInstruction(result_const, inst->type_id());
+      return folded_inst;
+    }
   }
+
   return nullptr;
 }
 
 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
+  return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);
+}
+
+bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const {
   // Support 32-bit integers.
   if (type_inst->opcode() == spv::Op::OpTypeInt) {
     return type_inst->GetSingleWordInOperand(0) == 32;
@@ -696,6 +726,19 @@ bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
   return false;
 }
 
+bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const {
+  // Support vectors with foldable components
+  if (type_inst->opcode() == spv::Op::OpTypeVector) {
+    uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);
+    Instruction* def_component_type =
+        context_->get_def_use_mgr()->GetDef(component_type_id);
+    return def_component_type != nullptr &&
+           IsFoldableScalarType(def_component_type);
+  }
+  // Nothing else yet.
+  return false;
+}
+
 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
   bool modified = false;
   Instruction* folded_inst(inst);

+ 8 - 0
3rdparty/spirv-tools/source/opt/fold.h

@@ -86,6 +86,14 @@ class InstructionFolder {
   // result type is |type_inst|.
   bool IsFoldableType(Instruction* type_inst) const;
 
+  // Returns true if |FoldInstructionToConstant| could fold an instruction whose
+  // result type is |type_inst|.
+  bool IsFoldableScalarType(Instruction* type_inst) const;
+
+  // Returns true if |FoldInstructionToConstant| could fold an instruction whose
+  // result type is |type_inst|.
+  bool IsFoldableVectorType(Instruction* type_inst) const;
+
   // Tries to fold |inst| to a single constant, when the input ids to |inst|
   // have been substituted using |id_map|.  Returns a pointer to the OpConstant*
   // instruction if successful.  If necessary, a new constant instruction is

+ 6 - 2
3rdparty/spirv-tools/source/opt/folding_rules.cpp

@@ -2884,8 +2884,12 @@ FoldingRule UpdateImageOperands() {
                "Offset and ConstOffset may not be used together");
         if (offset_operand_index < inst->NumOperands()) {
           if (constants[offset_operand_index]) {
-            image_operands =
-                image_operands | uint32_t(spv::ImageOperandsMask::ConstOffset);
+            if (constants[offset_operand_index]->IsZero()) {
+              inst->RemoveInOperand(offset_operand_index);
+            } else {
+              image_operands = image_operands |
+                               uint32_t(spv::ImageOperandsMask::ConstOffset);
+            }
             image_operands =
                 image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
             inst->SetInOperand(operand_index, {image_operands});

+ 238 - 463
3rdparty/spirv-tools/source/opt/inst_bindless_check_pass.cpp

@@ -31,15 +31,12 @@ constexpr int kSpvLoadPtrIdInIdx = 0;
 constexpr int kSpvAccessChainBaseIdInIdx = 0;
 constexpr int kSpvAccessChainIndex0IdInIdx = 1;
 constexpr int kSpvTypeArrayTypeIdInIdx = 0;
-constexpr int kSpvTypeArrayLengthIdInIdx = 1;
-constexpr int kSpvConstantValueInIdx = 0;
 constexpr int kSpvVariableStorageClassInIdx = 0;
 constexpr int kSpvTypePtrTypeIdInIdx = 1;
 constexpr int kSpvTypeImageDim = 1;
 constexpr int kSpvTypeImageDepth = 2;
 constexpr int kSpvTypeImageArrayed = 3;
 constexpr int kSpvTypeImageMS = 4;
-constexpr int kSpvTypeImageSampled = 5;
 }  // namespace
 
 void InstBindlessCheckPass::SetupInputBufferIds() {
@@ -135,228 +132,76 @@ void InstBindlessCheckPass::SetupInputBufferIds() {
 
 // clang-format off
 // GLSL:
-// uint inst_bindless_read_binding_length(uint desc_set_idx, uint binding_idx)
-// {
-//     if (desc_set_idx >= inst_bindless_input_buffer.desc_sets.length()) {
-//         return 0;
-//     }
-//
-//     DescriptorSetData set_data = inst_bindless_input_buffer.desc_sets[desc_set_idx];
-//     uvec2 ptr_as_vec = uvec2(set_data);
-//     if ((ptr_as_vec.x == 0u) && (_ptr_as_vec.y == 0u))
-//     {
-//         return 0u;
-//     }
-//     uint num_bindings = set_data.num_bindings;
-//     if (binding_idx >= num_bindings) {
-//         return 0;
-//     }
-//     return set_data.data[binding_idx];
-// }
+//bool inst_bindless_check_desc(uint shader_id, uint line, uvec4 stage_info, uint desc_set_idx, uint binding_idx, uint desc_idx,
+//                              uint offset)
+//{
+//    if (desc_set_idx >= inst_bindless_input_buffer.desc_sets.length()) {
+//        // kInstErrorBindlessBounds
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 1, desc_set_idx, binding_idx, desc_idx, 0, 0);
+//        return false;
+//    }
+//    DescriptorSetData set_data = inst_bindless_input_buffer.desc_sets[desc_set_idx];
+//    uvec2 ptr_vec = uvec2(set_data);
+//    if (ptr_vec.x == 0 && ptr_vec.y == 0) {
+//        // kInstErrorBindlessBounds
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 1, desc_set_idx, binding_idx, desc_idx, 0, 0);
+//        return false;
+//    }
+//    uint num_bindings = set_data.num_bindings;
+//    if (binding_idx >= num_bindings) {
+//        // kInstErrorBindlessBounds
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 1, desc_set_idx, binding_idx, desc_idx, 0, 0);
+//        return false;
+//    }
+//    uint binding_length = set_data.data[binding_idx];
+//    if (desc_idx >= binding_length) {
+//        // kInstErrorBindlessBounds
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 1, desc_set_idx, binding_idx, desc_idx, binding_length, 0);
+//        return false;
+//    }
+//    uint desc_records_start = set_data.data[num_bindings + binding_idx];
+//    uint init_or_len = set_data.data[desc_records_start + desc_idx];
+//    if (init_or_len == 0) {
+//        // kInstErrorBindlessUninit
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 2, desc_set_idx, binding_idx, desc_idx, 0, 0);
+//        return false;
+//    }
+//    if (offset >= init_or_len) {
+//        // kInstErrorOOB
+//        inst_bindless_stream_write_6(shader_id, line, stage_info, 4, desc_set_idx, binding_idx, desc_idx, offset,
+//                                      init_or_len);
+//        return false;
+//    }
+//    return true;
+//}
 // clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadLengthFunctionId() {
-  if (read_length_func_id_ != 0) {
-    return read_length_func_id_;
-  }
-  SetupInputBufferIds();
-  const analysis::Integer* uint_type = GetInteger(32, false);
-  const std::vector<const analysis::Type*> param_types(2, uint_type);
-
-  const uint32_t func_id = TakeNextId();
-  std::unique_ptr<Function> func =
-      StartFunction(func_id, uint_type, param_types);
-
-  const std::vector<uint32_t> param_ids = AddParameters(*func, param_types);
-
-  // Create block
-  auto new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(TakeNextId()));
-  InstructionBuilder builder(
-      context(), new_blk_ptr.get(),
-      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-  Instruction* inst;
-
-  inst = builder.AddBinaryOp(
-      GetBoolId(), spv::Op::OpUGreaterThanEqual, param_ids[0],
-      builder.GetUintConstantId(kDebugInputBindlessMaxDescSets));
-  const uint32_t desc_cmp_id = inst->result_id();
-
-  uint32_t error_blk_id = TakeNextId();
-  uint32_t merge_blk_id = TakeNextId();
-  std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
-  std::unique_ptr<Instruction> error_label(NewLabel(error_blk_id));
-  (void)builder.AddConditionalBranch(desc_cmp_id, error_blk_id, merge_blk_id,
-                                     merge_blk_id);
-
-  func->AddBasicBlock(std::move(new_blk_ptr));
-
-  // error return
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
-  func->AddBasicBlock(std::move(new_blk_ptr));
-
-  // check descriptor set table entry is non-null
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-
-  analysis::TypeManager* type_mgr = context()->get_type_mgr();
-  const uint32_t desc_set_ptr_ptr = type_mgr->FindPointerToType(
-      desc_set_ptr_id_, spv::StorageClass::StorageBuffer);
-
-  inst = builder.AddAccessChain(desc_set_ptr_ptr, input_buffer_id_,
-                                {builder.GetUintConstantId(0), param_ids[0]});
-  const uint32_t set_access_chain_id = inst->result_id();
-
-  inst = builder.AddLoad(desc_set_ptr_id_, set_access_chain_id);
-  const uint32_t desc_set_ptr_id = inst->result_id();
-
-  inst =
-      builder.AddUnaryOp(GetVecUintId(2), spv::Op::OpBitcast, desc_set_ptr_id);
-  const uint32_t ptr_as_uvec_id = inst->result_id();
 
-  inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {0});
-  const uint32_t uvec_x = inst->result_id();
-
-  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_x,
-                             builder.GetUintConstantId(0));
-  const uint32_t x_is_zero_id = inst->result_id();
-
-  inst = builder.AddCompositeExtract(GetUintId(), ptr_as_uvec_id, {1});
-  const uint32_t uvec_y = inst->result_id();
-
-  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, uvec_y,
-                             builder.GetUintConstantId(0));
-  const uint32_t y_is_zero_id = inst->result_id();
-
-  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpLogicalAnd, x_is_zero_id,
-                             y_is_zero_id);
-  const uint32_t is_null_id = inst->result_id();
-
-  error_blk_id = TakeNextId();
-  merge_blk_id = TakeNextId();
-  merge_label = NewLabel(merge_blk_id);
-  error_label = NewLabel(error_blk_id);
-  (void)builder.AddConditionalBranch(is_null_id, error_blk_id, merge_blk_id,
-                                     merge_blk_id);
-  func->AddBasicBlock(std::move(new_blk_ptr));
-  // error return
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
-  func->AddBasicBlock(std::move(new_blk_ptr));
-
-  // check binding is in range
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-
-  const uint32_t uint_ptr = type_mgr->FindPointerToType(
-      GetUintId(), spv::StorageClass::PhysicalStorageBuffer);
-
-  inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
-                                {builder.GetUintConstantId(0)});
-  const uint32_t binding_access_chain_id = inst->result_id();
-
-  inst = builder.AddLoad(GetUintId(), binding_access_chain_id, 8);
-  const uint32_t num_bindings_id = inst->result_id();
-
-  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
-                             param_ids[1], num_bindings_id);
-  const uint32_t bindings_cmp_id = inst->result_id();
-
-  error_blk_id = TakeNextId();
-  merge_blk_id = TakeNextId();
-  merge_label = NewLabel(merge_blk_id);
-  error_label = NewLabel(error_blk_id);
-  (void)builder.AddConditionalBranch(bindings_cmp_id, error_blk_id,
-                                     merge_blk_id, merge_blk_id);
-  func->AddBasicBlock(std::move(new_blk_ptr));
-  // error return
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
-  func->AddBasicBlock(std::move(new_blk_ptr));
-
-  // read binding length
-  new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
-  builder.SetInsertPoint(&*new_blk_ptr);
-
-  inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
-                                {{builder.GetUintConstantId(1), param_ids[1]}});
-  const uint32_t length_ac_id = inst->result_id();
-
-  inst = builder.AddLoad(GetUintId(), length_ac_id, sizeof(uint32_t));
-  const uint32_t length_id = inst->result_id();
-
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, length_id);
-
-  func->AddBasicBlock(std::move(new_blk_ptr));
-  func->SetFunctionEnd(EndFunction());
-
-  context()->AddFunction(std::move(func));
-  context()->AddDebug2Inst(NewGlobalName(func_id, "read_binding_length"));
-
-  read_length_func_id_ = func_id;
-  // Make sure this function doesn't get processed by
-  // InstrumentPass::InstProcessCallTreeFromRoots()
-  param2output_func_id_[2] = func_id;
-  return read_length_func_id_;
-}
-
-// clang-format off
-// GLSL:
-// result = inst_bindless_read_binding_length(desc_set_id, binding_id);
-// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadLength(
-    uint32_t var_id, InstructionBuilder* builder) {
-  const uint32_t func_id = GenDebugReadLengthFunctionId();
-
-  const std::vector<uint32_t> args = {
-      builder->GetUintConstantId(var2desc_set_[var_id]),
-      builder->GetUintConstantId(var2binding_[var_id]),
+uint32_t InstBindlessCheckPass::GenDescCheckFunctionId() {
+  enum {
+    kShaderId = 0,
+    kInstructionIndex = 1,
+    kStageInfo = 2,
+    kDescSet = 3,
+    kDescBinding = 4,
+    kDescIndex = 5,
+    kByteOffset = 6,
+    kNumArgs
   };
-  return GenReadFunctionCall(func_id, args, builder);
-}
-
-// clang-format off
-// GLSL:
-// uint inst_bindless_read_desc_init(uint desc_set_idx, uint binding_idx, uint desc_idx)
-// {
-//     if (desc_set_idx >= uint(inst_bindless_input_buffer.desc_sets.length()))
-//     {
-//         return 0u;
-//     }
-//     DescriptorSetData set_data = inst_bindless_input_buffer.desc_sets[desc_set_idx];
-//     uvec2 ptr_as_vec = uvec2(set_data)
-//     if ((ptr_as_vec .x == 0u) && (ptr_as_vec.y == 0u))
-//     {
-//         return 0u;
-//     }
-//     if (binding_idx >= set_data.num_bindings)
-//     {
-//         return 0u;
-//     }
-//     if (desc_idx >= set_data.data[binding_idx])
-//     {
-//         return 0u;
-//     }
-//     uint desc_records_start = set_data.data[set_data.num_bindings + binding_idx];
-//     return set_data.data[desc_records_start + desc_idx];
-// }
-// clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
-  if (read_init_func_id_ != 0) {
-    return read_init_func_id_;
+  if (desc_check_func_id_ != 0) {
+    return desc_check_func_id_;
   }
+
   SetupInputBufferIds();
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
   const analysis::Integer* uint_type = GetInteger(32, false);
-  const std::vector<const analysis::Type*> param_types(3, uint_type);
+  const analysis::Vector v4uint(uint_type, 4);
+  const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
+  std::vector<const analysis::Type*> param_types(kNumArgs, uint_type);
+  param_types[2] = v4uint_type;
 
   const uint32_t func_id = TakeNextId();
   std::unique_ptr<Function> func =
-      StartFunction(func_id, uint_type, param_types);
+      StartFunction(func_id, type_mgr->GetBoolType(), param_types);
 
   const std::vector<uint32_t> param_ids = AddParameters(*func, param_types);
 
@@ -365,10 +210,12 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   InstructionBuilder builder(
       context(), new_blk_ptr.get(),
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  const uint32_t false_id = builder.GetBoolConstantId(false);
+  const uint32_t true_id = builder.GetBoolConstantId(true);
   Instruction* inst;
 
   inst = builder.AddBinaryOp(
-      GetBoolId(), spv::Op::OpUGreaterThanEqual, param_ids[0],
+      GetBoolId(), spv::Op::OpUGreaterThanEqual, param_ids[kDescSet],
       builder.GetUintConstantId(kDebugInputBindlessMaxDescSets));
   const uint32_t desc_cmp_id = inst->result_id();
 
@@ -383,20 +230,19 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   // error return
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
   builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
   func->AddBasicBlock(std::move(new_blk_ptr));
 
   // check descriptor set table entry is non-null
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
   builder.SetInsertPoint(&*new_blk_ptr);
 
-  analysis::TypeManager* type_mgr = context()->get_type_mgr();
   const uint32_t desc_set_ptr_ptr = type_mgr->FindPointerToType(
       desc_set_ptr_id_, spv::StorageClass::StorageBuffer);
 
-  inst = builder.AddAccessChain(desc_set_ptr_ptr, input_buffer_id_,
-                                {builder.GetUintConstantId(0), param_ids[0]});
+  inst = builder.AddAccessChain(
+      desc_set_ptr_ptr, input_buffer_id_,
+      {builder.GetUintConstantId(0), param_ids[kDescSet]});
   const uint32_t set_access_chain_id = inst->result_id();
 
   inst = builder.AddLoad(desc_set_ptr_id_, set_access_chain_id);
@@ -434,8 +280,13 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   // error return
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
   builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
+  GenDebugStreamWrite(
+      param_ids[kShaderId], param_ids[kInstructionIndex], param_ids[kStageInfo],
+      {builder.GetUintConstantId(kInstErrorBindlessBounds), param_ids[kDescSet],
+       param_ids[kDescBinding], param_ids[kDescIndex],
+       builder.GetUintConstantId(0), builder.GetUintConstantId(0)},
+      &builder);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
   func->AddBasicBlock(std::move(new_blk_ptr));
 
   // check binding is in range
@@ -453,7 +304,7 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   const uint32_t num_bindings_id = inst->result_id();
 
   inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
-                             param_ids[1], num_bindings_id);
+                             param_ids[kDescBinding], num_bindings_id);
   const uint32_t bindings_cmp_id = inst->result_id();
 
   error_blk_id = TakeNextId();
@@ -466,16 +317,22 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   // error return
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
   builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
+  GenDebugStreamWrite(
+      param_ids[kShaderId], param_ids[kInstructionIndex], param_ids[kStageInfo],
+      {builder.GetUintConstantId(kInstErrorBindlessBounds), param_ids[kDescSet],
+       param_ids[kDescBinding], param_ids[kDescIndex],
+       builder.GetUintConstantId(0), builder.GetUintConstantId(0)},
+      &builder);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
   func->AddBasicBlock(std::move(new_blk_ptr));
 
   // read binding length
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
   builder.SetInsertPoint(&*new_blk_ptr);
 
-  inst = builder.AddAccessChain(uint_ptr, desc_set_ptr_id,
-                                {{builder.GetUintConstantId(1), param_ids[1]}});
+  inst = builder.AddAccessChain(
+      uint_ptr, desc_set_ptr_id,
+      {{builder.GetUintConstantId(1), param_ids[kDescBinding]}});
   const uint32_t length_ac_id = inst->result_id();
 
   inst = builder.AddLoad(GetUintId(), length_ac_id, sizeof(uint32_t));
@@ -483,7 +340,7 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
 
   // Check descriptor index in bounds
   inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
-                             param_ids[2], length_id);
+                             param_ids[kDescIndex], length_id);
   const uint32_t desc_idx_range_id = inst->result_id();
 
   error_blk_id = TakeNextId();
@@ -496,15 +353,20 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   // Error return
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
   builder.SetInsertPoint(&*new_blk_ptr);
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue,
-                           builder.GetUintConstantId(0));
+  GenDebugStreamWrite(
+      param_ids[kShaderId], param_ids[kInstructionIndex], param_ids[kStageInfo],
+      {builder.GetUintConstantId(kInstErrorBindlessBounds), param_ids[kDescSet],
+       param_ids[kDescBinding], param_ids[kDescIndex], length_id,
+       builder.GetUintConstantId(0)},
+      &builder);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
   func->AddBasicBlock(std::move(new_blk_ptr));
 
   // Read descriptor init status
   new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
   builder.SetInsertPoint(&*new_blk_ptr);
 
-  inst = builder.AddIAdd(GetUintId(), num_bindings_id, param_ids[1]);
+  inst = builder.AddIAdd(GetUintId(), num_bindings_id, param_ids[kDescBinding]);
   const uint32_t state_offset_id = inst->result_id();
 
   inst =
@@ -515,7 +377,7 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   inst = builder.AddLoad(GetUintId(), state_start_ac_id, sizeof(uint32_t));
   const uint32_t state_start_id = inst->result_id();
 
-  inst = builder.AddIAdd(GetUintId(), state_start_id, param_ids[2]);
+  inst = builder.AddIAdd(GetUintId(), state_start_id, param_ids[kDescIndex]);
   const uint32_t state_entry_id = inst->result_id();
 
   // Note: length starts from the beginning of the buffer, not the beginning of
@@ -528,35 +390,90 @@ uint32_t InstBindlessCheckPass::GenDebugReadInitFunctionId() {
   inst = builder.AddLoad(GetUintId(), init_ac_id, sizeof(uint32_t));
   const uint32_t init_status_id = inst->result_id();
 
-  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, init_status_id);
+  // Check for uninitialized descriptor
+  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpIEqual, init_status_id,
+                             builder.GetUintConstantId(0));
+  const uint32_t uninit_check_id = inst->result_id();
+  error_blk_id = TakeNextId();
+  merge_blk_id = TakeNextId();
+  merge_label = NewLabel(merge_blk_id);
+  error_label = NewLabel(error_blk_id);
+  (void)builder.AddConditionalBranch(uninit_check_id, error_blk_id,
+                                     merge_blk_id, merge_blk_id);
+  func->AddBasicBlock(std::move(new_blk_ptr));
+  new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
+  builder.SetInsertPoint(&*new_blk_ptr);
+  GenDebugStreamWrite(
+      param_ids[kShaderId], param_ids[kInstructionIndex], param_ids[kStageInfo],
+      {builder.GetUintConstantId(kInstErrorBindlessUninit), param_ids[kDescSet],
+       param_ids[kDescBinding], param_ids[kDescIndex],
+       builder.GetUintConstantId(0), builder.GetUintConstantId(0)},
+      &builder);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
+  func->AddBasicBlock(std::move(new_blk_ptr));
 
+  // Check for OOB.
+  new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
+  builder.SetInsertPoint(&*new_blk_ptr);
+  inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpUGreaterThanEqual,
+                             param_ids[kByteOffset], init_status_id);
+  const uint32_t buf_offset_range_id = inst->result_id();
+
+  error_blk_id = TakeNextId();
+  merge_blk_id = TakeNextId();
+  merge_label = NewLabel(merge_blk_id);
+  error_label = NewLabel(error_blk_id);
+  (void)builder.AddConditionalBranch(buf_offset_range_id, error_blk_id,
+                                     merge_blk_id, merge_blk_id);
   func->AddBasicBlock(std::move(new_blk_ptr));
+  // Error return
+  new_blk_ptr = MakeUnique<BasicBlock>(std::move(error_label));
+  builder.SetInsertPoint(&*new_blk_ptr);
+  GenDebugStreamWrite(
+      param_ids[kShaderId], param_ids[kInstructionIndex], param_ids[kStageInfo],
+      {builder.GetUintConstantId(kInstErrorOOB), param_ids[kDescSet],
+       param_ids[kDescBinding], param_ids[kDescIndex], param_ids[kByteOffset],
+       init_status_id},
+      &builder);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, false_id);
+  func->AddBasicBlock(std::move(new_blk_ptr));
+
+  // Success return
+  new_blk_ptr = MakeUnique<BasicBlock>(std::move(merge_label));
+  builder.SetInsertPoint(&*new_blk_ptr);
+  (void)builder.AddUnaryOp(0, spv::Op::OpReturnValue, true_id);
+  func->AddBasicBlock(std::move(new_blk_ptr));
+
   func->SetFunctionEnd(EndFunction());
 
   context()->AddFunction(std::move(func));
-  context()->AddDebug2Inst(NewGlobalName(func_id, "read_desc_init"));
+  context()->AddDebug2Inst(NewGlobalName(func_id, "desc_check"));
 
-  read_init_func_id_ = func_id;
+  desc_check_func_id_ = func_id;
   // Make sure function doesn't get processed by
   // InstrumentPass::InstProcessCallTreeFromRoots()
   param2output_func_id_[3] = func_id;
-  return read_init_func_id_;
+  return desc_check_func_id_;
 }
 
 // clang-format off
 // GLSL:
-// result = inst_bindless_read_desc_init(desc_set_id, binding_id, desc_idx_id);
+// result = inst_bindless_desc_check(shader_id, inst_idx, stage_info, desc_set, binding, desc_idx, offset);
 //
 // clang-format on
-uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
-                                                 uint32_t desc_idx_id,
-                                                 InstructionBuilder* builder) {
-  const uint32_t func_id = GenDebugReadInitFunctionId();
+uint32_t InstBindlessCheckPass::GenDescCheckCall(
+    uint32_t inst_idx, uint32_t stage_idx, uint32_t var_id,
+    uint32_t desc_idx_id, uint32_t offset_id, InstructionBuilder* builder) {
+  const uint32_t func_id = GenDescCheckFunctionId();
   const std::vector<uint32_t> args = {
+      builder->GetUintConstantId(shader_id_),
+      builder->GetUintConstantId(inst_idx),
+      GenStageInfo(stage_idx, builder),
       builder->GetUintConstantId(var2desc_set_[var_id]),
       builder->GetUintConstantId(var2binding_[var_id]),
-      GenUintCastCode(desc_idx_id, builder)};
-  return GenReadFunctionCall(func_id, args, builder);
+      GenUintCastCode(desc_idx_id, builder),
+      offset_id};
+  return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
 }
 
 uint32_t InstBindlessCheckPass::CloneOriginalImage(
@@ -1047,29 +964,30 @@ void InstBindlessCheckPass::GenCheckCode(
   // Gen invalid block
   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
   builder.SetInsertPoint(&*new_blk_ptr);
-  const uint32_t u_set_id = builder.GetUintConstantId(ref->set);
-  const uint32_t u_binding_id = builder.GetUintConstantId(ref->binding);
-  const uint32_t u_index_id = GenUintCastCode(ref->desc_idx_id, &builder);
-  const uint32_t u_length_id = GenUintCastCode(length_id, &builder);
-  if (offset_id != 0) {
-    const uint32_t u_offset_id = GenUintCastCode(offset_id, &builder);
-    // Buffer OOB
-    GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
-                        {error_id, u_set_id, u_binding_id, u_index_id,
-                         u_offset_id, u_length_id},
-                        &builder);
-  } else if (buffer_bounds_enabled_ || texel_buffer_enabled_) {
-    // Uninitialized Descriptor - Return additional unused zero so all error
-    // modes will use same debug stream write function
-    GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
-                        {error_id, u_set_id, u_binding_id, u_index_id,
-                         u_length_id, builder.GetUintConstantId(0)},
-                        &builder);
-  } else {
-    // Uninitialized Descriptor - Normal error return
-    GenDebugStreamWrite(
-        uid2offset_[ref->ref_inst->unique_id()], stage_idx,
-        {error_id, u_set_id, u_binding_id, u_index_id, u_length_id}, &builder);
+  if (error_id != 0) {
+    const uint32_t u_shader_id = builder.GetUintConstantId(shader_id_);
+    const uint32_t u_inst_id =
+        builder.GetUintConstantId(ref->ref_inst->unique_id());
+    const uint32_t shader_info_id = GenStageInfo(stage_idx, &builder);
+    const uint32_t u_set_id = builder.GetUintConstantId(ref->set);
+    const uint32_t u_binding_id = builder.GetUintConstantId(ref->binding);
+    const uint32_t u_index_id = GenUintCastCode(ref->desc_idx_id, &builder);
+    const uint32_t u_length_id = GenUintCastCode(length_id, &builder);
+    if (offset_id != 0) {
+      const uint32_t u_offset_id = GenUintCastCode(offset_id, &builder);
+      // Buffer OOB
+      GenDebugStreamWrite(u_shader_id, u_inst_id, shader_info_id,
+                          {error_id, u_set_id, u_binding_id, u_index_id,
+                           u_offset_id, u_length_id},
+                          &builder);
+    } else {
+      // Uninitialized Descriptor - Return additional unused zero so all error
+      // modes will use same debug stream write function
+      GenDebugStreamWrite(u_shader_id, u_inst_id, shader_info_id,
+                          {error_id, u_set_id, u_binding_id, u_index_id,
+                           u_length_id, builder.GetUintConstantId(0)},
+                          &builder);
+    }
   }
   // Generate a ConstantNull, converting to uint64 if the type cannot be a null.
   if (new_ref_id != 0) {
@@ -1106,77 +1024,42 @@ void InstBindlessCheckPass::GenCheckCode(
   context()->KillInst(ref->ref_inst);
 }
 
-void InstBindlessCheckPass::GenDescIdxCheckCode(
+void InstBindlessCheckPass::GenDescCheckCode(
     BasicBlock::iterator ref_inst_itr,
     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-  // Look for reference through indexed descriptor. If found, analyze and
-  // save components. If not, return.
+  // Look for reference through descriptor. If not, return.
   RefAnalysis ref;
   if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
-  Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
-  if (ptr_inst->opcode() != spv::Op::OpAccessChain) return;
-  // If index and bound both compile-time constants and index < bound,
-  // return without changing
-  Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
-  Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
-  uint32_t length_id = 0;
-  if (desc_type_inst->opcode() == spv::Op::OpTypeArray) {
-    length_id =
-        desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
-    Instruction* index_inst = get_def_use_mgr()->GetDef(ref.desc_idx_id);
-    Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
-    if (index_inst->opcode() == spv::Op::OpConstant &&
-        length_inst->opcode() == spv::Op::OpConstant &&
-        index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
-            length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
-      return;
-  } else if (!desc_idx_enabled_ ||
-             desc_type_inst->opcode() != spv::Op::OpTypeRuntimeArray) {
-    return;
-  }
-  // Move original block's preceding instructions into first new block
   std::unique_ptr<BasicBlock> new_blk_ptr;
+  // Move original block's preceding instructions into first new block
   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
   InstructionBuilder builder(
       context(), &*new_blk_ptr,
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
   new_blocks->push_back(std::move(new_blk_ptr));
-  uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
-  // If length id not yet set, descriptor array is runtime size so
-  // generate load of length from stage's debug input buffer.
-  if (length_id == 0) {
-    assert(desc_type_inst->opcode() == spv::Op::OpTypeRuntimeArray &&
-           "unexpected bindless type");
-    length_id = GenDebugReadLength(ref.var_id, &builder);
-  }
-  // Generate full runtime bounds test code with true branch
-  // being full reference and false branch being debug output and zero
-  // for the referenced value.
-  uint32_t desc_idx_32b_id = Gen32BitCvtCode(ref.desc_idx_id, &builder);
-  uint32_t length_32b_id = Gen32BitCvtCode(length_id, &builder);
-  Instruction* ult_inst = builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan,
-                                              desc_idx_32b_id, length_32b_id);
-  ref.desc_idx_id = desc_idx_32b_id;
-  GenCheckCode(ult_inst->result_id(), error_id, 0u, length_id, stage_idx, &ref,
-               new_blocks);
-  // Move original block's remaining code into remainder/merge block and add
-  // to new blocks
-  BasicBlock* back_blk_ptr = &*new_blocks->back();
-  MovePostludeCode(ref_block_itr, back_blk_ptr);
-}
-
-void InstBindlessCheckPass::GenDescInitCheckCode(
-    BasicBlock::iterator ref_inst_itr,
-    UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
-    std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-  // Look for reference through descriptor. If not, return.
-  RefAnalysis ref;
-  if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
   // Determine if we can only do initialization check
-  bool init_check = false;
-  if (ref.desc_load_id != 0 || !buffer_bounds_enabled_) {
-    init_check = true;
+  uint32_t ref_id = builder.GetUintConstantId(0u);
+  spv::Op op = ref.ref_inst->opcode();
+  if (ref.desc_load_id != 0) {
+    uint32_t num_in_oprnds = ref.ref_inst->NumInOperands();
+    if ((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
+        (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
+        (op == spv::Op::OpImageWrite && num_in_oprnds == 3)) {
+      Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
+      uint32_t image_ty_id = image_inst->type_id();
+      Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
+      if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) ==
+          spv::Dim::Buffer) {
+        if ((image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) == 0) &&
+            (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) ==
+             0) &&
+            (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) == 0)) {
+          ref_id = GenUintCastCode(ref.ref_inst->GetSingleWordInOperand(1),
+                                   &builder);
+        }
+      }
+    }
   } else {
     // For now, only do bounds check for non-aggregate types. Otherwise
     // just do descriptor initialization check.
@@ -1184,106 +1067,24 @@ void InstBindlessCheckPass::GenDescInitCheckCode(
     Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
     Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
     spv::Op pte_type_op = pte_type_inst->opcode();
-    if (pte_type_op == spv::Op::OpTypeArray ||
-        pte_type_op == spv::Op::OpTypeRuntimeArray ||
-        pte_type_op == spv::Op::OpTypeStruct)
-      init_check = true;
+    if (pte_type_op != spv::Op::OpTypeArray &&
+        pte_type_op != spv::Op::OpTypeRuntimeArray &&
+        pte_type_op != spv::Op::OpTypeStruct) {
+      ref_id = GenLastByteIdx(&ref, &builder);
+    }
   }
-  // If initialization check and not enabled, return
-  if (init_check && !desc_init_enabled_) return;
-  // Move original block's preceding instructions into first new block
-  std::unique_ptr<BasicBlock> new_blk_ptr;
-  MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
-  InstructionBuilder builder(
-      context(), &*new_blk_ptr,
-      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-  new_blocks->push_back(std::move(new_blk_ptr));
-  // If initialization check, use reference value of zero.
-  // Else use the index of the last byte referenced.
-  uint32_t ref_id = init_check ? builder.GetUintConstantId(0u)
-                               : GenLastByteIdx(&ref, &builder);
   // Read initialization/bounds from debug input buffer. If index id not yet
   // set, binding is single descriptor, so set index to constant 0.
   if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
-  uint32_t init_id = GenDebugReadInit(ref.var_id, ref.desc_idx_id, &builder);
-  // Generate runtime initialization/bounds test code with true branch
-  // being full reference and false branch being debug output and zero
-  // for the referenced value.
-  Instruction* ult_inst =
-      builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan, ref_id, init_id);
-  uint32_t error =
-      init_check
-          ? kInstErrorBindlessUninit
-          : (spv::StorageClass(ref.strg_class) == spv::StorageClass::Uniform
-                 ? kInstErrorBuffOOBUniform
-                 : kInstErrorBuffOOBStorage);
-  uint32_t error_id = builder.GetUintConstantId(error);
-  GenCheckCode(ult_inst->result_id(), error_id, init_check ? 0 : ref_id,
-               init_check ? builder.GetUintConstantId(0u) : init_id, stage_idx,
-               &ref, new_blocks);
-  // Move original block's remaining code into remainder/merge block and add
-  // to new blocks
-  BasicBlock* back_blk_ptr = &*new_blocks->back();
-  MovePostludeCode(ref_block_itr, back_blk_ptr);
-}
+  uint32_t check_id =
+      GenDescCheckCall(ref.ref_inst->unique_id(), stage_idx, ref.var_id,
+                       ref.desc_idx_id, ref_id, &builder);
 
-void InstBindlessCheckPass::GenTexBuffCheckCode(
-    BasicBlock::iterator ref_inst_itr,
-    UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
-    std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-  // Only process OpImageRead and OpImageWrite with no optional operands
-  Instruction* ref_inst = &*ref_inst_itr;
-  spv::Op op = ref_inst->opcode();
-  uint32_t num_in_oprnds = ref_inst->NumInOperands();
-  if (!((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
-        (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
-        (op == spv::Op::OpImageWrite && num_in_oprnds == 3)))
-    return;
-  // Pull components from descriptor reference
-  RefAnalysis ref;
-  if (!AnalyzeDescriptorReference(ref_inst, &ref)) return;
-  // Only process if image is texel buffer
-  Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
-  uint32_t image_ty_id = image_inst->type_id();
-  Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
-  if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) !=
-      spv::Dim::Buffer) {
-    return;
-  }
-  if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) != 0) return;
-  if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) != 0) return;
-  if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) != 0) return;
-  // Enable ImageQuery Capability if not yet enabled
-  context()->AddCapability(spv::Capability::ImageQuery);
-  // Move original block's preceding instructions into first new block
-  std::unique_ptr<BasicBlock> new_blk_ptr;
-  MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
-  InstructionBuilder builder(
-      context(), &*new_blk_ptr,
-      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-  new_blocks->push_back(std::move(new_blk_ptr));
-  // Get texel coordinate
-  uint32_t coord_id =
-      GenUintCastCode(ref_inst->GetSingleWordInOperand(1), &builder);
-  // If index id not yet set, binding is single descriptor, so set index to
-  // constant 0.
-  if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
-  // Get texel buffer size.
-  Instruction* size_inst =
-      builder.AddUnaryOp(GetUintId(), spv::Op::OpImageQuerySize, ref.image_id);
-  uint32_t size_id = size_inst->result_id();
   // Generate runtime initialization/bounds test code with true branch
-  // being full reference and false branch being debug output and zero
+  // being full reference and false branch being zero
   // for the referenced value.
-  Instruction* ult_inst =
-      builder.AddBinaryOp(GetBoolId(), spv::Op::OpULessThan, coord_id, size_id);
-  uint32_t error =
-      (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageSampled) == 2)
-          ? kInstErrorBuffOOBStorageTexel
-          : kInstErrorBuffOOBUniformTexel;
-  uint32_t error_id = builder.GetUintConstantId(error);
-  GenCheckCode(ult_inst->result_id(), error_id, coord_id, size_id, stage_idx,
-               &ref, new_blocks);
+  GenCheckCode(check_id, 0, 0, 0, stage_idx, &ref, new_blocks);
+
   // Move original block's remaining code into remainder/merge block and add
   // to new blocks
   BasicBlock* back_blk_ptr = &*new_blocks->back();
@@ -1293,58 +1094,32 @@ void InstBindlessCheckPass::GenTexBuffCheckCode(
 void InstBindlessCheckPass::InitializeInstBindlessCheck() {
   // Initialize base class
   InitializeInstrument();
-  // If runtime array length support or buffer bounds checking are enabled,
-  // create variable mappings. Length support is always enabled if descriptor
-  // init check is enabled.
-  if (desc_idx_enabled_ || buffer_bounds_enabled_ || texel_buffer_enabled_)
-    for (auto& anno : get_module()->annotations())
-      if (anno.opcode() == spv::Op::OpDecorate) {
-        if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
-            spv::Decoration::DescriptorSet) {
-          var2desc_set_[anno.GetSingleWordInOperand(0u)] =
-              anno.GetSingleWordInOperand(2u);
-        } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
-                   spv::Decoration::Binding) {
-          var2binding_[anno.GetSingleWordInOperand(0u)] =
-              anno.GetSingleWordInOperand(2u);
-        }
+  for (auto& anno : get_module()->annotations()) {
+    if (anno.opcode() == spv::Op::OpDecorate) {
+      if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
+          spv::Decoration::DescriptorSet) {
+        var2desc_set_[anno.GetSingleWordInOperand(0u)] =
+            anno.GetSingleWordInOperand(2u);
+      } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
+                 spv::Decoration::Binding) {
+        var2binding_[anno.GetSingleWordInOperand(0u)] =
+            anno.GetSingleWordInOperand(2u);
       }
+    }
+  }
 }
 
 Pass::Status InstBindlessCheckPass::ProcessImpl() {
-  // Perform bindless bounds check on each entry point function in module
+  bool modified = false;
   InstProcessFunction pfn =
       [this](BasicBlock::iterator ref_inst_itr,
              UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-        return GenDescIdxCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
-                                   new_blocks);
+        return GenDescCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
+                                new_blocks);
       };
-  bool modified = InstProcessEntryPointCallTree(pfn);
-  if (desc_init_enabled_ || buffer_bounds_enabled_) {
-    // Perform descriptor initialization and/or buffer bounds check on each
-    // entry point function in module
-    pfn = [this](BasicBlock::iterator ref_inst_itr,
-                 UptrVectorIterator<BasicBlock> ref_block_itr,
-                 uint32_t stage_idx,
-                 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-      return GenDescInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
-                                  new_blocks);
-    };
-    modified |= InstProcessEntryPointCallTree(pfn);
-  }
-  if (texel_buffer_enabled_) {
-    // Perform texel buffer bounds check on each entry point function in
-    // module. Generate after descriptor bounds and initialization checks.
-    pfn = [this](BasicBlock::iterator ref_inst_itr,
-                 UptrVectorIterator<BasicBlock> ref_block_itr,
-                 uint32_t stage_idx,
-                 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
-      return GenTexBuffCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
-                                 new_blocks);
-    };
-    modified |= InstProcessEntryPointCallTree(pfn);
-  }
+
+  modified = InstProcessEntryPointCallTree(pfn);
   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 

+ 11 - 96
3rdparty/spirv-tools/source/opt/inst_bindless_check_pass.h

@@ -28,16 +28,8 @@ namespace opt {
 // external design may change as the layer evolves.
 class InstBindlessCheckPass : public InstrumentPass {
  public:
-  InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id,
-                        bool desc_idx_enable, bool desc_init_enable,
-                        bool buffer_bounds_enable, bool texel_buffer_enable,
-                        bool opt_direct_reads)
-      : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless,
-                       opt_direct_reads),
-        desc_idx_enabled_(desc_idx_enable),
-        desc_init_enabled_(desc_init_enable),
-        buffer_bounds_enabled_(buffer_bounds_enable),
-        texel_buffer_enabled_(texel_buffer_enable) {}
+  InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id)
+      : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless, true) {}
 
   ~InstBindlessCheckPass() override = default;
 
@@ -47,82 +39,18 @@ class InstBindlessCheckPass : public InstrumentPass {
   const char* name() const override { return "inst-bindless-check-pass"; }
 
  private:
-  // These functions do bindless checking instrumentation on a single
-  // instruction which references through a descriptor (ie references into an
-  // image or buffer). Refer to Vulkan API for further information on
-  // descriptors. GenDescIdxCheckCode checks that an index into a descriptor
-  // array (array of images or buffers) is in-bounds. GenDescInitCheckCode
-  // checks that the referenced descriptor has been initialized, if the
-  // SPV_EXT_descriptor_indexing extension is enabled, and initialized large
-  // enough to handle the reference, if RobustBufferAccess is disabled.
-  // GenDescInitCheckCode checks for uniform and storage buffer overrun.
-  // GenTexBuffCheckCode checks for texel buffer overrun and should be
-  // run after GenDescInitCheckCode to first make sure that the descriptor
-  // is initialized because it uses OpImageQuerySize on the descriptor.
-  //
-  // The functions are designed to be passed to
-  // InstrumentPass::InstProcessEntryPointCallTree(), which applies the
-  // function to each instruction in a module and replaces the instruction
-  // if warranted.
-  //
-  // If |ref_inst_itr| is a bindless reference, return in |new_blocks| the
-  // result of instrumenting it with validation code within its block at
-  // |ref_block_itr|.  The validation code first executes a check for the
-  // specific condition called for. If the check passes, it executes
-  // the remainder of the reference, otherwise writes a record to the debug
-  // output buffer stream including |function_idx, instruction_idx, stage_idx|
-  // and replaces the reference with the null value of the original type. The
-  // block at |ref_block_itr| can just be replaced with the blocks in
-  // |new_blocks|, which will contain at least two blocks. The last block will
-  // comprise all instructions following |ref_inst_itr|,
-  // preceded by a phi instruction.
-  //
-  // These instrumentation functions utilize GenDebugDirectRead() to read data
-  // from the debug input buffer, specifically the lengths of variable length
-  // descriptor arrays, and the initialization status of each descriptor.
-  // The format of the debug input buffer is documented in instrument.hpp.
-  //
-  // These instrumentation functions utilize GenDebugStreamWrite() to write its
-  // error records. The validation-specific part of the error record will
-  // have the format:
-  //
-  //    Validation Error Code (=kInstErrorBindlessBounds)
-  //    Descriptor Index
-  //    Descriptor Array Size
-  //
-  // The Descriptor Index is the index which has been determined to be
-  // out-of-bounds.
-  //
-  // The Descriptor Array Size is the size of the descriptor array which was
-  // indexed.
-  void GenDescIdxCheckCode(
-      BasicBlock::iterator ref_inst_itr,
-      UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
-      std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
-
-  void GenDescInitCheckCode(
-      BasicBlock::iterator ref_inst_itr,
-      UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
-      std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
-
-  void GenTexBuffCheckCode(
-      BasicBlock::iterator ref_inst_itr,
-      UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
-      std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+  void GenDescCheckCode(BasicBlock::iterator ref_inst_itr,
+                        UptrVectorIterator<BasicBlock> ref_block_itr,
+                        uint32_t stage_idx,
+                        std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
 
   void SetupInputBufferIds();
-  uint32_t GenDebugReadLengthFunctionId();
 
-  // Generate instructions into |builder| to read length of runtime descriptor
-  // array |var_id| from debug input buffer and return id of value.
-  uint32_t GenDebugReadLength(uint32_t var_id, InstructionBuilder* builder);
+  uint32_t GenDescCheckFunctionId();
 
-  uint32_t GenDebugReadInitFunctionId();
-  // Generate instructions into |builder| to read initialization status of
-  // descriptor array |image_id| at |index_id| from debug input buffer and
-  // return id of value.
-  uint32_t GenDebugReadInit(uint32_t image_id, uint32_t index_id,
-                            InstructionBuilder* builder);
+  uint32_t GenDescCheckCall(uint32_t inst_idx, uint32_t stage_idx,
+                            uint32_t var_id, uint32_t index_id,
+                            uint32_t byte_offset, InstructionBuilder* builder);
 
   // Analysis data for descriptor reference components, generated by
   // AnalyzeDescriptorReference. It is necessary and sufficient for further
@@ -190,26 +118,13 @@ class InstBindlessCheckPass : public InstrumentPass {
   // GenDescInitCheckCode to every instruction in module.
   Pass::Status ProcessImpl();
 
-  // Enable instrumentation of runtime array length checking
-  bool desc_idx_enabled_;
-
-  // Enable instrumentation of descriptor initialization checking
-  bool desc_init_enabled_;
-
-  // Enable instrumentation of uniform and storage buffer overrun checking
-  bool buffer_bounds_enabled_;
-
-  // Enable instrumentation of texel buffer overrun checking
-  bool texel_buffer_enabled_;
-
   // Mapping from variable to descriptor set
   std::unordered_map<uint32_t, uint32_t> var2desc_set_;
 
   // Mapping from variable to binding
   std::unordered_map<uint32_t, uint32_t> var2binding_;
 
-  uint32_t read_length_func_id_{0};
-  uint32_t read_init_func_id_{0};
+  uint32_t desc_check_func_id_{0};
   uint32_t desc_set_type_id_{0};
   uint32_t desc_set_ptr_id_{0};
   uint32_t input_buffer_struct_id_{0};

+ 16 - 48
3rdparty/spirv-tools/source/opt/inst_buff_addr_check_pass.cpp

@@ -113,7 +113,9 @@ void InstBuffAddrCheckPass::GenCheckCode(
   Instruction* hi_uptr_inst = builder.AddUnaryOp(
       GetUintId(), spv::Op::OpUConvert, rshift_uptr_inst->result_id());
   GenDebugStreamWrite(
-      uid2offset_[ref_inst->unique_id()], stage_idx,
+      builder.GetUintConstantId(shader_id_),
+      builder.GetUintConstantId(uid2offset_[ref_inst->unique_id()]),
+      GenStageInfo(stage_idx, &builder),
       {error_id, lo_uptr_inst->result_id(), hi_uptr_inst->result_id()},
       &builder);
   // Gen zero for invalid load. If pointer type, need to convert uint64
@@ -150,48 +152,13 @@ void InstBuffAddrCheckPass::GenCheckCode(
   context()->KillInst(ref_inst);
 }
 
-uint32_t InstBuffAddrCheckPass::GetTypeAlignment(uint32_t type_id) {
-  Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
-  switch (type_inst->opcode()) {
-    case spv::Op::OpTypeFloat:
-    case spv::Op::OpTypeInt:
-    case spv::Op::OpTypeVector:
-      return GetTypeLength(type_id);
-    case spv::Op::OpTypeMatrix:
-      return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
-    case spv::Op::OpTypeArray:
-    case spv::Op::OpTypeRuntimeArray:
-      return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
-    case spv::Op::OpTypeStruct: {
-      uint32_t max = 0;
-      type_inst->ForEachInId([&max, this](const uint32_t* iid) {
-        uint32_t alignment = GetTypeAlignment(*iid);
-        max = (alignment > max) ? alignment : max;
-      });
-      return max;
-    }
-    case spv::Op::OpTypePointer:
-      assert(spv::StorageClass(type_inst->GetSingleWordInOperand(0)) ==
-                 spv::StorageClass::PhysicalStorageBufferEXT &&
-             "unexpected pointer type");
-      return 8u;
-    default:
-      assert(false && "unexpected type");
-      return 0;
-  }
-}
-
 uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
   switch (type_inst->opcode()) {
     case spv::Op::OpTypeFloat:
     case spv::Op::OpTypeInt:
       return type_inst->GetSingleWordInOperand(0) / 8u;
-    case spv::Op::OpTypeVector: {
-      uint32_t raw_cnt = type_inst->GetSingleWordInOperand(1);
-      uint32_t adj_cnt = (raw_cnt == 3u) ? 4u : raw_cnt;
-      return adj_cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
-    }
+    case spv::Op::OpTypeVector:
     case spv::Op::OpTypeMatrix:
       return type_inst->GetSingleWordInOperand(1) *
              GetTypeLength(type_inst->GetSingleWordInOperand(0));
@@ -207,18 +174,19 @@ uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
       return cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
     }
     case spv::Op::OpTypeStruct: {
-      uint32_t len = 0;
-      type_inst->ForEachInId([&len, this](const uint32_t* iid) {
-        // Align struct length
-        uint32_t alignment = GetTypeAlignment(*iid);
-        uint32_t mod = len % alignment;
-        uint32_t diff = (mod != 0) ? alignment - mod : 0;
-        len += diff;
-        // Increment struct length by component length
-        uint32_t comp_len = GetTypeLength(*iid);
-        len += comp_len;
+      // Figure out the location of the last byte of the last member of the
+      // structure.
+      uint32_t last_offset = 0, last_len = 0;
+
+      get_decoration_mgr()->ForEachDecoration(
+          type_id, uint32_t(spv::Decoration::Offset),
+          [&last_offset](const Instruction& deco_inst) {
+            last_offset = deco_inst.GetSingleWordInOperand(3);
+          });
+      type_inst->ForEachInId([&last_len, this](const uint32_t* iid) {
+        last_len = GetTypeLength(*iid);
       });
-      return len;
+      return last_offset + last_len;
     }
     case spv::Op::OpTypeRuntimeArray:
     default:

+ 0 - 4
3rdparty/spirv-tools/source/opt/inst_buff_addr_check_pass.h

@@ -45,10 +45,6 @@ class InstBuffAddrCheckPass : public InstrumentPass {
                           InstProcessFunction& pfn) override;
 
  private:
-  // Return byte alignment of type |type_id|. Must be int, float, vector,
-  // matrix, struct, array or physical pointer. Uses std430 alignment.
-  uint32_t GetTypeAlignment(uint32_t type_id);
-
   // Return byte length of type |type_id|. Must be int, float, vector, matrix,
   // struct, array or physical pointer. Uses std430 alignment and sizes.
   uint32_t GetTypeLength(uint32_t type_id);

+ 4 - 2
3rdparty/spirv-tools/source/opt/inst_debug_printf_pass.cpp

@@ -165,8 +165,10 @@ void InstDebugPrintfPass::GenOutputCode(
           GenOutputValues(opnd_inst, &val_ids, &builder);
         }
       });
-  GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
-                      &builder);
+  GenDebugStreamWrite(
+      builder.GetUintConstantId(shader_id_),
+      builder.GetUintConstantId(uid2offset_[printf_inst->unique_id()]),
+      GenStageInfo(stage_idx, &builder), val_ids, &builder);
   context()->KillInst(printf_inst);
 }
 

+ 25 - 3
3rdparty/spirv-tools/source/opt/instruction.cpp

@@ -751,7 +751,7 @@ bool Instruction::IsOpaqueType() const {
 }
 
 bool Instruction::IsFoldable() const {
-  return IsFoldableByFoldScalar() ||
+  return IsFoldableByFoldScalar() || IsFoldableByFoldVector() ||
          context()->get_instruction_folder().HasConstFoldingRule(this);
 }
 
@@ -762,7 +762,7 @@ bool Instruction::IsFoldableByFoldScalar() const {
   }
 
   Instruction* type = context()->get_def_use_mgr()->GetDef(type_id());
-  if (!folder.IsFoldableType(type)) {
+  if (!folder.IsFoldableScalarType(type)) {
     return false;
   }
 
@@ -773,7 +773,29 @@ bool Instruction::IsFoldableByFoldScalar() const {
     Instruction* def_inst = context()->get_def_use_mgr()->GetDef(*op_id);
     Instruction* def_inst_type =
         context()->get_def_use_mgr()->GetDef(def_inst->type_id());
-    return folder.IsFoldableType(def_inst_type);
+    return folder.IsFoldableScalarType(def_inst_type);
+  });
+}
+
+bool Instruction::IsFoldableByFoldVector() const {
+  const InstructionFolder& folder = context()->get_instruction_folder();
+  if (!folder.IsFoldableOpcode(opcode())) {
+    return false;
+  }
+
+  Instruction* type = context()->get_def_use_mgr()->GetDef(type_id());
+  if (!folder.IsFoldableVectorType(type)) {
+    return false;
+  }
+
+  // Even if the type of the instruction is foldable, its operands may not be
+  // foldable (e.g., comparisons of 64bit types).  Check that all operand types
+  // are foldable before accepting the instruction.
+  return WhileEachInOperand([&folder, this](const uint32_t* op_id) {
+    Instruction* def_inst = context()->get_def_use_mgr()->GetDef(*op_id);
+    Instruction* def_inst_type =
+        context()->get_def_use_mgr()->GetDef(def_inst->type_id());
+    return folder.IsFoldableVectorType(def_inst_type);
   });
 }
 

+ 10 - 0
3rdparty/spirv-tools/source/opt/instruction.h

@@ -294,6 +294,8 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // It is the responsibility of the caller to make sure
   // that the instruction remains valid.
   inline void AddOperand(Operand&& operand);
+  // Adds a copy of |operand| to the list of operands of this instruction.
+  inline void AddOperand(const Operand& operand);
   // Gets the |index|-th logical operand as a single SPIR-V word. This method is
   // not expected to be used with logical operands consisting of multiple SPIR-V
   // words.
@@ -522,6 +524,10 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // constant value by |FoldScalar|.
   bool IsFoldableByFoldScalar() const;
 
+  // Returns true if |this| is an instruction which could be folded into a
+  // constant value by |FoldVector|.
+  bool IsFoldableByFoldVector() const;
+
   // Returns true if we are allowed to fold or otherwise manipulate the
   // instruction that defines |id| in the given context. This includes not
   // handling NaN values.
@@ -676,6 +682,10 @@ inline void Instruction::AddOperand(Operand&& operand) {
   operands_.push_back(std::move(operand));
 }
 
+inline void Instruction::AddOperand(const Operand& operand) {
+  operands_.push_back(operand);
+}
+
 inline void Instruction::SetInOperand(uint32_t index,
                                       Operand::OperandData&& data) {
   SetOperand(index + TypeResultIdCount(), std::move(data));

+ 92 - 115
3rdparty/spirv-tools/source/opt/instrument_pass.cpp

@@ -22,9 +22,6 @@
 namespace spvtools {
 namespace opt {
 namespace {
-// Common Parameter Positions
-constexpr int kInstCommonParamInstIdx = 0;
-constexpr int kInstCommonParamCnt = 1;
 // Indices of operands in SPIR-V instructions
 constexpr int kEntryPointFunctionIdInIdx = 1;
 }  // namespace
@@ -216,34 +213,6 @@ void InstrumentPass::GenDebugOutputFieldCode(uint32_t base_offset_id,
   (void)builder->AddStore(achain_inst->result_id(), val_id);
 }
 
-void InstrumentPass::GenCommonStreamWriteCode(uint32_t record_sz,
-                                              uint32_t inst_id,
-                                              uint32_t stage_idx,
-                                              uint32_t base_offset_id,
-                                              InstructionBuilder* builder) {
-  // Store record size
-  GenDebugOutputFieldCode(base_offset_id, kInstCommonOutSize,
-                          builder->GetUintConstantId(record_sz), builder);
-  // Store Shader Id
-  GenDebugOutputFieldCode(base_offset_id, kInstCommonOutShaderId,
-                          builder->GetUintConstantId(shader_id_), builder);
-  // Store Instruction Idx
-  GenDebugOutputFieldCode(base_offset_id, kInstCommonOutInstructionIdx, inst_id,
-                          builder);
-  // Store Stage Idx
-  GenDebugOutputFieldCode(base_offset_id, kInstCommonOutStageIdx,
-                          builder->GetUintConstantId(stage_idx), builder);
-}
-
-void InstrumentPass::GenFragCoordEltDebugOutputCode(
-    uint32_t base_offset_id, uint32_t uint_frag_coord_id, uint32_t element,
-    InstructionBuilder* builder) {
-  Instruction* element_val_inst =
-      builder->AddCompositeExtract(GetUintId(), uint_frag_coord_id, {element});
-  GenDebugOutputFieldCode(base_offset_id, kInstFragOutFragCoordX + element,
-                          element_val_inst->result_id(), builder);
-}
-
 uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
                                     InstructionBuilder* builder) {
   Instruction* var_inst = get_def_use_mgr()->GetDef(var_id);
@@ -252,28 +221,23 @@ uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
   return load_inst->result_id();
 }
 
-void InstrumentPass::GenBuiltinOutputCode(uint32_t builtin_id,
-                                          uint32_t builtin_off,
-                                          uint32_t base_offset_id,
-                                          InstructionBuilder* builder) {
-  // Load and store builtin
-  uint32_t load_id = GenVarLoad(builtin_id, builder);
-  GenDebugOutputFieldCode(base_offset_id, builtin_off, load_id, builder);
-}
-
-void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
-                                             uint32_t base_offset_id,
-                                             InstructionBuilder* builder) {
+uint32_t InstrumentPass::GenStageInfo(uint32_t stage_idx,
+                                      InstructionBuilder* builder) {
+  std::vector<uint32_t> ids(4, builder->GetUintConstantId(0));
+  ids[0] = builder->GetUintConstantId(stage_idx);
+  // %289 = OpCompositeConstruct %v4uint %uint_0 %285 %288 %uint_0
   // TODO(greg-lunarg): Add support for all stages
   switch (spv::ExecutionModel(stage_idx)) {
     case spv::ExecutionModel::Vertex: {
       // Load and store VertexId and InstanceId
-      GenBuiltinOutputCode(
+      uint32_t load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::VertexIndex)),
-          kInstVertOutVertexIndex, base_offset_id, builder);
-      GenBuiltinOutputCode(context()->GetBuiltinInputVarId(
+          builder);
+      ids[1] = load_id;
+      load_id = GenVarLoad(context()->GetBuiltinInputVarId(
                                uint32_t(spv::BuiltIn::InstanceIndex)),
-                           kInstVertOutInstanceIndex, base_offset_id, builder);
+                           builder);
+      ids[2] = load_id;
     } break;
     case spv::ExecutionModel::GLCompute:
     case spv::ExecutionModel::TaskNV:
@@ -284,56 +248,50 @@ void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
       uint32_t load_id = GenVarLoad(context()->GetBuiltinInputVarId(uint32_t(
                                         spv::BuiltIn::GlobalInvocationId)),
                                     builder);
-      Instruction* x_inst =
-          builder->AddCompositeExtract(GetUintId(), load_id, {0});
-      Instruction* y_inst =
-          builder->AddCompositeExtract(GetUintId(), load_id, {1});
-      Instruction* z_inst =
-          builder->AddCompositeExtract(GetUintId(), load_id, {2});
-      GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdX,
-                              x_inst->result_id(), builder);
-      GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdY,
-                              y_inst->result_id(), builder);
-      GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdZ,
-                              z_inst->result_id(), builder);
+      for (uint32_t u = 0; u < 3u; ++u) {
+        ids[u + 1] = builder->AddCompositeExtract(GetUintId(), load_id, {u})
+                         ->result_id();
+      }
     } break;
     case spv::ExecutionModel::Geometry: {
       // Load and store PrimitiveId and InvocationId.
-      GenBuiltinOutputCode(
+      uint32_t load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
-          kInstGeomOutPrimitiveId, base_offset_id, builder);
-      GenBuiltinOutputCode(
+          builder);
+      ids[1] = load_id;
+      load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
-          kInstGeomOutInvocationId, base_offset_id, builder);
+          builder);
+      ids[2] = load_id;
     } break;
     case spv::ExecutionModel::TessellationControl: {
       // Load and store InvocationId and PrimitiveId
-      GenBuiltinOutputCode(
+      uint32_t load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
-          kInstTessCtlOutInvocationId, base_offset_id, builder);
-      GenBuiltinOutputCode(
+          builder);
+      ids[1] = load_id;
+      load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
-          kInstTessCtlOutPrimitiveId, base_offset_id, builder);
+          builder);
+      ids[2] = load_id;
     } break;
     case spv::ExecutionModel::TessellationEvaluation: {
       // Load and store PrimitiveId and TessCoord.uv
-      GenBuiltinOutputCode(
-          context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
-          kInstTessEvalOutPrimitiveId, base_offset_id, builder);
       uint32_t load_id = GenVarLoad(
+          context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
+          builder);
+      ids[1] = load_id;
+      load_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::TessCoord)),
           builder);
       Instruction* uvec3_cast_inst =
           builder->AddUnaryOp(GetVec3UintId(), spv::Op::OpBitcast, load_id);
       uint32_t uvec3_cast_id = uvec3_cast_inst->result_id();
-      Instruction* u_inst =
-          builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {0});
-      Instruction* v_inst =
-          builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {1});
-      GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordU,
-                              u_inst->result_id(), builder);
-      GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordV,
-                              v_inst->result_id(), builder);
+      for (uint32_t u = 0; u < 2u; ++u) {
+        ids[u + 2] =
+            builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {u})
+                ->result_id();
+      }
     } break;
     case spv::ExecutionModel::Fragment: {
       // Load FragCoord and convert to Uint
@@ -342,9 +300,13 @@ void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::FragCoord)));
       Instruction* uint_frag_coord_inst = builder->AddUnaryOp(
           GetVec4UintId(), spv::Op::OpBitcast, frag_coord_inst->result_id());
-      for (uint32_t u = 0; u < 2u; ++u)
-        GenFragCoordEltDebugOutputCode(
-            base_offset_id, uint_frag_coord_inst->result_id(), u, builder);
+      for (uint32_t u = 0; u < 2u; ++u) {
+        ids[u + 1] =
+            builder
+                ->AddCompositeExtract(GetUintId(),
+                                      uint_frag_coord_inst->result_id(), {u})
+                ->result_id();
+      }
     } break;
     case spv::ExecutionModel::RayGenerationNV:
     case spv::ExecutionModel::IntersectionNV:
@@ -356,33 +318,26 @@ void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
       uint32_t launch_id = GenVarLoad(
           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::LaunchIdNV)),
           builder);
-      Instruction* x_launch_inst =
-          builder->AddCompositeExtract(GetUintId(), launch_id, {0});
-      Instruction* y_launch_inst =
-          builder->AddCompositeExtract(GetUintId(), launch_id, {1});
-      Instruction* z_launch_inst =
-          builder->AddCompositeExtract(GetUintId(), launch_id, {2});
-      GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdX,
-                              x_launch_inst->result_id(), builder);
-      GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdY,
-                              y_launch_inst->result_id(), builder);
-      GenDebugOutputFieldCode(base_offset_id, kInstRayTracingOutLaunchIdZ,
-                              z_launch_inst->result_id(), builder);
+      for (uint32_t u = 0; u < 3u; ++u) {
+        ids[u + 1] = builder->AddCompositeExtract(GetUintId(), launch_id, {u})
+                         ->result_id();
+      }
     } break;
     default: { assert(false && "unsupported stage"); } break;
   }
+  return builder->AddCompositeConstruct(GetVec4UintId(), ids)->result_id();
 }
 
 void InstrumentPass::GenDebugStreamWrite(
-    uint32_t instruction_idx, uint32_t stage_idx,
+    uint32_t shader_id, uint32_t instruction_idx_id, uint32_t stage_info_id,
     const std::vector<uint32_t>& validation_ids, InstructionBuilder* builder) {
   // Call debug output function. Pass func_idx, instruction_idx and
   // validation ids as args.
   uint32_t val_id_cnt = static_cast<uint32_t>(validation_ids.size());
-  std::vector<uint32_t> args = {builder->GetUintConstantId(instruction_idx)};
+  std::vector<uint32_t> args = {shader_id, instruction_idx_id, stage_info_id};
   (void)args.insert(args.end(), validation_ids.begin(), validation_ids.end());
-  (void)builder->AddFunctionCall(
-      GetVoidId(), GetStreamWriteFunctionId(stage_idx, val_id_cnt), args);
+  (void)builder->AddFunctionCall(GetVoidId(),
+                                 GetStreamWriteFunctionId(val_id_cnt), args);
 }
 
 bool InstrumentPass::AllConstant(const std::vector<uint32_t>& ids) {
@@ -398,11 +353,12 @@ uint32_t InstrumentPass::GenDebugDirectRead(
   // Call debug input function. Pass func_idx and offset ids as args.
   const uint32_t off_id_cnt = static_cast<uint32_t>(offset_ids.size());
   const uint32_t input_func_id = GetDirectReadFunctionId(off_id_cnt);
-  return GenReadFunctionCall(input_func_id, offset_ids, builder);
+  return GenReadFunctionCall(GetUintId(), input_func_id, offset_ids, builder);
 }
 
 uint32_t InstrumentPass::GenReadFunctionCall(
-    uint32_t func_id, const std::vector<uint32_t>& func_call_args,
+    uint32_t return_id, uint32_t func_id,
+    const std::vector<uint32_t>& func_call_args,
     InstructionBuilder* ref_builder) {
   // If optimizing direct reads and the call has already been generated,
   // use its result
@@ -423,8 +379,7 @@ uint32_t InstrumentPass::GenReadFunctionCall(
     builder.SetInsertPoint(insert_before);
   }
   uint32_t res_id =
-      builder.AddFunctionCall(GetUintId(), func_id, func_call_args)
-          ->result_id();
+      builder.AddFunctionCall(return_id, func_id, func_call_args)->result_id();
   if (insert_in_first_block) call2id_[func_call_args] = res_id;
   return res_id;
 }
@@ -817,18 +772,27 @@ uint32_t InstrumentPass::GetVoidId() {
   return void_id_;
 }
 
-uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
-                                                  uint32_t val_spec_param_cnt) {
+uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t param_cnt) {
+  enum {
+    kShaderId = 0,
+    kInstructionIndex = 1,
+    kStageInfo = 2,
+    kFirstParam = 3,
+  };
   // Total param count is common params plus validation-specific
   // params
-  uint32_t param_cnt = kInstCommonParamCnt + val_spec_param_cnt;
   if (param2output_func_id_[param_cnt] == 0) {
     // Create function
     param2output_func_id_[param_cnt] = TakeNextId();
     analysis::TypeManager* type_mgr = context()->get_type_mgr();
 
-    const std::vector<const analysis::Type*> param_types(param_cnt,
-                                                         GetInteger(32, false));
+    const analysis::Type* uint_type = GetInteger(32, false);
+    const analysis::Vector v4uint(uint_type, 4);
+    const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
+
+    std::vector<const analysis::Type*> param_types(kFirstParam + param_cnt,
+                                                   uint_type);
+    param_types[kStageInfo] = v4uint_type;
     std::unique_ptr<Function> output_func = StartFunction(
         param2output_func_id_[param_cnt], type_mgr->GetVoidType(), param_types);
 
@@ -841,10 +805,10 @@ uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
         context(), &*new_blk_ptr,
         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
     // Gen test if debug output buffer size will not be exceeded.
-    uint32_t val_spec_offset = kInstStageOutCnt;
-    uint32_t obuf_record_sz = val_spec_offset + val_spec_param_cnt;
-    uint32_t buf_id = GetOutputBufferId();
-    uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
+    const uint32_t val_spec_offset = kInstStageOutCnt;
+    const uint32_t obuf_record_sz = val_spec_offset + param_cnt;
+    const uint32_t buf_id = GetOutputBufferId();
+    const uint32_t buf_uint_ptr_id = GetOutputBufferPtrId();
     Instruction* obuf_curr_sz_ac_inst = builder.AddAccessChain(
         buf_uint_ptr_id, buf_id,
         {builder.GetUintConstantId(kDebugOutputSizeOffset)});
@@ -884,13 +848,26 @@ uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
     new_blk_ptr = MakeUnique<BasicBlock>(std::move(write_label));
     builder.SetInsertPoint(&*new_blk_ptr);
     // Generate common and stage-specific debug record members
-    GenCommonStreamWriteCode(obuf_record_sz, param_ids[kInstCommonParamInstIdx],
-                             stage_idx, obuf_curr_sz_id, &builder);
-    GenStageStreamWriteCode(stage_idx, obuf_curr_sz_id, &builder);
+    GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutSize,
+                            builder.GetUintConstantId(obuf_record_sz),
+                            &builder);
+    // Store Shader Id
+    GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutShaderId,
+                            param_ids[kShaderId], &builder);
+    // Store Instruction Idx
+    GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutInstructionIdx,
+                            param_ids[kInstructionIndex], &builder);
+    // Store stage info. Stage Idx + 3 words of stage-specific data.
+    for (uint32_t i = 0; i < 4; ++i) {
+      Instruction* field =
+          builder.AddCompositeExtract(GetUintId(), param_ids[kStageInfo], {i});
+      GenDebugOutputFieldCode(obuf_curr_sz_id, kInstCommonOutStageIdx + i,
+                              field->result_id(), &builder);
+    }
     // Gen writes of validation specific data
-    for (uint32_t i = 0; i < val_spec_param_cnt; ++i) {
+    for (uint32_t i = 0; i < param_cnt; ++i) {
       GenDebugOutputFieldCode(obuf_curr_sz_id, val_spec_offset + i,
-                              param_ids[kInstCommonParamCnt + i], &builder);
+                              param_ids[kFirstParam + i], &builder);
     }
     // Close write block and gen merge block
     (void)builder.AddBranch(merge_blk_id);

+ 5 - 28
3rdparty/spirv-tools/source/opt/instrument_pass.h

@@ -196,7 +196,8 @@ class InstrumentPass : public Pass {
   // Because the code that is generated checks against the size of the buffer
   // before writing, the size of the debug out buffer can be used by the
   // validation layer to control the number of error records that are written.
-  void GenDebugStreamWrite(uint32_t instruction_idx, uint32_t stage_idx,
+  void GenDebugStreamWrite(uint32_t shader_id, uint32_t instruction_idx_id,
+                           uint32_t stage_info_id,
                            const std::vector<uint32_t>& validation_ids,
                            InstructionBuilder* builder);
 
@@ -214,7 +215,7 @@ class InstrumentPass : public Pass {
   uint32_t GenDebugDirectRead(const std::vector<uint32_t>& offset_ids,
                               InstructionBuilder* builder);
 
-  uint32_t GenReadFunctionCall(uint32_t func_id,
+  uint32_t GenReadFunctionCall(uint32_t return_id, uint32_t func_id,
                                const std::vector<uint32_t>& args,
                                InstructionBuilder* builder);
 
@@ -323,8 +324,7 @@ class InstrumentPass : public Pass {
 
   // Return id for output function. Define if it doesn't exist with
   // |val_spec_param_cnt| validation-specific uint32 parameters.
-  uint32_t GetStreamWriteFunctionId(uint32_t stage_idx,
-                                    uint32_t val_spec_param_cnt);
+  uint32_t GetStreamWriteFunctionId(uint32_t val_spec_param_cnt);
 
   // Return id for input function taking |param_cnt| uint32 parameters. Define
   // if it doesn't exist.
@@ -355,34 +355,11 @@ class InstrumentPass : public Pass {
                                uint32_t field_value_id,
                                InstructionBuilder* builder);
 
-  // Generate instructions into |builder| which will write the members
-  // of the debug output record common for all stages and validations at
-  // |base_off|.
-  void GenCommonStreamWriteCode(uint32_t record_sz, uint32_t instruction_idx,
-                                uint32_t stage_idx, uint32_t base_off,
-                                InstructionBuilder* builder);
-
-  // Generate instructions into |builder| which will write
-  // |uint_frag_coord_id| at |component| of the record at |base_offset_id| of
-  // the debug output buffer .
-  void GenFragCoordEltDebugOutputCode(uint32_t base_offset_id,
-                                      uint32_t uint_frag_coord_id,
-                                      uint32_t component,
-                                      InstructionBuilder* builder);
-
   // Generate instructions into |builder| which will load |var_id| and return
   // its result id.
   uint32_t GenVarLoad(uint32_t var_id, InstructionBuilder* builder);
 
-  // Generate instructions into |builder| which will load the uint |builtin_id|
-  // and write it into the debug output buffer at |base_off| + |builtin_off|.
-  void GenBuiltinOutputCode(uint32_t builtin_id, uint32_t builtin_off,
-                            uint32_t base_off, InstructionBuilder* builder);
-
-  // Generate instructions into |builder| which will write the |stage_idx|-
-  // specific members of the debug output stream at |base_off|.
-  void GenStageStreamWriteCode(uint32_t stage_idx, uint32_t base_off,
-                               InstructionBuilder* builder);
+  uint32_t GenStageInfo(uint32_t stage_idx, InstructionBuilder* builder);
 
   // Return true if instruction must be in the same block that its result
   // is used.

+ 16 - 0
3rdparty/spirv-tools/source/opt/ir_builder.h

@@ -440,6 +440,22 @@ class InstructionBuilder {
     return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
   }
 
+  Instruction* GetBoolConstant(bool value) {
+    analysis::Bool type;
+    uint32_t type_id = GetContext()->get_type_mgr()->GetTypeInstruction(&type);
+    analysis::Type* rebuilt_type =
+        GetContext()->get_type_mgr()->GetType(type_id);
+    uint32_t word = value;
+    const analysis::Constant* constant =
+        GetContext()->get_constant_mgr()->GetConstant(rebuilt_type, {word});
+    return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
+  }
+
+  uint32_t GetBoolConstantId(bool value) {
+    Instruction* inst = GetBoolConstant(value);
+    return (inst != nullptr ? inst->result_id() : 0);
+  }
+
   Instruction* AddCompositeExtract(uint32_t type, uint32_t id_of_composite,
                                    const std::vector<uint32_t>& index_list) {
     std::vector<Operand> operands;

+ 2 - 1
3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp

@@ -426,7 +426,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
        "SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
        "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
        "SPV_KHR_uniform_group_instructions",
-       "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model"});
+       "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
+       "SPV_NV_bindless_texture"});
 }
 
 bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

+ 56 - 53
3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp

@@ -233,59 +233,62 @@ Pass::Status LocalSingleBlockLoadStoreElimPass::Process() {
 
 void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
   extensions_allowlist_.clear();
-  extensions_allowlist_.insert({"SPV_AMD_shader_explicit_vertex_parameter",
-                                "SPV_AMD_shader_trinary_minmax",
-                                "SPV_AMD_gcn_shader",
-                                "SPV_KHR_shader_ballot",
-                                "SPV_AMD_shader_ballot",
-                                "SPV_AMD_gpu_shader_half_float",
-                                "SPV_KHR_shader_draw_parameters",
-                                "SPV_KHR_subgroup_vote",
-                                "SPV_KHR_8bit_storage",
-                                "SPV_KHR_16bit_storage",
-                                "SPV_KHR_device_group",
-                                "SPV_KHR_multiview",
-                                "SPV_NVX_multiview_per_view_attributes",
-                                "SPV_NV_viewport_array2",
-                                "SPV_NV_stereo_view_rendering",
-                                "SPV_NV_sample_mask_override_coverage",
-                                "SPV_NV_geometry_shader_passthrough",
-                                "SPV_AMD_texture_gather_bias_lod",
-                                "SPV_KHR_storage_buffer_storage_class",
-                                "SPV_KHR_variable_pointers",
-                                "SPV_AMD_gpu_shader_int16",
-                                "SPV_KHR_post_depth_coverage",
-                                "SPV_KHR_shader_atomic_counter_ops",
-                                "SPV_EXT_shader_stencil_export",
-                                "SPV_EXT_shader_viewport_index_layer",
-                                "SPV_AMD_shader_image_load_store_lod",
-                                "SPV_AMD_shader_fragment_mask",
-                                "SPV_EXT_fragment_fully_covered",
-                                "SPV_AMD_gpu_shader_half_float_fetch",
-                                "SPV_GOOGLE_decorate_string",
-                                "SPV_GOOGLE_hlsl_functionality1",
-                                "SPV_GOOGLE_user_type",
-                                "SPV_NV_shader_subgroup_partitioned",
-                                "SPV_EXT_demote_to_helper_invocation",
-                                "SPV_EXT_descriptor_indexing",
-                                "SPV_NV_fragment_shader_barycentric",
-                                "SPV_NV_compute_shader_derivatives",
-                                "SPV_NV_shader_image_footprint",
-                                "SPV_NV_shading_rate",
-                                "SPV_NV_mesh_shader",
-                                "SPV_NV_ray_tracing",
-                                "SPV_KHR_ray_tracing",
-                                "SPV_KHR_ray_query",
-                                "SPV_EXT_fragment_invocation_density",
-                                "SPV_EXT_physical_storage_buffer",
-                                "SPV_KHR_terminate_invocation",
-                                "SPV_KHR_subgroup_uniform_control_flow",
-                                "SPV_KHR_integer_dot_product",
-                                "SPV_EXT_shader_image_int64",
-                                "SPV_KHR_non_semantic_info",
-                                "SPV_KHR_uniform_group_instructions",
-                                "SPV_KHR_fragment_shader_barycentric",
-                                "SPV_KHR_vulkan_memory_model"});
+  extensions_allowlist_.insert({
+      "SPV_AMD_shader_explicit_vertex_parameter",
+      "SPV_AMD_shader_trinary_minmax",
+      "SPV_AMD_gcn_shader",
+      "SPV_KHR_shader_ballot",
+      "SPV_AMD_shader_ballot",
+      "SPV_AMD_gpu_shader_half_float",
+      "SPV_KHR_shader_draw_parameters",
+      "SPV_KHR_subgroup_vote",
+      "SPV_KHR_8bit_storage",
+      "SPV_KHR_16bit_storage",
+      "SPV_KHR_device_group",
+      "SPV_KHR_multiview",
+      "SPV_NVX_multiview_per_view_attributes",
+      "SPV_NV_viewport_array2",
+      "SPV_NV_stereo_view_rendering",
+      "SPV_NV_sample_mask_override_coverage",
+      "SPV_NV_geometry_shader_passthrough",
+      "SPV_AMD_texture_gather_bias_lod",
+      "SPV_KHR_storage_buffer_storage_class",
+      "SPV_KHR_variable_pointers",
+      "SPV_AMD_gpu_shader_int16",
+      "SPV_KHR_post_depth_coverage",
+      "SPV_KHR_shader_atomic_counter_ops",
+      "SPV_EXT_shader_stencil_export",
+      "SPV_EXT_shader_viewport_index_layer",
+      "SPV_AMD_shader_image_load_store_lod",
+      "SPV_AMD_shader_fragment_mask",
+      "SPV_EXT_fragment_fully_covered",
+      "SPV_AMD_gpu_shader_half_float_fetch",
+      "SPV_GOOGLE_decorate_string",
+      "SPV_GOOGLE_hlsl_functionality1",
+      "SPV_GOOGLE_user_type",
+      "SPV_NV_shader_subgroup_partitioned",
+      "SPV_EXT_demote_to_helper_invocation",
+      "SPV_EXT_descriptor_indexing",
+      "SPV_NV_fragment_shader_barycentric",
+      "SPV_NV_compute_shader_derivatives",
+      "SPV_NV_shader_image_footprint",
+      "SPV_NV_shading_rate",
+      "SPV_NV_mesh_shader",
+      "SPV_NV_ray_tracing",
+      "SPV_KHR_ray_tracing",
+      "SPV_KHR_ray_query",
+      "SPV_EXT_fragment_invocation_density",
+      "SPV_EXT_physical_storage_buffer",
+      "SPV_KHR_terminate_invocation",
+      "SPV_KHR_subgroup_uniform_control_flow",
+      "SPV_KHR_integer_dot_product",
+      "SPV_EXT_shader_image_int64",
+      "SPV_KHR_non_semantic_info",
+      "SPV_KHR_uniform_group_instructions",
+      "SPV_KHR_fragment_shader_barycentric",
+      "SPV_KHR_vulkan_memory_model",
+      "SPV_NV_bindless_texture",
+  });
 }
 
 }  // namespace opt

+ 53 - 50
3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp

@@ -86,56 +86,59 @@ Pass::Status LocalSingleStoreElimPass::Process() {
 }
 
 void LocalSingleStoreElimPass::InitExtensionAllowList() {
-  extensions_allowlist_.insert({"SPV_AMD_shader_explicit_vertex_parameter",
-                                "SPV_AMD_shader_trinary_minmax",
-                                "SPV_AMD_gcn_shader",
-                                "SPV_KHR_shader_ballot",
-                                "SPV_AMD_shader_ballot",
-                                "SPV_AMD_gpu_shader_half_float",
-                                "SPV_KHR_shader_draw_parameters",
-                                "SPV_KHR_subgroup_vote",
-                                "SPV_KHR_8bit_storage",
-                                "SPV_KHR_16bit_storage",
-                                "SPV_KHR_device_group",
-                                "SPV_KHR_multiview",
-                                "SPV_NVX_multiview_per_view_attributes",
-                                "SPV_NV_viewport_array2",
-                                "SPV_NV_stereo_view_rendering",
-                                "SPV_NV_sample_mask_override_coverage",
-                                "SPV_NV_geometry_shader_passthrough",
-                                "SPV_AMD_texture_gather_bias_lod",
-                                "SPV_KHR_storage_buffer_storage_class",
-                                "SPV_KHR_variable_pointers",
-                                "SPV_AMD_gpu_shader_int16",
-                                "SPV_KHR_post_depth_coverage",
-                                "SPV_KHR_shader_atomic_counter_ops",
-                                "SPV_EXT_shader_stencil_export",
-                                "SPV_EXT_shader_viewport_index_layer",
-                                "SPV_AMD_shader_image_load_store_lod",
-                                "SPV_AMD_shader_fragment_mask",
-                                "SPV_EXT_fragment_fully_covered",
-                                "SPV_AMD_gpu_shader_half_float_fetch",
-                                "SPV_GOOGLE_decorate_string",
-                                "SPV_GOOGLE_hlsl_functionality1",
-                                "SPV_NV_shader_subgroup_partitioned",
-                                "SPV_EXT_descriptor_indexing",
-                                "SPV_NV_fragment_shader_barycentric",
-                                "SPV_NV_compute_shader_derivatives",
-                                "SPV_NV_shader_image_footprint",
-                                "SPV_NV_shading_rate",
-                                "SPV_NV_mesh_shader",
-                                "SPV_NV_ray_tracing",
-                                "SPV_KHR_ray_query",
-                                "SPV_EXT_fragment_invocation_density",
-                                "SPV_EXT_physical_storage_buffer",
-                                "SPV_KHR_terminate_invocation",
-                                "SPV_KHR_subgroup_uniform_control_flow",
-                                "SPV_KHR_integer_dot_product",
-                                "SPV_EXT_shader_image_int64",
-                                "SPV_KHR_non_semantic_info",
-                                "SPV_KHR_uniform_group_instructions",
-                                "SPV_KHR_fragment_shader_barycentric",
-                                "SPV_KHR_vulkan_memory_model"});
+  extensions_allowlist_.insert({
+      "SPV_AMD_shader_explicit_vertex_parameter",
+      "SPV_AMD_shader_trinary_minmax",
+      "SPV_AMD_gcn_shader",
+      "SPV_KHR_shader_ballot",
+      "SPV_AMD_shader_ballot",
+      "SPV_AMD_gpu_shader_half_float",
+      "SPV_KHR_shader_draw_parameters",
+      "SPV_KHR_subgroup_vote",
+      "SPV_KHR_8bit_storage",
+      "SPV_KHR_16bit_storage",
+      "SPV_KHR_device_group",
+      "SPV_KHR_multiview",
+      "SPV_NVX_multiview_per_view_attributes",
+      "SPV_NV_viewport_array2",
+      "SPV_NV_stereo_view_rendering",
+      "SPV_NV_sample_mask_override_coverage",
+      "SPV_NV_geometry_shader_passthrough",
+      "SPV_AMD_texture_gather_bias_lod",
+      "SPV_KHR_storage_buffer_storage_class",
+      "SPV_KHR_variable_pointers",
+      "SPV_AMD_gpu_shader_int16",
+      "SPV_KHR_post_depth_coverage",
+      "SPV_KHR_shader_atomic_counter_ops",
+      "SPV_EXT_shader_stencil_export",
+      "SPV_EXT_shader_viewport_index_layer",
+      "SPV_AMD_shader_image_load_store_lod",
+      "SPV_AMD_shader_fragment_mask",
+      "SPV_EXT_fragment_fully_covered",
+      "SPV_AMD_gpu_shader_half_float_fetch",
+      "SPV_GOOGLE_decorate_string",
+      "SPV_GOOGLE_hlsl_functionality1",
+      "SPV_NV_shader_subgroup_partitioned",
+      "SPV_EXT_descriptor_indexing",
+      "SPV_NV_fragment_shader_barycentric",
+      "SPV_NV_compute_shader_derivatives",
+      "SPV_NV_shader_image_footprint",
+      "SPV_NV_shading_rate",
+      "SPV_NV_mesh_shader",
+      "SPV_NV_ray_tracing",
+      "SPV_KHR_ray_query",
+      "SPV_EXT_fragment_invocation_density",
+      "SPV_EXT_physical_storage_buffer",
+      "SPV_KHR_terminate_invocation",
+      "SPV_KHR_subgroup_uniform_control_flow",
+      "SPV_KHR_integer_dot_product",
+      "SPV_EXT_shader_image_int64",
+      "SPV_KHR_non_semantic_info",
+      "SPV_KHR_uniform_group_instructions",
+      "SPV_KHR_fragment_shader_barycentric",
+      "SPV_KHR_vulkan_memory_model",
+      "SPV_NV_bindless_texture",
+  });
 }
 bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
   std::vector<Instruction*> users;

+ 37 - 40
3rdparty/spirv-tools/source/opt/optimizer.cpp

@@ -109,7 +109,7 @@ Optimizer& Optimizer::RegisterPass(PassToken&& p) {
 // The legalization problem is essentially a very general copy propagation
 // problem.  The optimization we use are all used to either do copy propagation
 // or enable more copy propagation.
-Optimizer& Optimizer::RegisterLegalizationPasses() {
+Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
   return
       // Wrap OpKill instructions so all other code can be inlined.
       RegisterPass(CreateWrapOpKillPass())
@@ -129,16 +129,16 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
           // Propagate the value stored to the loads in very simple cases.
           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
           .RegisterPass(CreateLocalSingleStoreElimPass())
-          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
           // Split up aggregates so they are easier to deal with.
           .RegisterPass(CreateScalarReplacementPass(0))
           // Remove loads and stores so everything is in intermediate values.
           // Takes care of copy propagation of non-members.
           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
           .RegisterPass(CreateLocalSingleStoreElimPass())
-          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
           .RegisterPass(CreateLocalMultiStoreElimPass())
-          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
           // Propagate constants to get as many constant conditions on branches
           // as possible.
           .RegisterPass(CreateCCPPass())
@@ -147,7 +147,7 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
           // Copy propagate members.  Cleans up code sequences generated by
           // scalar replacement.  Also important for removing OpPhi nodes.
           .RegisterPass(CreateSimplificationPass())
-          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
           .RegisterPass(CreateCopyPropagateArraysPass())
           // May need loop unrolling here see
           // https://github.com/Microsoft/DirectXShaderCompiler/pull/930
@@ -156,30 +156,34 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
           .RegisterPass(CreateVectorDCEPass())
           .RegisterPass(CreateDeadInsertElimPass())
           .RegisterPass(CreateReduceLoadSizePass())
-          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
           .RegisterPass(CreateInterpolateFixupPass());
 }
 
-Optimizer& Optimizer::RegisterPerformancePasses() {
+Optimizer& Optimizer::RegisterLegalizationPasses() {
+  return RegisterLegalizationPasses(false);
+}
+
+Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
   return RegisterPass(CreateWrapOpKillPass())
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateMergeReturnPass())
       .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateEliminateDeadFunctionsPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreatePrivateToLocalPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateScalarReplacementPass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateLocalMultiStoreElimPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateCCPPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateLoopUnrollPass(true))
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateRedundancyEliminationPass())
@@ -189,9 +193,9 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateSSARewritePass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateVectorDCEPass())
       .RegisterPass(CreateDeadInsertElimPass())
       .RegisterPass(CreateDeadBranchElimPass())
@@ -199,7 +203,7 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
       .RegisterPass(CreateIfConversionPass())
       .RegisterPass(CreateCopyPropagateArraysPass())
       .RegisterPass(CreateReduceLoadSizePass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateBlockMergePass())
       .RegisterPass(CreateRedundancyEliminationPass())
       .RegisterPass(CreateDeadBranchElimPass())
@@ -207,7 +211,11 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
       .RegisterPass(CreateSimplificationPass());
 }
 
-Optimizer& Optimizer::RegisterSizePasses() {
+Optimizer& Optimizer::RegisterPerformancePasses() {
+  return RegisterPerformancePasses(false);
+}
+
+Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) {
   return RegisterPass(CreateWrapOpKillPass())
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateMergeReturnPass())
@@ -224,12 +232,12 @@ Optimizer& Optimizer::RegisterSizePasses() {
       .RegisterPass(CreateLocalSingleStoreElimPass())
       .RegisterPass(CreateIfConversionPass())
       .RegisterPass(CreateSimplificationPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateBlockMergePass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateCopyPropagateArraysPass())
       .RegisterPass(CreateVectorDCEPass())
       .RegisterPass(CreateDeadInsertElimPass())
@@ -239,10 +247,12 @@ Optimizer& Optimizer::RegisterSizePasses() {
       .RegisterPass(CreateLocalMultiStoreElimPass())
       .RegisterPass(CreateRedundancyEliminationPass())
       .RegisterPass(CreateSimplificationPass())
-      .RegisterPass(CreateAggressiveDCEPass())
+      .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
       .RegisterPass(CreateCFGCleanupPass());
 }
 
+Optimizer& Optimizer::RegisterSizePasses() { return RegisterSizePasses(false); }
+
 bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags) {
   for (const auto& flag : flags) {
     if (!RegisterPassFromFlag(flag)) {
@@ -419,20 +429,11 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
     RegisterPass(CreateWorkaround1209Pass());
   } else if (pass_name == "replace-invalid-opcode") {
     RegisterPass(CreateReplaceInvalidOpcodePass());
-  } else if (pass_name == "inst-bindless-check") {
-    RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false));
-    RegisterPass(CreateSimplificationPass());
-    RegisterPass(CreateDeadBranchElimPass());
-    RegisterPass(CreateBlockMergePass());
-    RegisterPass(CreateAggressiveDCEPass(true));
-  } else if (pass_name == "inst-desc-idx-check") {
-    RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true));
-    RegisterPass(CreateSimplificationPass());
-    RegisterPass(CreateDeadBranchElimPass());
-    RegisterPass(CreateBlockMergePass());
-    RegisterPass(CreateAggressiveDCEPass(true));
-  } else if (pass_name == "inst-buff-oob-check") {
-    RegisterPass(CreateInstBindlessCheckPass(7, 23, false, false, true, true));
+  } else if (pass_name == "inst-bindless-check" ||
+             pass_name == "inst-desc-idx-check" ||
+             pass_name == "inst-buff-oob-check") {
+    // preserve legacy names
+    RegisterPass(CreateInstBindlessCheckPass(7, 23));
     RegisterPass(CreateSimplificationPass());
     RegisterPass(CreateDeadBranchElimPass());
     RegisterPass(CreateBlockMergePass());
@@ -945,14 +946,10 @@ Optimizer::PassToken CreateUpgradeMemoryModelPass() {
       MakeUnique<opt::UpgradeMemoryModel>());
 }
 
-Optimizer::PassToken CreateInstBindlessCheckPass(
-    uint32_t desc_set, uint32_t shader_id, bool desc_length_enable,
-    bool desc_init_enable, bool buff_oob_enable, bool texbuff_oob_enable) {
+Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set,
+                                                 uint32_t shader_id) {
   return MakeUnique<Optimizer::PassToken::Impl>(
-      MakeUnique<opt::InstBindlessCheckPass>(
-          desc_set, shader_id, desc_length_enable, desc_init_enable,
-          buff_oob_enable, texbuff_oob_enable,
-          desc_length_enable || desc_init_enable || buff_oob_enable));
+      MakeUnique<opt::InstBindlessCheckPass>(desc_set, shader_id));
 }
 
 Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,

+ 31 - 0
3rdparty/spirv-tools/source/opt/type_manager.cpp

@@ -423,6 +423,23 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
               {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
       break;
     }
+    case Type::kCooperativeMatrixKHR: {
+      auto coop_mat = type->AsCooperativeMatrixKHR();
+      uint32_t const component_type =
+          GetTypeInstruction(coop_mat->component_type());
+      if (component_type == 0) {
+        return 0;
+      }
+      typeInst = MakeUnique<Instruction>(
+          context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
+          std::initializer_list<Operand>{
+              {SPV_OPERAND_TYPE_ID, {component_type}},
+              {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
+      break;
+    }
     default:
       assert(false && "Unexpected type");
       break;
@@ -628,6 +645,14 @@ Type* TypeManager::RebuildType(const Type& type) {
           cm_type->columns_id());
       break;
     }
+    case Type::kCooperativeMatrixKHR: {
+      const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
+      const Type* component_type = cm_type->component_type();
+      rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
+          RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
+          cm_type->columns_id(), cm_type->use_id());
+      break;
+    }
     default:
       assert(false && "Unhandled type");
       return nullptr;
@@ -863,6 +888,12 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
                                      inst.GetSingleWordInOperand(2),
                                      inst.GetSingleWordInOperand(3));
       break;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
+      type = new CooperativeMatrixKHR(
+          GetType(inst.GetSingleWordInOperand(0)),
+          inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
+          inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
+      break;
     case spv::Op::OpTypeRayQueryKHR:
       type = new RayQueryKHR();
       break;

+ 5 - 6
3rdparty/spirv-tools/source/opt/type_manager.h

@@ -144,18 +144,17 @@ class TypeManager {
   // |type| (e.g. should be called in loop of |type|'s decorations).
   void AttachDecoration(const Instruction& inst, Type* type);
 
-  Type* GetUIntType() {
-    Integer int_type(32, false);
-    return GetRegisteredType(&int_type);
-  }
+  Type* GetUIntType() { return GetIntType(32, false); }
 
   uint32_t GetUIntTypeId() { return GetTypeInstruction(GetUIntType()); }
 
-  Type* GetSIntType() {
-    Integer int_type(32, true);
+  Type* GetIntType(int32_t bitWidth, bool isSigned) {
+    Integer int_type(bitWidth, isSigned);
     return GetRegisteredType(&int_type);
   }
 
+  Type* GetSIntType() { return GetIntType(32, true); }
+
   uint32_t GetSIntTypeId() { return GetTypeInstruction(GetSIntType()); }
 
   Type* GetFloatType() {

+ 42 - 0
3rdparty/spirv-tools/source/opt/types.cpp

@@ -128,6 +128,7 @@ std::unique_ptr<Type> Type::Clone() const {
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -175,6 +176,7 @@ bool Type::operator==(const Type& other) const {
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -230,6 +232,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -708,6 +711,45 @@ bool CooperativeMatrixNV::IsSameImpl(const Type* that,
          columns_id_ == mt->columns_id_ && HasSameDecorations(that);
 }
 
+CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
+                                           const uint32_t scope,
+                                           const uint32_t rows,
+                                           const uint32_t columns,
+                                           const uint32_t use)
+    : Type(kCooperativeMatrixKHR),
+      component_type_(type),
+      scope_id_(scope),
+      rows_id_(rows),
+      columns_id_(columns),
+      use_id_(use) {
+  assert(type != nullptr);
+  assert(scope != 0);
+  assert(rows != 0);
+  assert(columns != 0);
+}
+
+std::string CooperativeMatrixKHR::str() const {
+  std::ostringstream oss;
+  oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
+      << ", " << columns_id_ << ", " << use_id_ << ">";
+  return oss.str();
+}
+
+size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
+                                                   SeenTypes* seen) const {
+  hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
+  return component_type_->ComputeHashValue(hash, seen);
+}
+
+bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
+                                      IsSameCache* seen) const {
+  const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
+  if (!mt) return false;
+  return component_type_->IsSameImpl(mt->component_type_, seen) &&
+         scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
+         columns_id_ == mt->columns_id_ && HasSameDecorations(that);
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools

+ 35 - 0
3rdparty/spirv-tools/source/opt/types.h

@@ -60,6 +60,7 @@ class PipeStorage;
 class NamedBarrier;
 class AccelerationStructureNV;
 class CooperativeMatrixNV;
+class CooperativeMatrixKHR;
 class RayQueryKHR;
 class HitObjectNV;
 
@@ -100,6 +101,7 @@ class Type {
     kNamedBarrier,
     kAccelerationStructureNV,
     kCooperativeMatrixNV,
+    kCooperativeMatrixKHR,
     kRayQueryKHR,
     kHitObjectNV,
     kLast
@@ -201,6 +203,7 @@ class Type {
   DeclareCastMethod(NamedBarrier)
   DeclareCastMethod(AccelerationStructureNV)
   DeclareCastMethod(CooperativeMatrixNV)
+  DeclareCastMethod(CooperativeMatrixKHR)
   DeclareCastMethod(RayQueryKHR)
   DeclareCastMethod(HitObjectNV)
 #undef DeclareCastMethod
@@ -624,6 +627,38 @@ class CooperativeMatrixNV : public Type {
   const uint32_t columns_id_;
 };
 
+class CooperativeMatrixKHR : public Type {
+ public:
+  CooperativeMatrixKHR(const Type* type, const uint32_t scope,
+                       const uint32_t rows, const uint32_t columns,
+                       const uint32_t use);
+  CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
+
+  std::string str() const override;
+
+  CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
+  const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
+    return this;
+  }
+
+  size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+  const Type* component_type() const { return component_type_; }
+  uint32_t scope_id() const { return scope_id_; }
+  uint32_t rows_id() const { return rows_id_; }
+  uint32_t columns_id() const { return columns_id_; }
+  uint32_t use_id() const { return use_id_; }
+
+ private:
+  bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+  const Type* component_type_;
+  const uint32_t scope_id_;
+  const uint32_t rows_id_;
+  const uint32_t columns_id_;
+  const uint32_t use_id_;
+};
+
 #define DefineParameterlessType(type, name)                                \
   class type : public Type {                                               \
    public:                                                                 \

+ 2 - 1
3rdparty/spirv-tools/source/text.cpp

@@ -402,7 +402,8 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
     case SPV_OPERAND_TYPE_SELECTION_CONTROL:
     case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
-    case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: {
+    case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
       uint32_t value;
       if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
         return context->diagnostic(error)

+ 143 - 10
3rdparty/spirv-tools/source/val/validate_arithmetics.cpp

@@ -42,14 +42,29 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
            opcode != spv::Op::OpFMod);
       if (!_.IsFloatScalarType(result_type) &&
           !_.IsFloatVectorType(result_type) &&
-          !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
+          !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
+          !(opcode == spv::Op::OpFMul &&
+            _.IsCooperativeMatrixKHRType(result_type) &&
+            _.IsFloatCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected floating scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
-        if (_.GetOperandTypeId(inst, operand_index) != result_type)
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsFloatCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected arithmetic operands to be of Result Type: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -71,7 +86,19 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
 
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
-        if (_.GetOperandTypeId(inst, operand_index) != result_type)
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected arithmetic operands to be of Result Type: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -91,7 +118,10 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
           (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
            opcode != spv::Op::OpSMod);
       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
-          !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
+          !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+          !(opcode == spv::Op::OpIMul &&
+            _.IsCooperativeMatrixKHRType(result_type) &&
+            _.IsIntCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
@@ -102,9 +132,26 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsIntCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        }
+
         if (!type_id ||
             (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
-             !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
+             !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+             !(opcode == spv::Op::OpIMul &&
+               _.IsCooperativeMatrixKHRType(result_type) &&
+               _.IsIntCooperativeMatrixType(result_type))))
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected int scalar or vector type as operand: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -187,7 +234,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
 
     case spv::Op::OpMatrixTimesScalar: {
       if (!_.IsFloatMatrixType(result_type) &&
-          !_.IsCooperativeMatrixType(result_type))
+          !(_.IsCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected float matrix type as Result Type: "
                << spvOpcodeString(opcode);
@@ -459,22 +506,108 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
       const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
       const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
 
-      if (!_.IsCooperativeMatrixType(A_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(A_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as A Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(B_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(B_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as B Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(C_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(C_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as C Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(D_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(D_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as Result Type: "
+               << spvOpcodeString(opcode);
+      }
+
+      const auto A = _.FindDef(A_type_id);
+      const auto B = _.FindDef(B_type_id);
+      const auto C = _.FindDef(C_type_id);
+      const auto D = _.FindDef(D_type_id);
+
+      std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+          A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+      A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+      B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+      C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+      D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+      A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+      B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+      C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+      D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+      A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+      B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+      C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+      D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+      const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+                               std::tuple<bool, bool, uint32_t> Y) {
+        return (std::get<1>(X) && std::get<1>(Y) &&
+                std::get<2>(X) != std::get<2>(Y));
+      };
+
+      if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+          notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+          notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix scopes must match: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+          notEqual(C_rows, D_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'M' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+          notEqual(C_cols, D_cols)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'N' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_cols, B_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'K' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+      break;
+    }
+
+    case spv::Op::OpCooperativeMatrixMulAddKHR: {
+      const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+      const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+      const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+      const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+      if (!_.IsCooperativeMatrixAType(A_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be A Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixBType(B_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be B Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixAccType(C_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be Accumulator Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as Result Type: "
                << spvOpcodeString(opcode);

+ 20 - 0
3rdparty/spirv-tools/source/val/validate_composites.cpp

@@ -122,6 +122,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
         *member_type = type_inst->word(component_index + 2);
         break;
       }
+      case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeCooperativeMatrixNV: {
         *member_type = type_inst->word(2);
         break;
@@ -335,6 +336,25 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
 
       break;
     }
+    case spv::Op::OpTypeCooperativeMatrixKHR: {
+      const auto result_type_inst = _.FindDef(result_type);
+      assert(result_type_inst);
+      const auto component_type_id =
+          result_type_inst->GetOperandAs<uint32_t>(1);
+
+      if (3 != num_operands) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Must be only one constituent";
+      }
+
+      const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+      if (operand_type_id != component_type_id) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected Constituent type to be equal to the component type";
+      }
+      break;
+    }
     case spv::Op::OpTypeCooperativeMatrixNV: {
       const auto result_type_inst = _.FindDef(result_type);
       assert(result_type_inst);

+ 2 - 0
3rdparty/spirv-tools/source/val/validate_constants.cpp

@@ -243,6 +243,7 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _,
         }
       }
     } break;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     case spv::Op::OpTypeCooperativeMatrixNV: {
       if (1 != constituent_count) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -310,6 +311,7 @@ bool IsTypeNullable(const std::vector<uint32_t>& instruction,
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeMatrix:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     case spv::Op::OpTypeVector: {
       auto base_type = _.FindDef(instruction[2]);
       return base_type && IsTypeNullable(base_type->words(), _);

+ 16 - 2
3rdparty/spirv-tools/source/val/validate_conversion.cpp

@@ -473,7 +473,10 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
       const bool input_is_pointer = _.IsPointerType(input_type);
       const bool input_is_int_scalar = _.IsIntScalarType(input_type);
 
-      if (!result_is_pointer && !result_is_int_scalar &&
+      const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
+      const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
+
+      if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
           !_.IsIntVectorType(result_type) &&
           !_.IsFloatScalarType(result_type) &&
           !_.IsFloatVectorType(result_type))
@@ -481,13 +484,24 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
                << "Expected Result Type to be a pointer or int or float vector "
                << "or scalar type: " << spvOpcodeString(opcode);
 
-      if (!input_is_pointer && !input_is_int_scalar &&
+      if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
           !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
           !_.IsFloatVectorType(input_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be a pointer or int or float vector "
                << "or scalar: " << spvOpcodeString(opcode);
 
+      if (result_is_coopmat != input_is_coopmat)
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix can only be cast to another cooperative "
+               << "matrix: " << spvOpcodeString(opcode);
+
+      if (result_is_coopmat) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      }
+
       if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
           _.HasExtension(kSPV_KHR_physical_storage_buffer)) {
         const bool result_is_int_vector = _.IsIntVectorType(result_type);

+ 58 - 23
3rdparty/spirv-tools/source/val/validate_decorations.cpp

@@ -452,7 +452,16 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
     return ds;
   };
 
-  const auto& members = getStructMembers(struct_id, vstate);
+  // If we are checking physical storage buffer pointers, we may not actually
+  // have a struct here. Instead, pretend we have a struct with a single member
+  // at offset 0.
+  const auto& struct_type = vstate.FindDef(struct_id);
+  std::vector<uint32_t> members;
+  if (struct_type->opcode() == spv::Op::OpTypeStruct) {
+    members = getStructMembers(struct_id, vstate);
+  } else {
+    members.push_back(struct_id);
+  }
 
   // To check for member overlaps, we want to traverse the members in
   // offset order.
@@ -461,31 +470,38 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
     uint32_t offset;
   };
   std::vector<MemberOffsetPair> member_offsets;
-  member_offsets.reserve(members.size());
-  for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size());
-       memberIdx < numMembers; memberIdx++) {
-    uint32_t offset = 0xffffffff;
-    auto member_decorations =
-        vstate.id_member_decorations(struct_id, memberIdx);
-    for (auto decoration = member_decorations.begin;
-         decoration != member_decorations.end; ++decoration) {
-      assert(decoration->struct_member_index() == (int)memberIdx);
-      switch (decoration->dec_type()) {
-        case spv::Decoration::Offset:
-          offset = decoration->params()[0];
-          break;
-        default:
-          break;
+
+  // With physical storage buffers, we might be checking layouts that do not
+  // originate from a structure.
+  if (struct_type->opcode() == spv::Op::OpTypeStruct) {
+    member_offsets.reserve(members.size());
+    for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size());
+         memberIdx < numMembers; memberIdx++) {
+      uint32_t offset = 0xffffffff;
+      auto member_decorations =
+          vstate.id_member_decorations(struct_id, memberIdx);
+      for (auto decoration = member_decorations.begin;
+           decoration != member_decorations.end; ++decoration) {
+        assert(decoration->struct_member_index() == (int)memberIdx);
+        switch (decoration->dec_type()) {
+          case spv::Decoration::Offset:
+            offset = decoration->params()[0];
+            break;
+          default:
+            break;
+        }
       }
+      member_offsets.push_back(
+          MemberOffsetPair{memberIdx, incoming_offset + offset});
     }
-    member_offsets.push_back(
-        MemberOffsetPair{memberIdx, incoming_offset + offset});
+    std::stable_sort(
+        member_offsets.begin(), member_offsets.end(),
+        [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) {
+          return lhs.offset < rhs.offset;
+        });
+  } else {
+    member_offsets.push_back({0, 0});
   }
-  std::stable_sort(
-      member_offsets.begin(), member_offsets.end(),
-      [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) {
-        return lhs.offset < rhs.offset;
-      });
 
   // Now scan from lowest offset to highest offset.
   uint32_t nextValidOffset = 0;
@@ -1023,6 +1039,8 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
   std::unordered_set<uint32_t> uses_push_constant;
   for (const auto& inst : vstate.ordered_instructions()) {
     const auto& words = inst.words();
+    auto type_id = inst.type_id();
+    const Instruction* type_inst = vstate.FindDef(type_id);
     if (spv::Op::OpVariable == inst.opcode()) {
       const auto var_id = inst.id();
       // For storage class / decoration combinations, see Vulkan 14.5.4 "Offset
@@ -1276,6 +1294,23 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
           }
         }
       }
+    } else if (type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
+               type_inst->GetOperandAs<spv::StorageClass>(1u) ==
+                   spv::StorageClass::PhysicalStorageBuffer) {
+      const bool scalar_block_layout = vstate.options()->scalar_block_layout;
+      MemberConstraints constraints;
+      const bool buffer = true;
+      const auto data_type_id = type_inst->GetOperandAs<uint32_t>(2u);
+      const auto* data_type_inst = vstate.FindDef(data_type_id);
+      if (data_type_inst->opcode() == spv::Op::OpTypeStruct) {
+        ComputeMemberConstraintsForStruct(&constraints, data_type_id,
+                                          LayoutConstraints(), vstate);
+      }
+      if (auto res = checkLayout(data_type_id, "PhysicalStorageBuffer", "Block",
+                                 !buffer, scalar_block_layout, 0, constraints,
+                                 vstate)) {
+        return res;
+      }
     }
   }
   return SPV_SUCCESS;

+ 10 - 4
3rdparty/spirv-tools/source/val/validate_id.cpp

@@ -163,9 +163,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
               !inst->IsDebugInfo() && !inst->IsNonSemantic() &&
               !spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
               opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+              opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
               !(opcode == spv::Op::OpSpecConstantOp &&
-                spv::Op(inst->word(3)) ==
-                    spv::Op::OpCooperativeMatrixLengthNV)) {
+                (spv::Op(inst->word(3)) ==
+                     spv::Op::OpCooperativeMatrixLengthNV ||
+                 spv::Op(inst->word(3)) ==
+                     spv::Op::OpCooperativeMatrixLengthKHR))) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " cannot be a type";
@@ -179,9 +182,12 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
                      opcode != spv::Op::OpLoopMerge &&
                      opcode != spv::Op::OpFunction &&
                      opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+                     opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
                      !(opcode == spv::Op::OpSpecConstantOp &&
-                       spv::Op(inst->word(3)) ==
-                           spv::Op::OpCooperativeMatrixLengthNV)) {
+                       (spv::Op(inst->word(3)) ==
+                            spv::Op::OpCooperativeMatrixLengthNV ||
+                        spv::Op(inst->word(3)) ==
+                            spv::Op::OpCooperativeMatrixLengthKHR))) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " requires a type";

+ 0 - 1
3rdparty/spirv-tools/source/val/validate_image.cpp

@@ -297,7 +297,6 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
                           spv::ImageOperandsMask::ConstOffsets |
                           spv::ImageOperandsMask::Offsets)) > 1) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << _.VkErrorID(4662)
            << "Image Operands Offset, ConstOffset, ConstOffsets, Offsets "
               "cannot be used together";
   }

+ 25 - 6
3rdparty/spirv-tools/source/val/validate_interfaces.cpp

@@ -173,8 +173,19 @@ spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
       }
       break;
     }
+    case spv::Op::OpTypePointer: {
+      if (_.addressing_model() ==
+              spv::AddressingModel::PhysicalStorageBuffer64 &&
+          type->GetOperandAs<spv::StorageClass>(1) ==
+              spv::StorageClass::PhysicalStorageBuffer) {
+        *num_locations = 1;
+        break;
+      }
+      [[fallthrough]];
+    }
     default:
-      break;
+      return _.diag(SPV_ERROR_INVALID_DATA, type)
+             << "Invalid type to assign a location";
   }
 
   return SPV_SUCCESS;
@@ -206,6 +217,14 @@ uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
       // Skip the array.
       return NumConsumedComponents(_,
                                    _.FindDef(type->GetOperandAs<uint32_t>(1)));
+    case spv::Op::OpTypePointer:
+      if (_.addressing_model() ==
+              spv::AddressingModel::PhysicalStorageBuffer64 &&
+          type->GetOperandAs<spv::StorageClass>(1) ==
+              spv::StorageClass::PhysicalStorageBuffer) {
+        return 2;
+      }
+      break;
     default:
       // This is an error that is validated elsewhere.
       break;
@@ -363,12 +382,12 @@ spv_result_t GetLocationsForVariable(
       sub_type = _.FindDef(sub_type_id);
     }
 
-    for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
-      uint32_t num_locations = 0;
-      if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
-        return error;
+    uint32_t num_locations = 0;
+    if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
+      return error;
+    uint32_t num_components = NumConsumedComponents(_, sub_type);
 
-      uint32_t num_components = NumConsumedComponents(_, sub_type);
+    for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
       uint32_t array_location = location + (num_locations * array_idx);
       uint32_t start = array_location * 4;
       if (kMaxLocations <= start) {

+ 128 - 4
3rdparty/spirv-tools/source/val/validate_memory.cpp

@@ -204,6 +204,7 @@ bool ContainsCooperativeMatrix(ValidationState_t& _,
 
   switch (storage->opcode()) {
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return true;
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeRuntimeArray:
@@ -232,6 +233,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
   spv::StorageClass src_sc = spv::StorageClass::Max;
   switch (inst->opcode()) {
     case spv::Op::OpCooperativeMatrixLoadNV:
+    case spv::Op::OpCooperativeMatrixLoadKHR:
     case spv::Op::OpLoad: {
       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
       auto load_pointer_type = _.FindDef(load_pointer->type_id());
@@ -239,6 +241,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
       break;
     }
     case spv::Op::OpCooperativeMatrixStoreNV:
+    case spv::Op::OpCooperativeMatrixStoreKHR:
     case spv::Op::OpStore: {
       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
       auto store_pointer_type = _.FindDef(store_pointer->type_id());
@@ -326,7 +329,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
     if (inst->opcode() == spv::Op::OpLoad ||
-        inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
+        inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
+        inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "MakePointerAvailableKHR cannot be used with OpLoad.";
     }
@@ -442,6 +446,7 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
       storage_class != spv::StorageClass::CrossWorkgroup &&
       storage_class != spv::StorageClass::Private &&
       storage_class != spv::StorageClass::Function &&
+      storage_class != spv::StorageClass::UniformConstant &&
       storage_class != spv::StorageClass::RayPayloadKHR &&
       storage_class != spv::StorageClass::IncomingRayPayloadKHR &&
       storage_class != spv::StorageClass::HitAttributeKHR &&
@@ -475,8 +480,8 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
                   "can only be used with non-externally visible shader Storage "
                   "Classes: Workgroup, CrossWorkgroup, Private, Function, "
                   "Input, Output, RayPayloadKHR, IncomingRayPayloadKHR, "
-                  "HitAttributeKHR, CallableDataKHR, or "
-                  "IncomingCallableDataKHR";
+                  "HitAttributeKHR, CallableDataKHR, "
+                  "IncomingCallableDataKHR, or UniformConstant";
       }
     }
   }
@@ -1356,6 +1361,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeVector:
       case spv::Op::OpTypeCooperativeMatrixNV:
+      case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeArray:
       case spv::Op::OpTypeRuntimeArray: {
         // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
@@ -1553,9 +1559,15 @@ spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
            << " must be OpTypeInt with width 32 and signedness 0.";
   }
 
+  bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
   auto type_id = inst->GetOperandAs<uint32_t>(2);
   auto type = state.FindDef(type_id);
-  if (type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+  if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    return state.diag(SPV_ERROR_INVALID_ID, inst)
+           << "The type in " << instr_name << " <id> "
+           << state.getIdName(type_id)
+           << " must be OpTypeCooperativeMatrixKHR.";
+  } else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
     return state.diag(SPV_ERROR_INVALID_ID, inst)
            << "The type in " << instr_name << " <id> "
            << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
@@ -1667,6 +1679,112 @@ spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
+                                                   const Instruction* inst) {
+  uint32_t type_id;
+  const char* opname;
+  if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+    type_id = inst->type_id();
+    opname = "spv::Op::OpCooperativeMatrixLoadKHR";
+  } else {
+    // get Object operand's type
+    type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
+    opname = "spv::Op::OpCooperativeMatrixStoreKHR";
+  }
+
+  auto matrix_type = _.FindDef(type_id);
+
+  if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
+             << _.getIdName(type_id) << " is not a cooperative matrix type.";
+    } else {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
+             << _.getIdName(type_id) << " is not a cooperative matrix type.";
+    }
+  }
+
+  const auto pointer_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
+  const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
+  const auto pointer = _.FindDef(pointer_id);
+  if (!pointer ||
+      ((_.addressing_model() == spv::AddressingModel::Logical) &&
+       ((!_.features().variable_pointers &&
+         !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+        (_.features().variable_pointers &&
+         !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> " << _.getIdName(pointer_id)
+           << " is not a logical pointer.";
+  }
+
+  const auto pointer_type_id = pointer->type_id();
+  const auto pointer_type = _.FindDef(pointer_type_id);
+  if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " type for pointer <id> " << _.getIdName(pointer_id)
+           << " is not a pointer type.";
+  }
+
+  const auto storage_class_index = 1u;
+  const auto storage_class =
+      pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
+
+  if (storage_class != spv::StorageClass::Workgroup &&
+      storage_class != spv::StorageClass::StorageBuffer &&
+      storage_class != spv::StorageClass::PhysicalStorageBuffer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " storage class for pointer type <id> "
+           << _.getIdName(pointer_type_id)
+           << " is not Workgroup or StorageBuffer.";
+  }
+
+  const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
+  const auto pointee_type = _.FindDef(pointee_id);
+  if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
+                         _.IsFloatScalarOrVectorType(pointee_id))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> " << _.getIdName(pointer->id())
+           << "s Type must be a scalar or vector type.";
+  }
+
+  const auto layout_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
+  const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
+  const auto colmajor = _.FindDef(colmajor_id);
+  if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
+      !(spvOpcodeIsConstant(colmajor->opcode()) ||
+        spvOpcodeIsSpecConstant(colmajor->opcode()))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
+           << " must be a 32-bit integer constant instruction.";
+  }
+
+  const auto stride_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
+  if (inst->operands().size() > stride_index) {
+    const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
+    const auto stride = _.FindDef(stride_id);
+    if (!stride || !_.IsIntScalarType(stride->type_id())) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "Stride operand <id> " << _.getIdName(stride_id)
+             << " must be a scalar integer type.";
+    }
+  }
+
+  const auto memory_access_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
+  if (inst->operands().size() > memory_access_index) {
+    if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
+      return error;
+  }
+
+  return SPV_SUCCESS;
+}
+
 spv_result_t ValidatePtrComparison(ValidationState_t& _,
                                    const Instruction* inst) {
   if (_.addressing_model() == spv::AddressingModel::Logical &&
@@ -1756,9 +1874,15 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
         return error;
       break;
+    case spv::Op::OpCooperativeMatrixLengthKHR:
     case spv::Op::OpCooperativeMatrixLengthNV:
       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
       break;
+    case spv::Op::OpCooperativeMatrixLoadKHR:
+    case spv::Op::OpCooperativeMatrixStoreKHR:
+      if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
+        return error;
+      break;
     case spv::Op::OpPtrEqual:
     case spv::Op::OpPtrNotEqual:
     case spv::Op::OpPtrDiff:

+ 20 - 7
3rdparty/spirv-tools/source/val/validate_type.cpp

@@ -552,8 +552,8 @@ spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
-                                             const Instruction* inst) {
+spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
+                                           const Instruction* inst) {
   const auto component_type_index = 1;
   const auto component_type_id =
       inst->GetOperandAs<uint32_t>(component_type_index);
@@ -561,7 +561,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
   if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
                           spv::Op::OpTypeInt != component_type->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Component Type <id> "
+           << "OpTypeCooperativeMatrix Component Type <id> "
            << _.getIdName(component_type_id)
            << " is not a scalar numerical type.";
   }
@@ -572,7 +572,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
   if (!scope || !_.IsIntScalarType(scope->type_id()) ||
       !spvOpcodeIsConstant(scope->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Scope <id> " << _.getIdName(scope_id)
+           << "OpTypeCooperativeMatrix Scope <id> " << _.getIdName(scope_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
@@ -582,7 +582,7 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
   if (!rows || !_.IsIntScalarType(rows->type_id()) ||
       !spvOpcodeIsConstant(rows->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Rows <id> " << _.getIdName(rows_id)
+           << "OpTypeCooperativeMatrix Rows <id> " << _.getIdName(rows_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
@@ -592,10 +592,22 @@ spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
   if (!cols || !_.IsIntScalarType(cols->type_id()) ||
       !spvOpcodeIsConstant(cols->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Cols <id> " << _.getIdName(cols_id)
+           << "OpTypeCooperativeMatrix Cols <id> " << _.getIdName(cols_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
+  if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    const auto use_index = 5;
+    const auto use_id = inst->GetOperandAs<uint32_t>(use_index);
+    const auto use = _.FindDef(use_id);
+    if (!use || !_.IsIntScalarType(use->type_id()) ||
+        !spvOpcodeIsConstant(use->opcode())) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "OpTypeCooperativeMatrixKHR Use <id> " << _.getIdName(use_id)
+             << " is not a constant instruction with scalar integer type.";
+    }
+  }
+
   return SPV_SUCCESS;
 }
 }  // namespace
@@ -640,7 +652,8 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
       if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
       break;
     case spv::Op::OpTypeCooperativeMatrixNV:
-      if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
+      if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
       break;
     default:
       break;

+ 68 - 7
3rdparty/spirv-tools/source/val/validation_state.cpp

@@ -859,6 +859,7 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
       return GetComponentType(inst->word(2));
 
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return inst->word(2);
 
     default:
@@ -886,6 +887,7 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const {
       return inst->word(3);
 
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       // Actual dimension isn't known, return 0
       return 0;
 
@@ -1142,22 +1144,68 @@ bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
 }
 
 bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && (inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV ||
+                  inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR);
+}
+
+bool ValidationState_t::IsCooperativeMatrixNVType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
 }
 
+bool ValidationState_t::IsCooperativeMatrixKHRType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR;
+}
+
+bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse ==
+           static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
+  }
+  return false;
+}
+
+bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse ==
+           static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
+  }
+  return false;
+}
+bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse == static_cast<uint64_t>(
+                            spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
+  }
+  return false;
+}
+
 bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsFloatScalarType(FindDef(id)->word(2));
 }
 
 bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsIntScalarType(FindDef(id)->word(2));
 }
 
 bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsUnsignedIntScalarType(FindDef(id)->word(2));
 }
 
@@ -1173,8 +1221,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
   const auto m1_type = FindDef(m1);
   const auto m2_type = FindDef(m2);
 
-  if (m1_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV ||
-      m2_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+  if (m1_type->opcode() != m2_type->opcode()) {
     return diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected cooperative matrix types";
   }
@@ -1224,6 +1271,21 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
            << "identical";
   }
 
+  if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    uint32_t m1_use_id = m1_type->GetOperandAs<uint32_t>(5);
+    uint32_t m2_use_id = m2_type->GetOperandAs<uint32_t>(5);
+    std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+        EvalInt32IfConst(m1_use_id);
+    std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+        EvalInt32IfConst(m2_use_id);
+
+    if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+      return diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Expected Use of Matrix type and Result Type to be "
+             << "identical";
+    }
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -1489,6 +1551,7 @@ bool ValidationState_t::ContainsType(
     case spv::Op::OpTypeImage:
     case spv::Op::OpTypeSampledImage:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
                           traverse_all_types);
     case spv::Op::OpTypePointer:
@@ -2048,8 +2111,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
       return VUID_WRAP(VUID-StandaloneSpirv-OpImageTexelPointer-04658);
     case 4659:
       return VUID_WRAP(VUID-StandaloneSpirv-OpImageQuerySizeLod-04659);
-    case 4662:
-      return VUID_WRAP(VUID-StandaloneSpirv-Offset-04662);
     case 4663:
       return VUID_WRAP(VUID-StandaloneSpirv-Offset-04663);
     case 4664:

+ 5 - 0
3rdparty/spirv-tools/source/val/validation_state.h

@@ -610,6 +610,11 @@ class ValidationState_t {
   bool IsPointerType(uint32_t id) const;
   bool IsAccelerationStructureType(uint32_t id) const;
   bool IsCooperativeMatrixType(uint32_t id) const;
+  bool IsCooperativeMatrixNVType(uint32_t id) const;
+  bool IsCooperativeMatrixKHRType(uint32_t id) const;
+  bool IsCooperativeMatrixAType(uint32_t id) const;
+  bool IsCooperativeMatrixBType(uint32_t id) const;
+  bool IsCooperativeMatrixAccType(uint32_t id) const;
   bool IsFloatCooperativeMatrixType(uint32_t id) const;
   bool IsIntCooperativeMatrixType(uint32_t id) const;
   bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;

Some files were not shown because too many files changed in this diff