Browse Source

Update spirv-cross

Alex Szpakowski 5 years ago
parent
commit
f5ecaf2e92

+ 26 - 26
src/libraries/spirv_cross/spirv.h

@@ -54,11 +54,11 @@
 typedef unsigned int SpvId;
 
 #define SPV_VERSION 0x10500
-#define SPV_REVISION 1
+#define SPV_REVISION 3
 
 static const unsigned int SpvMagicNumber = 0x07230203;
 static const unsigned int SpvVersion = 0x00010500;
-static const unsigned int SpvRevision = 1;
+static const unsigned int SpvRevision = 3;
 static const unsigned int SpvOpCodeMask = 0xffff;
 static const unsigned int SpvWordCountShift = 16;
 
@@ -1899,6 +1899,13 @@ inline void SpvHasResultAndType(SpvOp opcode, bool *hasResult, bool *hasResultTy
     case SpvOpSubgroupAnyKHR: *hasResult = true; *hasResultType = true; break;
     case SpvOpSubgroupAllEqualKHR: *hasResult = true; *hasResultType = true; break;
     case SpvOpSubgroupReadInvocationKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpTypeRayQueryProvisionalKHR: *hasResult = true; *hasResultType = false; break;
+    case SpvOpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
+    case SpvOpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
+    case SpvOpRayQueryGenerateIntersectionKHR: *hasResult = false; *hasResultType = false; break;
+    case SpvOpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break;
+    case SpvOpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break;
     case SpvOpGroupIAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
     case SpvOpGroupFAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
     case SpvOpGroupFMinNonUniformAMD: *hasResult = true; *hasResultType = true; break;
@@ -1918,30 +1925,6 @@ inline void SpvHasResultAndType(SpvOp opcode, bool *hasResult, bool *hasResultTy
     case SpvOpTerminateRayNV: *hasResult = false; *hasResultType = false; break;
     case SpvOpTraceNV: *hasResult = false; *hasResultType = false; break;
     case SpvOpTypeAccelerationStructureNV: *hasResult = true; *hasResultType = false; break;
-    case SpvOpTypeRayQueryProvisionalKHR: *hasResult = true; *hasResultType = false; break;
-    case SpvOpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
-    case SpvOpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
-    case SpvOpRayQueryGenerateIntersectionKHR: *hasResult = false; *hasResultType = false; break;
-    case SpvOpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break;
-    case SpvOpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetRayTMinKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetRayFlagsKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionTKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionInstanceIdKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionGeometryIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionBarycentricsKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionFrontFaceKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionObjectRayOriginKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetWorldRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetWorldRayOriginKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionObjectToWorldKHR: *hasResult = true; *hasResultType = true; break;
-    case SpvOpRayQueryGetIntersectionWorldToObjectKHR: *hasResult = true; *hasResultType = true; break;
     case SpvOpExecuteCallableNV: *hasResult = false; *hasResultType = false; break;
     case SpvOpTypeCooperativeMatrixNV: *hasResult = true; *hasResultType = false; break;
     case SpvOpCooperativeMatrixLoadNV: *hasResult = true; *hasResultType = true; break;
@@ -2096,6 +2079,23 @@ inline void SpvHasResultAndType(SpvOp opcode, bool *hasResult, bool *hasResultTy
     case SpvOpSubgroupAvcSicGetPackedSkcLumaCountThresholdINTEL: *hasResult = true; *hasResultType = true; break;
     case SpvOpSubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL: *hasResult = true; *hasResultType = true; break;
     case SpvOpSubgroupAvcSicGetInterRawSadsINTEL: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetRayTMinKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetRayFlagsKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionTKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionInstanceIdKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionGeometryIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionBarycentricsKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionFrontFaceKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionObjectRayOriginKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetWorldRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetWorldRayOriginKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionObjectToWorldKHR: *hasResult = true; *hasResultType = true; break;
+    case SpvOpRayQueryGetIntersectionWorldToObjectKHR: *hasResult = true; *hasResultType = true; break;
     }
 }
 #endif /* SPV_ENABLE_UTILITY_CODE */

+ 26 - 26
src/libraries/spirv_cross/spirv.hpp

@@ -50,11 +50,11 @@ namespace spv {
 typedef unsigned int Id;
 
 #define SPV_VERSION 0x10500
-#define SPV_REVISION 1
+#define SPV_REVISION 3
 
 static const unsigned int MagicNumber = 0x07230203;
 static const unsigned int Version = 0x00010500;
-static const unsigned int Revision = 1;
+static const unsigned int Revision = 3;
 static const unsigned int OpCodeMask = 0xffff;
 static const unsigned int WordCountShift = 16;
 
@@ -1895,6 +1895,13 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
     case OpSubgroupAnyKHR: *hasResult = true; *hasResultType = true; break;
     case OpSubgroupAllEqualKHR: *hasResult = true; *hasResultType = true; break;
     case OpSubgroupReadInvocationKHR: *hasResult = true; *hasResultType = true; break;
+    case OpTypeRayQueryProvisionalKHR: *hasResult = true; *hasResultType = false; break;
+    case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
+    case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
+    case OpRayQueryGenerateIntersectionKHR: *hasResult = false; *hasResultType = false; break;
+    case OpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break;
+    case OpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break;
     case OpGroupIAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
     case OpGroupFAddNonUniformAMD: *hasResult = true; *hasResultType = true; break;
     case OpGroupFMinNonUniformAMD: *hasResult = true; *hasResultType = true; break;
@@ -1914,30 +1921,6 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
     case OpTerminateRayNV: *hasResult = false; *hasResultType = false; break;
     case OpTraceNV: *hasResult = false; *hasResultType = false; break;
     case OpTypeAccelerationStructureNV: *hasResult = true; *hasResultType = false; break;
-    case OpTypeRayQueryProvisionalKHR: *hasResult = true; *hasResultType = false; break;
-    case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
-    case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
-    case OpRayQueryGenerateIntersectionKHR: *hasResult = false; *hasResultType = false; break;
-    case OpRayQueryConfirmIntersectionKHR: *hasResult = false; *hasResultType = false; break;
-    case OpRayQueryProceedKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionTypeKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetRayTMinKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetRayFlagsKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionTKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionInstanceCustomIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionInstanceIdKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionGeometryIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionPrimitiveIndexKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionBarycentricsKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionFrontFaceKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionObjectRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionObjectRayOriginKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetWorldRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetWorldRayOriginKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionObjectToWorldKHR: *hasResult = true; *hasResultType = true; break;
-    case OpRayQueryGetIntersectionWorldToObjectKHR: *hasResult = true; *hasResultType = true; break;
     case OpExecuteCallableNV: *hasResult = false; *hasResultType = false; break;
     case OpTypeCooperativeMatrixNV: *hasResult = true; *hasResultType = false; break;
     case OpCooperativeMatrixLoadNV: *hasResult = true; *hasResultType = true; break;
@@ -2092,6 +2075,23 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
     case OpSubgroupAvcSicGetPackedSkcLumaCountThresholdINTEL: *hasResult = true; *hasResultType = true; break;
     case OpSubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL: *hasResult = true; *hasResultType = true; break;
     case OpSubgroupAvcSicGetInterRawSadsINTEL: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetRayTMinKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetRayFlagsKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionTKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionInstanceCustomIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionInstanceIdKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionGeometryIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionPrimitiveIndexKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionBarycentricsKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionFrontFaceKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionObjectRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionObjectRayOriginKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetWorldRayDirectionKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetWorldRayOriginKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionObjectToWorldKHR: *hasResult = true; *hasResultType = true; break;
+    case OpRayQueryGetIntersectionWorldToObjectKHR: *hasResult = true; *hasResultType = true; break;
     }
 }
 #endif /* SPV_ENABLE_UTILITY_CODE */

+ 6 - 1
src/libraries/spirv_cross/spirv_common.hpp

@@ -558,11 +558,16 @@ struct SPIRType : IVariant
 	// Keep track of how many pointer layers we have.
 	uint32_t pointer_depth = 0;
 	bool pointer = false;
+	bool forward_pointer = false;
 
 	spv::StorageClass storage = spv::StorageClassGeneric;
 
 	SmallVector<TypeID> member_types;
 
+	// If member order has been rewritten to handle certain scenarios with Offset,
+	// allow codegen to rewrite the index.
+	SmallVector<uint32_t> member_type_index_redirection;
+
 	struct ImageType
 	{
 		TypeID type;
@@ -776,7 +781,7 @@ struct SPIRBlock : IVariant
 		ComplexLoop
 	};
 
-	enum
+	enum : uint32_t
 	{
 		NoDominator = 0xffffffffu
 	};

+ 24 - 8
src/libraries/spirv_cross/spirv_cross.cpp

@@ -1659,6 +1659,13 @@ size_t Compiler::get_declared_struct_member_size(const SPIRType &struct_type, ui
 		break;
 	}
 
+	if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
+	{
+		// Check if this is a top-level pointer type, and not an array of pointers.
+		if (type.pointer_depth > get<SPIRType>(type.parent_type).pointer_depth)
+			return 8;
+	}
+
 	if (!type.array.empty())
 	{
 		// For arrays, we can use ArrayStride to get an easy check.
@@ -4204,19 +4211,22 @@ Bitset Compiler::combined_decoration_for_member(const SPIRType &type, uint32_t i
 
 	if (type_meta)
 	{
-		auto &memb = type_meta->members;
-		if (index >= memb.size())
+		auto &members = type_meta->members;
+		if (index >= members.size())
 			return flags;
-		auto &dec = memb[index];
+		auto &dec = members[index];
 
-		// If our type is a struct, traverse all the members as well recursively.
 		flags.merge_or(dec.decoration_flags);
 
-		for (uint32_t i = 0; i < type.member_types.size(); i++)
+		auto &member_type = get<SPIRType>(type.member_types[index]);
+
+		// If our member type is a struct, traverse all the child members as well recursively.
+		auto &member_childs = member_type.member_types;
+		for (uint32_t i = 0; i < member_childs.size(); i++)
 		{
-			auto &memb_type = get<SPIRType>(type.member_types[i]);
-			if (!memb_type.pointer)
-				flags.merge_or(combined_decoration_for_member(memb_type, i));
+			auto &child_member_type = get<SPIRType>(member_childs[i]);
+			if (!child_member_type.pointer)
+				flags.merge_or(combined_decoration_for_member(member_type, i));
 		}
 	}
 
@@ -4634,6 +4644,12 @@ bool Compiler::type_is_array_of_pointers(const SPIRType &type) const
 	return type.pointer_depth == get<SPIRType>(type.parent_type).pointer_depth;
 }
 
+bool Compiler::type_is_top_level_physical_pointer(const SPIRType &type) const
+{
+	return type.pointer && type.storage == StorageClassPhysicalStorageBuffer &&
+	       type.pointer_depth > get<SPIRType>(type.parent_type).pointer_depth;
+}
+
 bool Compiler::flush_phi_required(BlockID from, BlockID to) const
 {
 	auto &child = get<SPIRBlock>(to);

+ 1 - 0
src/libraries/spirv_cross/spirv_cross.hpp

@@ -1037,6 +1037,7 @@ protected:
 	void unset_extended_member_decoration(uint32_t type, uint32_t index, ExtendedDecorations decoration);
 
 	bool type_is_array_of_pointers(const SPIRType &type) const;
+	bool type_is_top_level_physical_pointer(const SPIRType &type) const;
 	bool type_is_block_like(const SPIRType &type) const;
 	bool type_is_opaque_value(const SPIRType &type) const;
 

+ 4 - 8
src/libraries/spirv_cross/spirv_cross_c.cpp

@@ -485,6 +485,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
 	case SPVC_COMPILER_OPTION_HLSL_NONWRITABLE_UAV_TEXTURE_AS_SRV:
 		options->hlsl.nonwritable_uav_texture_as_srv = value != 0;
 		break;
+
+	case SPVC_COMPILER_OPTION_HLSL_ENABLE_16BIT_TYPES:
+		options->hlsl.enable_16bit_types = value != 0;
+		break;
 #endif
 
 #if SPIRV_CROSS_C_API_MSL
@@ -1017,12 +1021,8 @@ spvc_result spvc_compiler_msl_add_vertex_attribute(spvc_compiler compiler, const
 	auto &msl = *static_cast<CompilerMSL *>(compiler->compiler.get());
 	MSLVertexAttr attr;
 	attr.location = va->location;
-	attr.msl_buffer = va->msl_buffer;
-	attr.msl_offset = va->msl_offset;
-	attr.msl_stride = va->msl_stride;
 	attr.format = static_cast<MSLVertexFormat>(va->format);
 	attr.builtin = static_cast<spv::BuiltIn>(va->builtin);
-	attr.per_instance = va->per_instance != 0;
 	msl.add_msl_vertex_attribute(attr);
 	return SPVC_SUCCESS;
 #else
@@ -2260,12 +2260,8 @@ void spvc_msl_vertex_attribute_init(spvc_msl_vertex_attribute *attr)
 	// Crude, but works.
 	MSLVertexAttr attr_default;
 	attr->location = attr_default.location;
-	attr->per_instance = attr_default.per_instance ? SPVC_TRUE : SPVC_FALSE;
 	attr->format = static_cast<spvc_msl_vertex_format>(attr_default.format);
 	attr->builtin = static_cast<SpvBuiltIn>(attr_default.builtin);
-	attr->msl_buffer = attr_default.msl_buffer;
-	attr->msl_offset = attr_default.msl_offset;
-	attr->msl_stride = attr_default.msl_stride;
 #else
 	memset(attr, 0, sizeof(*attr));
 #endif

+ 9 - 1
src/libraries/spirv_cross/spirv_cross_c.h

@@ -33,7 +33,7 @@ extern "C" {
 /* Bumped if ABI or API breaks backwards compatibility. */
 #define SPVC_C_API_VERSION_MAJOR 0
 /* Bumped if APIs or enumerations are added in a backwards compatible way. */
-#define SPVC_C_API_VERSION_MINOR 33
+#define SPVC_C_API_VERSION_MINOR 34
 /* Bumped if internal implementation details change. */
 #define SPVC_C_API_VERSION_PATCH 0
 
@@ -270,10 +270,16 @@ typedef enum spvc_msl_vertex_format
 typedef struct spvc_msl_vertex_attribute
 {
 	unsigned location;
+
+	/* Obsolete, do not use. Only lingers on for ABI compatibility. */
 	unsigned msl_buffer;
+	/* Obsolete, do not use. Only lingers on for ABI compatibility. */
 	unsigned msl_offset;
+	/* Obsolete, do not use. Only lingers on for ABI compatibility. */
 	unsigned msl_stride;
+	/* Obsolete, do not use. Only lingers on for ABI compatibility. */
 	spvc_bool per_instance;
+
 	spvc_msl_vertex_format format;
 	SpvBuiltIn builtin;
 } spvc_msl_vertex_attribute;
@@ -588,6 +594,8 @@ typedef enum spvc_compiler_option
 	SPVC_COMPILER_OPTION_MSL_ENABLE_FRAG_STENCIL_REF_BUILTIN = 58 | SPVC_COMPILER_OPTION_MSL_BIT,
 	SPVC_COMPILER_OPTION_MSL_ENABLE_CLIP_DISTANCE_USER_VARYING = 59 | SPVC_COMPILER_OPTION_MSL_BIT,
 
+	SPVC_COMPILER_OPTION_HLSL_ENABLE_16BIT_TYPES = 60 | SPVC_COMPILER_OPTION_HLSL_BIT,
+
 	SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
 } spvc_compiler_option;
 

+ 158 - 25
src/libraries/spirv_cross/spirv_glsl.cpp

@@ -731,6 +731,13 @@ void CompilerGLSL::emit_header()
 				statement("#endif");
 			}
 		}
+		else if (!options.vulkan_semantics && ext == "GL_ARB_shader_draw_parameters")
+		{
+			// Soft-enable this extension on plain GLSL.
+			statement("#ifdef ", ext);
+			statement("#extension ", ext, " : enable");
+			statement("#endif");
+		}
 		else
 			statement("#extension ", ext, " : require");
 	}
@@ -2943,16 +2950,28 @@ void CompilerGLSL::emit_resources()
 			}
 			else if (id.get_type() == TypeType)
 			{
-				auto &type = id.get<SPIRType>();
-				if (type.basetype == SPIRType::Struct && type.array.empty() && !type.pointer &&
-				    (!ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) &&
-				     !ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock)))
+				auto *type = &id.get<SPIRType>();
+
+				bool is_natural_struct =
+					type->basetype == SPIRType::Struct && type->array.empty() && !type->pointer &&
+					(!has_decoration(type->self, DecorationBlock) && !has_decoration(type->self, DecorationBufferBlock));
+
+				// Special case, ray payload and hit attribute blocks are not really blocks, just regular structs.
+				if (type->basetype == SPIRType::Struct && type->pointer && has_decoration(type->self, DecorationBlock) &&
+				    (type->storage == StorageClassRayPayloadNV || type->storage == StorageClassIncomingRayPayloadNV ||
+				     type->storage == StorageClassHitAttributeNV))
+				{
+					type = &get<SPIRType>(type->parent_type);
+					is_natural_struct = true;
+				}
+
+				if (is_natural_struct)
 				{
 					if (emitted)
 						statement("");
 					emitted = false;
 
-					emit_struct(type);
+					emit_struct(*type);
 				}
 			}
 		}
@@ -3070,6 +3089,8 @@ void CompilerGLSL::emit_resources()
 		statement("");
 	emitted = false;
 
+	bool emitted_base_instance = false;
+
 	// Output in/out interfaces.
 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
 		auto &type = this->get<SPIRType>(var.basetype);
@@ -3092,13 +3113,42 @@ void CompilerGLSL::emit_resources()
 		}
 		else if (is_builtin_variable(var))
 		{
+			auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
 			// For gl_InstanceIndex emulation on GLES, the API user needs to
 			// supply this uniform.
-			if (options.vertex.support_nonzero_base_instance &&
-			    ir.meta[var.self].decoration.builtin_type == BuiltInInstanceIndex && !options.vulkan_semantics)
+
+			// The draw parameter extension is soft-enabled on GL with some fallbacks.
+			if (!options.vulkan_semantics)
 			{
-				statement("uniform int SPIRV_Cross_BaseInstance;");
-				emitted = true;
+				if (!emitted_base_instance &&
+				    ((options.vertex.support_nonzero_base_instance && builtin == BuiltInInstanceIndex) ||
+				     (builtin == BuiltInBaseInstance)))
+				{
+					statement("#ifdef GL_ARB_shader_draw_parameters");
+					statement("#define SPIRV_Cross_BaseInstance gl_BaseInstanceARB");
+					statement("#else");
+					// A crude, but simple workaround which should be good enough for non-indirect draws.
+					statement("uniform int SPIRV_Cross_BaseInstance;");
+					statement("#endif");
+					emitted = true;
+					emitted_base_instance = true;
+				}
+				else if (builtin == BuiltInBaseVertex)
+				{
+					statement("#ifdef GL_ARB_shader_draw_parameters");
+					statement("#define SPIRV_Cross_BaseVertex gl_BaseVertexARB");
+					statement("#else");
+					// A crude, but simple workaround which should be good enough for non-indirect draws.
+					statement("uniform int SPIRV_Cross_BaseVertex;");
+					statement("#endif");
+				}
+				else if (builtin == BuiltInDrawIndex)
+				{
+					statement("#ifndef GL_ARB_shader_draw_parameters");
+					// Cannot really be worked around.
+					statement("#error GL_ARB_shader_draw_parameters is not supported.");
+					statement("#endif");
+				}
 			}
 		}
 	});
@@ -6792,8 +6842,21 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 		return "gl_VertexID";
 	case BuiltInInstanceId:
 		if (options.vulkan_semantics)
-			SPIRV_CROSS_THROW(
-			    "Cannot implement gl_InstanceID in Vulkan GLSL. This shader was created with GL semantics.");
+		{
+			auto model = get_entry_point().model;
+			switch (model)
+			{
+			case spv::ExecutionModelIntersectionKHR:
+			case spv::ExecutionModelAnyHitKHR:
+			case spv::ExecutionModelClosestHitKHR:
+				// gl_InstanceID is allowed in these shaders.
+				break;
+
+			default:
+				SPIRV_CROSS_THROW(
+					"Cannot implement gl_InstanceID in Vulkan GLSL. This shader was created with GL semantics.");
+			}
+		}
 		return "gl_InstanceID";
 	case BuiltInVertexIndex:
 		if (options.vulkan_semantics)
@@ -6804,7 +6867,14 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 		if (options.vulkan_semantics)
 			return "gl_InstanceIndex";
 		else if (options.vertex.support_nonzero_base_instance)
+		{
+			if (!options.vulkan_semantics)
+			{
+				// This is a soft-enable. We will opt-in to using gl_BaseInstanceARB if supported.
+				require_extension_internal("GL_ARB_shader_draw_parameters");
+			}
 			return "(gl_InstanceID + SPIRV_Cross_BaseInstance)"; // ... but not gl_InstanceID.
+		}
 		else
 			return "gl_InstanceID";
 	case BuiltInPrimitiveId:
@@ -6846,33 +6916,69 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 		return "gl_LocalInvocationIndex";
 	case BuiltInHelperInvocation:
 		return "gl_HelperInvocation";
+
 	case BuiltInBaseVertex:
 		if (options.es)
 			SPIRV_CROSS_THROW("BaseVertex not supported in ES profile.");
-		if (options.version < 460)
+
+		if (options.vulkan_semantics)
+		{
+			if (options.version < 460)
+			{
+				require_extension_internal("GL_ARB_shader_draw_parameters");
+				return "gl_BaseVertexARB";
+			}
+			return "gl_BaseVertex";
+		}
+		else
 		{
+			// On regular GL, this is soft-enabled and we emit ifdefs in code.
 			require_extension_internal("GL_ARB_shader_draw_parameters");
-			return "gl_BaseVertexARB";
+			return "SPIRV_Cross_BaseVertex";
 		}
-		return "gl_BaseVertex";
+		break;
+
 	case BuiltInBaseInstance:
 		if (options.es)
 			SPIRV_CROSS_THROW("BaseInstance not supported in ES profile.");
-		if (options.version < 460)
+
+		if (options.vulkan_semantics)
 		{
+			if (options.version < 460)
+			{
+				require_extension_internal("GL_ARB_shader_draw_parameters");
+				return "gl_BaseInstanceARB";
+			}
+			return "gl_BaseInstance";
+		}
+		else
+		{
+			// On regular GL, this is soft-enabled and we emit ifdefs in code.
 			require_extension_internal("GL_ARB_shader_draw_parameters");
-			return "gl_BaseInstanceARB";
+			return "SPIRV_Cross_BaseInstance";
 		}
-		return "gl_BaseInstance";
+		break;
+
 	case BuiltInDrawIndex:
 		if (options.es)
 			SPIRV_CROSS_THROW("DrawIndex not supported in ES profile.");
-		if (options.version < 460)
+
+		if (options.vulkan_semantics)
 		{
+			if (options.version < 460)
+			{
+				require_extension_internal("GL_ARB_shader_draw_parameters");
+				return "gl_DrawIDARB";
+			}
+			return "gl_DrawID";
+		}
+		else
+		{
+			// On regular GL, this is soft-enabled and we emit ifdefs in code.
 			require_extension_internal("GL_ARB_shader_draw_parameters");
 			return "gl_DrawIDARB";
 		}
-		return "gl_DrawID";
+		break;
 
 	case BuiltInSampleId:
 		if (options.es && options.version < 320)
@@ -8591,7 +8697,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			var->static_expression = ops[1];
 		else if (var && var->loop_variable && !var->loop_variable_enable)
 			var->static_expression = ops[1];
-		else if (var && var->remapped_variable)
+		else if (var && var->remapped_variable && var->static_expression)
 		{
 			// Skip the write.
 		}
@@ -9794,15 +9900,33 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 	}
 
 	case OpAtomicLoad:
-		flush_all_atomic_capable_variables();
-		// FIXME: Image?
-		// OpAtomicLoad seems to only be relevant for atomic counters.
+	{
+		// In plain GLSL, we have no atomic loads, so emulate this by fetch adding by 0 and hope compiler figures it out.
+		// Alternatively, we could rely on KHR_memory_model, but that's not very helpful for GL.
+		auto &type = expression_type(ops[2]);
 		forced_temporaries.insert(ops[1]);
-		GLSL_UFOP(atomicCounter);
+		bool atomic_image = check_atomic_image(ops[2]);
+		bool unsigned_type = (type.basetype == SPIRType::UInt) ||
+		                     (atomic_image && get<SPIRType>(type.image.type).basetype == SPIRType::UInt);
+		const char *op = atomic_image ? "imageAtomicAdd" : "atomicAdd";
+		const char *increment = unsigned_type ? "0u" : "0";
+		emit_op(ops[0], ops[1], join(op, "(", to_expression(ops[2]), ", ", increment, ")"), false);
+		flush_all_atomic_capable_variables();
 		break;
+	}
 
 	case OpAtomicStore:
-		SPIRV_CROSS_THROW("Unsupported opcode OpAtomicStore.");
+	{
+		// In plain GLSL, we have no atomic stores, so emulate this with an atomic exchange where we don't consume the result.
+		// Alternatively, we could rely on KHR_memory_model, but that's not very helpful for GL.
+		uint32_t ptr = ops[0];
+		// Ignore semantics for now, probably only relevant to CL.
+		uint32_t val = ops[3];
+		const char *op = check_atomic_image(ptr) ? "imageAtomicExchange" : "atomicExchange";
+		statement(op, "(", to_expression(ptr), ", ", to_expression(val), ");");
+		flush_all_atomic_capable_variables();
+		break;
+	}
 
 	case OpAtomicIIncrement:
 	case OpAtomicIDecrement:
@@ -10761,21 +10885,26 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 
 	case OpReportIntersectionNV:
 		statement("reportIntersectionNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");");
+		flush_control_dependent_expressions(current_emitting_block->self);
 		break;
 	case OpIgnoreIntersectionNV:
 		statement("ignoreIntersectionNV();");
+		flush_control_dependent_expressions(current_emitting_block->self);
 		break;
 	case OpTerminateRayNV:
 		statement("terminateRayNV();");
+		flush_control_dependent_expressions(current_emitting_block->self);
 		break;
 	case OpTraceNV:
 		statement("traceNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ",
 		          to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ",
 		          to_expression(ops[6]), ", ", to_expression(ops[7]), ", ", to_expression(ops[8]), ", ",
 		          to_expression(ops[9]), ", ", to_expression(ops[10]), ");");
+		flush_control_dependent_expressions(current_emitting_block->self);
 		break;
 	case OpExecuteCallableNV:
 		statement("executeCallableNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");");
+		flush_control_dependent_expressions(current_emitting_block->self);
 		break;
 
 	case OpConvertUToPtr:
@@ -13376,6 +13505,7 @@ void CompilerGLSL::bitcast_from_builtin_load(uint32_t source_id, std::string &ex
 	case BuiltInBaseInstance:
 	case BuiltInDrawIndex:
 	case BuiltInFragStencilRefEXT:
+	case BuiltInInstanceCustomIndexNV:
 		expected_type = SPIRType::Int;
 		break;
 
@@ -13385,6 +13515,9 @@ void CompilerGLSL::bitcast_from_builtin_load(uint32_t source_id, std::string &ex
 	case BuiltInLocalInvocationIndex:
 	case BuiltInWorkgroupSize:
 	case BuiltInNumWorkgroups:
+	case BuiltInIncomingRayFlagsNV:
+	case BuiltInLaunchIdNV:
+	case BuiltInLaunchSizeNV:
 		expected_type = SPIRType::UInt;
 		break;
 

+ 385 - 122
src/libraries/spirv_cross/spirv_hlsl.cpp

@@ -23,6 +23,41 @@ using namespace spv;
 using namespace SPIRV_CROSS_NAMESPACE;
 using namespace std;
 
+enum class ImageFormatNormalizedState
+{
+	None = 0,
+	Unorm = 1,
+	Snorm = 2
+};
+
+static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
+{
+	switch (fmt)
+	{
+	case ImageFormatR8:
+	case ImageFormatR16:
+	case ImageFormatRg8:
+	case ImageFormatRg16:
+	case ImageFormatRgba8:
+	case ImageFormatRgba16:
+	case ImageFormatRgb10A2:
+		return ImageFormatNormalizedState::Unorm;
+
+	case ImageFormatR8Snorm:
+	case ImageFormatR16Snorm:
+	case ImageFormatRg8Snorm:
+	case ImageFormatRg16Snorm:
+	case ImageFormatRgba8Snorm:
+	case ImageFormatRgba16Snorm:
+		return ImageFormatNormalizedState::Snorm;
+
+	default:
+		break;
+	}
+
+	return ImageFormatNormalizedState::None;
+}
+
 static unsigned image_format_to_components(ImageFormat fmt)
 {
 	switch (fmt)
@@ -395,7 +430,20 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
 		case SPIRType::AtomicCounter:
 			return "atomic_uint";
 		case SPIRType::Half:
-			return "min16float";
+			if (hlsl_options.enable_16bit_types)
+				return "half";
+			else
+				return "min16float";
+		case SPIRType::Short:
+			if (hlsl_options.enable_16bit_types)
+				return "int16_t";
+			else
+				return "min16int";
+		case SPIRType::UShort:
+			if (hlsl_options.enable_16bit_types)
+				return "uint16_t";
+			else
+				return "min16uint";
 		case SPIRType::Float:
 			return "float";
 		case SPIRType::Double:
@@ -423,7 +471,11 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
 		case SPIRType::UInt:
 			return join("uint", type.vecsize);
 		case SPIRType::Half:
-			return join("min16float", type.vecsize);
+			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
+		case SPIRType::Short:
+			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
+		case SPIRType::UShort:
+			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
 		case SPIRType::Float:
 			return join("float", type.vecsize);
 		case SPIRType::Double:
@@ -447,7 +499,11 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
 		case SPIRType::UInt:
 			return join("uint", type.columns, "x", type.vecsize);
 		case SPIRType::Half:
-			return join("min16float", type.columns, "x", type.vecsize);
+			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
+		case SPIRType::Short:
+			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
+		case SPIRType::UShort:
+			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
 		case SPIRType::Float:
 			return join("float", type.columns, "x", type.vecsize);
 		case SPIRType::Double:
@@ -1423,66 +1479,14 @@ void CompilerHLSL::emit_resources()
 		}
 	}
 
-	if (required_textureSizeVariants != 0)
+	emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
+	for (uint32_t norm = 0; norm < 3; norm++)
 	{
-		static const char *types[QueryTypeCount] = { "float4", "int4", "uint4" };
-		static const char *dims[QueryDimCount] = { "Texture1D",   "Texture1DArray",  "Texture2D",   "Texture2DArray",
-			                                       "Texture3D",   "Buffer",          "TextureCube", "TextureCubeArray",
-			                                       "Texture2DMS", "Texture2DMSArray" };
-
-		static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
-
-		static const char *ret_types[QueryDimCount] = {
-			"uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
-		};
-
-		static const uint32_t return_arguments[QueryDimCount] = {
-			1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
-		};
-
-		for (uint32_t index = 0; index < QueryDimCount; index++)
+		for (uint32_t comp = 0; comp < 4; comp++)
 		{
-			for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
-			{
-				uint32_t bit = 16 * type_index + index;
-				uint64_t mask = 1ull << bit;
-
-				if ((required_textureSizeVariants & mask) == 0)
-					continue;
-
-				statement(ret_types[index], " SPIRV_Cross_textureSize(", dims[index], "<", types[type_index],
-				          "> Tex, uint Level, out uint Param)");
-				begin_scope();
-				statement(ret_types[index], " ret;");
-				switch (return_arguments[index])
-				{
-				case 1:
-					if (has_lod[index])
-						statement("Tex.GetDimensions(Level, ret.x, Param);");
-					else
-					{
-						statement("Tex.GetDimensions(ret.x);");
-						statement("Param = 0u;");
-					}
-					break;
-				case 2:
-					if (has_lod[index])
-						statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
-					else
-						statement("Tex.GetDimensions(ret.x, ret.y, Param);");
-					break;
-				case 3:
-					if (has_lod[index])
-						statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
-					else
-						statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
-					break;
-				}
-
-				statement("return ret;");
-				end_scope();
-				statement("");
-			}
+			static const char *qualifiers[] = { "", "unorm ", "snorm " };
+			static const char *vecsizes[] = { "", "2", "3", "4" };
+			emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true, qualifiers[norm]);
 		}
 	}
 
@@ -1845,6 +1849,83 @@ void CompilerHLSL::emit_resources()
 	}
 }
 
+void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav, const char *type_qualifier)
+{
+	if (variant_mask == 0)
+		return;
+
+	static const char *types[QueryTypeCount] = { "float", "int", "uint" };
+	static const char *dims[QueryDimCount] = { "Texture1D",   "Texture1DArray",  "Texture2D",   "Texture2DArray",
+	                                           "Texture3D",   "Buffer",          "TextureCube", "TextureCubeArray",
+	                                           "Texture2DMS", "Texture2DMSArray" };
+
+	static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
+
+	static const char *ret_types[QueryDimCount] = {
+		"uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
+	};
+
+	static const uint32_t return_arguments[QueryDimCount] = {
+		1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
+	};
+
+	for (uint32_t index = 0; index < QueryDimCount; index++)
+	{
+		for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
+		{
+			uint32_t bit = 16 * type_index + index;
+			uint64_t mask = 1ull << bit;
+
+			if ((variant_mask & mask) == 0)
+				continue;
+
+			statement(ret_types[index], " SPIRV_Cross_", (uav ? "image" : "texture"), "Size(", (uav ? "RW" : ""),
+			          dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier,
+					  "> Tex, ", (uav ? "" : "uint Level, "), "out uint Param)");
+			begin_scope();
+			statement(ret_types[index], " ret;");
+			switch (return_arguments[index])
+			{
+			case 1:
+				if (has_lod[index] && !uav)
+					statement("Tex.GetDimensions(Level, ret.x, Param);");
+				else
+				{
+					statement("Tex.GetDimensions(ret.x);");
+					statement("Param = 0u;");
+				}
+				break;
+			case 2:
+				if (has_lod[index] && !uav)
+					statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
+				else if (!uav)
+					statement("Tex.GetDimensions(ret.x, ret.y, Param);");
+				else
+				{
+					statement("Tex.GetDimensions(ret.x, ret.y);");
+					statement("Param = 0u;");
+				}
+				break;
+			case 3:
+				if (has_lod[index] && !uav)
+					statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
+				else if (!uav)
+					statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
+				else
+				{
+					statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
+					statement("Param = 0u;");
+				}
+				break;
+			}
+
+			statement("return ret;");
+			end_scope();
+			statement("");
+		}
+	}
+}
+
 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
 {
 	auto &flags = get_member_decoration_bitset(type.self, index);
@@ -1906,7 +1987,7 @@ void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
 	if (is_uav)
 	{
 		Bitset flags = ir.get_buffer_block_flags(var);
-		bool is_readonly = flags.get(DecorationNonWritable) && !hlsl_options.force_storage_buffer_as_uav;
+		bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
 		bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
 		bool is_interlocked = interlocked_resources.count(var.self) > 0;
 		const char *type_name = "ByteAddressBuffer ";
@@ -3038,7 +3119,7 @@ string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
 			if (has_decoration(type.self, DecorationBufferBlock))
 			{
 				Bitset flags = ir.get_buffer_block_flags(var);
-				bool is_readonly = flags.get(DecorationNonWritable) && !hlsl_options.force_storage_buffer_as_uav;
+				bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
 				space = is_readonly ? 't' : 'u'; // UAV
 				resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
 			}
@@ -3057,7 +3138,7 @@ string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
 		{
 			// UAV or SRV depending on readonly flag.
 			Bitset flags = ir.get_buffer_block_flags(var);
-			bool is_readonly = flags.get(DecorationNonWritable) && !hlsl_options.force_storage_buffer_as_uav;
+			bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
 			space = is_readonly ? 't' : 'u';
 			resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
 		}
@@ -3587,11 +3668,16 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 		read_access_chain_struct(lhs, chain);
 		return;
 	}
-	else if (type.width != 32)
-		SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported.");
+	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
+		SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and native 16-bit types are enabled.");
 
+	bool templated_load = hlsl_options.shader_model >= 62;
 	string load_expr;
 
+	string template_expr;
+	if (templated_load)
+		template_expr = join("<", type_to_glsl(type), ">");
+
 	// Load a vector or scalar.
 	if (type.columns == 1 && !chain.row_major_matrix)
 	{
@@ -3614,12 +3700,24 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 			SPIRV_CROSS_THROW("Unknown vector size.");
 		}
 
-		load_expr = join(chain.base, ".", load_op, "(", chain.dynamic_index, chain.static_index, ")");
+		if (templated_load)
+			load_op = "Load";
+
+		load_expr = join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
 	}
 	else if (type.columns == 1)
 	{
 		// Strided load since we are loading a column from a row-major matrix.
-		if (type.vecsize > 1)
+		if (templated_load)
+		{
+			auto scalar_type = type;
+			scalar_type.vecsize = 1;
+			scalar_type.columns = 1;
+			template_expr = join("<", type_to_glsl(scalar_type), ">");
+			if (type.vecsize > 1)
+				load_expr += type_to_glsl(type) + "(";
+		}
+		else if (type.vecsize > 1)
 		{
 			load_expr = type_to_glsl(target_type);
 			load_expr += "(";
@@ -3628,7 +3726,7 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 		for (uint32_t r = 0; r < type.vecsize; r++)
 		{
 			load_expr +=
-			    join(chain.base, ".Load(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")");
+			    join(chain.base, ".Load", template_expr, "(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")");
 			if (r + 1 < type.vecsize)
 				load_expr += ", ";
 		}
@@ -3658,13 +3756,25 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 			SPIRV_CROSS_THROW("Unknown vector size.");
 		}
 
-		// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
-		// so row-major is technically column-major ...
-		load_expr = type_to_glsl(target_type);
+		if (templated_load)
+		{
+			auto vector_type = type;
+			vector_type.columns = 1;
+			template_expr = join("<", type_to_glsl(vector_type), ">");
+			load_expr = type_to_glsl(type);
+			load_op = "Load";
+		}
+		else
+		{
+			// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
+			// so row-major is technically column-major ...
+			load_expr = type_to_glsl(target_type);
+		}
 		load_expr += "(";
+
 		for (uint32_t c = 0; c < type.columns; c++)
 		{
-			load_expr += join(chain.base, ".", load_op, "(", chain.dynamic_index,
+			load_expr += join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index,
 			                  chain.static_index + c * chain.matrix_stride, ")");
 			if (c + 1 < type.columns)
 				load_expr += ", ";
@@ -3676,13 +3786,24 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 		// Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
 		// considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
 
-		load_expr = type_to_glsl(target_type);
+		if (templated_load)
+		{
+			load_expr = type_to_glsl(type);
+			auto scalar_type = type;
+			scalar_type.vecsize = 1;
+			scalar_type.columns = 1;
+			template_expr = join("<", type_to_glsl(scalar_type), ">");
+		}
+		else
+			load_expr = type_to_glsl(target_type);
+
 		load_expr += "(";
+
 		for (uint32_t c = 0; c < type.columns; c++)
 		{
 			for (uint32_t r = 0; r < type.vecsize; r++)
 			{
-				load_expr += join(chain.base, ".Load(", chain.dynamic_index,
+				load_expr += join(chain.base, ".Load", template_expr, "(", chain.dynamic_index,
 				                  chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
 
 				if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
@@ -3692,9 +3813,12 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
 		load_expr += ")";
 	}
 
-	auto bitcast_op = bitcast_glsl_op(type, target_type);
-	if (!bitcast_op.empty())
-		load_expr = join(bitcast_op, "(", load_expr, ")");
+	if (!templated_load)
+	{
+		auto bitcast_op = bitcast_glsl_op(type, target_type);
+		if (!bitcast_op.empty())
+			load_expr = join(bitcast_op, "(", load_expr, ")");
+	}
 
 	if (lhs.empty())
 	{
@@ -3877,8 +4001,14 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
 		register_write(chain.self);
 		return;
 	}
-	else if (type.width != 32)
-		SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported.");
+	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
+		SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and native 16-bit types are enabled.");
+
+	bool templated_store = hlsl_options.shader_model >= 62;
+
+	string template_expr;
+	if (templated_store)
+		template_expr = join("<", type_to_glsl(type), ">");
 
 	if (type.columns == 1 && !chain.row_major_matrix)
 	{
@@ -3902,13 +4032,27 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
 		}
 
 		auto store_expr = write_access_chain_value(value, composite_chain, false);
-		auto bitcast_op = bitcast_glsl_op(target_type, type);
-		if (!bitcast_op.empty())
-			store_expr = join(bitcast_op, "(", store_expr, ")");
-		statement(chain.base, ".", store_op, "(", chain.dynamic_index, chain.static_index, ", ", store_expr, ");");
+
+		if (!templated_store)
+		{
+			auto bitcast_op = bitcast_glsl_op(target_type, type);
+			if (!bitcast_op.empty())
+				store_expr = join(bitcast_op, "(", store_expr, ")");
+		}
+		else
+			store_op = "Store";
+		statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ", store_expr, ");");
 	}
 	else if (type.columns == 1)
 	{
+		if (templated_store)
+		{
+			auto scalar_type = type;
+			scalar_type.vecsize = 1;
+			scalar_type.columns = 1;
+			template_expr = join("<", type_to_glsl(scalar_type), ">");
+		}
+
 		// Strided store.
 		for (uint32_t r = 0; r < type.vecsize; r++)
 		{
@@ -3920,10 +4064,14 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
 			}
 			remove_duplicate_swizzle(store_expr);
 
-			auto bitcast_op = bitcast_glsl_op(target_type, type);
-			if (!bitcast_op.empty())
-				store_expr = join(bitcast_op, "(", store_expr, ")");
-			statement(chain.base, ".Store(", chain.dynamic_index, chain.static_index + chain.matrix_stride * r, ", ",
+			if (!templated_store)
+			{
+				auto bitcast_op = bitcast_glsl_op(target_type, type);
+				if (!bitcast_op.empty())
+					store_expr = join(bitcast_op, "(", store_expr, ")");
+			}
+
+			statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index, chain.static_index + chain.matrix_stride * r, ", ",
 			          store_expr, ");");
 		}
 	}
@@ -3948,18 +4096,39 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
 			SPIRV_CROSS_THROW("Unknown vector size.");
 		}
 
+		if (templated_store)
+		{
+			store_op = "Store";
+			auto vector_type = type;
+			vector_type.columns = 1;
+			template_expr = join("<", type_to_glsl(vector_type), ">");
+		}
+
 		for (uint32_t c = 0; c < type.columns; c++)
 		{
 			auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
-			auto bitcast_op = bitcast_glsl_op(target_type, type);
-			if (!bitcast_op.empty())
-				store_expr = join(bitcast_op, "(", store_expr, ")");
-			statement(chain.base, ".", store_op, "(", chain.dynamic_index, chain.static_index + c * chain.matrix_stride,
+
+			if (!templated_store)
+			{
+				auto bitcast_op = bitcast_glsl_op(target_type, type);
+				if (!bitcast_op.empty())
+					store_expr = join(bitcast_op, "(", store_expr, ")");
+			}
+
+			statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index + c * chain.matrix_stride,
 			          ", ", store_expr, ");");
 		}
 	}
 	else
 	{
+		if (templated_store)
+		{
+			auto scalar_type = type;
+			scalar_type.vecsize = 1;
+			scalar_type.columns = 1;
+			template_expr = join("<", type_to_glsl(scalar_type), ">");
+		}
+
 		for (uint32_t r = 0; r < type.vecsize; r++)
 		{
 			for (uint32_t c = 0; c < type.columns; c++)
@@ -3970,7 +4139,7 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
 				auto bitcast_op = bitcast_glsl_op(target_type, type);
 				if (!bitcast_op.empty())
 					store_expr = join(bitcast_op, "(", store_expr, ")");
-				statement(chain.base, ".Store(", chain.dynamic_index,
+				statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index,
 				          chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
 			}
 		}
@@ -4087,9 +4256,11 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
 	const char *atomic_op = nullptr;
 
 	string value_expr;
-	if (op != OpAtomicIDecrement && op != OpAtomicIIncrement)
+	if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
 		value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
 
+	bool is_atomic_store = false;
+
 	switch (op)
 	{
 	case OpAtomicIIncrement:
@@ -4102,6 +4273,11 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
 		value_expr = "-1";
 		break;
 
+	case OpAtomicLoad:
+		atomic_op = "InterlockedAdd";
+		value_expr = "0";
+		break;
+
 	case OpAtomicISub:
 		atomic_op = "InterlockedAdd";
 		value_expr = join("-", enclose_expression(value_expr));
@@ -4137,6 +4313,11 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
 		atomic_op = "InterlockedExchange";
 		break;
 
+	case OpAtomicStore:
+		atomic_op = "InterlockedExchange";
+		is_atomic_store = true;
+		break;
+
 	case OpAtomicCompareExchange:
 		if (length < 8)
 			SPIRV_CROSS_THROW("Not enough data for opcode.");
@@ -4148,31 +4329,57 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
 		SPIRV_CROSS_THROW("Unknown atomic opcode.");
 	}
 
-	uint32_t result_type = ops[0];
-	uint32_t id = ops[1];
-	forced_temporaries.insert(ops[1]);
+	if (is_atomic_store)
+	{
+		auto &data_type = expression_type(ops[0]);
+		auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
 
-	auto &type = get<SPIRType>(result_type);
-	statement(variable_decl(type, to_name(id)), ";");
+		auto &tmp_id = extra_sub_expressions[ops[0]];
+		if (!tmp_id)
+		{
+			tmp_id = ir.increase_bound_by(1);
+			emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
+		}
 
-	auto &data_type = expression_type(ops[2]);
-	auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
-	SPIRType::BaseType expr_type;
-	if (data_type.storage == StorageClassImage || !chain)
-	{
-		statement(atomic_op, "(", to_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
-		expr_type = data_type.basetype;
+		if (data_type.storage == StorageClassImage || !chain)
+		{
+			statement(atomic_op, "(", to_expression(ops[0]), ", ", to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
+		}
+		else
+		{
+			// RWByteAddress buffer is always uint in its underlying type.
+			statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", to_expression(ops[3]),
+			          ", ", to_expression(tmp_id), ");");
+		}
 	}
 	else
 	{
-		// RWByteAddress buffer is always uint in its underlying type.
-		expr_type = SPIRType::UInt;
-		statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr, ", ",
-		          to_name(id), ");");
-	}
+		uint32_t result_type = ops[0];
+		uint32_t id = ops[1];
+		forced_temporaries.insert(ops[1]);
+
+		auto &type = get<SPIRType>(result_type);
+		statement(variable_decl(type, to_name(id)), ";");
+
+		auto &data_type = expression_type(ops[2]);
+		auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
+		SPIRType::BaseType expr_type;
+		if (data_type.storage == StorageClassImage || !chain)
+		{
+			statement(atomic_op, "(", to_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
+			expr_type = data_type.basetype;
+		}
+		else
+		{
+			// RWByteAddress buffer is always uint in its underlying type.
+			expr_type = SPIRType::UInt;
+			statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
+			          ", ", to_name(id), ");");
+		}
 
-	auto expr = bitcast_expression(type, expr_type, to_name(id));
-	set<SPIRExpression>(id, expr, result_type, true);
+		auto expr = bitcast_expression(type, expr_type, to_name(id));
+		set<SPIRExpression>(id, expr, result_type, true);
+	}
 	flush_all_atomic_capable_variables();
 }
 
@@ -4796,8 +5003,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 		auto result_type = ops[0];
 		auto id = ops[1];
 
-		require_texture_query_variant(expression_type(ops[2]));
-
+		require_texture_query_variant(ops[2]);
 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
 		statement("uint ", dummy_samples_levels, ";");
 
@@ -4815,12 +5021,22 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 		auto result_type = ops[0];
 		auto id = ops[1];
 
-		require_texture_query_variant(expression_type(ops[2]));
+		require_texture_query_variant(ops[2]);
+		bool uav = expression_type(ops[2]).image.sampled == 2;
+
+		if (const auto *var = maybe_get_backing_variable(ops[2]))
+			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
+				uav = false;
 
 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
 		statement("uint ", dummy_samples_levels, ";");
 
-		auto expr = join("SPIRV_Cross_textureSize(", to_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
+		string expr;
+		if (uav)
+			expr = join("SPIRV_Cross_imageSize(", to_expression(ops[2]), ", ", dummy_samples_levels, ")");
+		else
+			expr = join("SPIRV_Cross_textureSize(", to_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
+
 		auto &restype = get<SPIRType>(ops[0]);
 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
 		emit_op(result_type, id, expr, true);
@@ -4833,14 +5049,25 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 		auto result_type = ops[0];
 		auto id = ops[1];
 
-		require_texture_query_variant(expression_type(ops[2]));
+		require_texture_query_variant(ops[2]);
+		bool uav = expression_type(ops[2]).image.sampled == 2;
+		if (opcode == OpImageQueryLevels && uav)
+			SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
+
+		if (const auto *var = maybe_get_backing_variable(ops[2]))
+			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
+				uav = false;
 
 		// Keep it simple and do not emit special variants to make this look nicer ...
 		// This stuff is barely, if ever, used.
 		forced_temporaries.insert(id);
 		auto &type = get<SPIRType>(result_type);
 		statement(variable_decl(type, to_name(id)), ";");
-		statement("SPIRV_Cross_textureSize(", to_expression(ops[2]), ", 0u, ", to_name(id), ");");
+
+		if (uav)
+			statement("SPIRV_Cross_imageSize(", to_expression(ops[2]), ", ", to_name(id), ");");
+		else
+			statement("SPIRV_Cross_textureSize(", to_expression(ops[2]), ", 0u, ", to_name(id), ");");
 
 		auto &restype = get<SPIRType>(ops[0]);
 		auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
@@ -4967,6 +5194,8 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 	case OpAtomicIAdd:
 	case OpAtomicIIncrement:
 	case OpAtomicIDecrement:
+	case OpAtomicLoad:
+	case OpAtomicStore:
 	{
 		emit_atomic(ops, instruction.length, opcode);
 		break;
@@ -5152,8 +5381,16 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 	}
 }
 
-void CompilerHLSL::require_texture_query_variant(const SPIRType &type)
+void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
 {
+	if (const auto *var = maybe_get_backing_variable(var_id))
+		var_id = var->self;
+
+	auto &type = expression_type(var_id);
+	bool uav = type.image.sampled == 2;
+	if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
+		uav = false;
+
 	uint32_t bit = 0;
 	switch (type.image.dim)
 	{
@@ -5202,11 +5439,15 @@ void CompilerHLSL::require_texture_query_variant(const SPIRType &type)
 		SPIRV_CROSS_THROW("Unsupported query type.");
 	}
 
+	auto norm_state = image_format_to_normalized_state(type.image.format);
+	auto &variant = uav ? required_texture_size_variants.uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
+	                required_texture_size_variants.srv;
+
 	uint64_t mask = 1ull << bit;
-	if ((required_textureSizeVariants & mask) == 0)
+	if ((variant & mask) == 0)
 	{
 		force_recompile();
-		required_textureSizeVariants |= mask;
+		variant |= mask;
 	}
 }
 
@@ -5291,6 +5532,9 @@ void CompilerHLSL::validate_shader_model()
 
 	if (ir.addressing_model != AddressingModelLogical)
 		SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
+
+	if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
+		SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
 }
 
 string CompilerHLSL::compile()
@@ -5413,3 +5657,22 @@ CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, u
 
 	return BitcastType::TypeNormal;
 }
+
+bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
+{
+	if (hlsl_options.force_storage_buffer_as_uav)
+	{
+		return true;
+	}
+
+	const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
+	const uint32_t binding = get_decoration(id, spv::DecorationBinding);
+	
+	return (force_uav_buffer_bindings.find({desc_set, binding}) != force_uav_buffer_bindings.end());
+}
+
+void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
+{
+	SetBindingPair pair = { desc_set, binding };
+	force_uav_buffer_bindings.insert(pair);
+}

+ 30 - 2
src/libraries/spirv_cross/spirv_hlsl.hpp

@@ -112,12 +112,18 @@ public:
 
 		// Forces a storage buffer to always be declared as UAV, even if the readonly decoration is used.
 		// By default, a readonly storage buffer will be declared as ByteAddressBuffer (SRV) instead.
+		// Alternatively, use set_hlsl_force_storage_buffer_as_uav to specify individually.
 		bool force_storage_buffer_as_uav = false;
 
 		// Forces any storage image type marked as NonWritable to be considered an SRV instead.
 		// For this to work with function call parameters, NonWritable must be considered to be part of the type system
 		// so that NonWritable image arguments are also translated to Texture rather than RWTexture.
 		bool nonwritable_uav_texture_as_srv = false;
+
+		// Enables native 16-bit types. Needs SM 6.2.
+		// Uses half/int16_t/uint16_t instead of min16* types.
+		// Also adds support for 16-bit load-store from (RW)ByteAddressBuffer.
+		bool enable_16bit_types = false;
 	};
 
 	explicit CompilerHLSL(std::vector<uint32_t> spirv_)
@@ -186,6 +192,9 @@ public:
 	void add_hlsl_resource_binding(const HLSLResourceBinding &resource);
 	bool is_hlsl_resource_binding_used(spv::ExecutionModel model, uint32_t set, uint32_t binding) const;
 
+	// Controls which storage buffer bindings will be forced to be declared as UAVs.
+	void set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding);
+
 private:
 	std::string type_to_glsl(const SPIRType &type, uint32_t id = 0) override;
 	std::string image_type_hlsl(const SPIRType &type, uint32_t id);
@@ -245,6 +254,8 @@ private:
 	const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override;
 	void replace_illegal_names() override;
 
+	bool is_hlsl_force_storage_buffer_as_uav(ID id) const;
+
 	Options hlsl_options;
 
 	// TODO: Refactor this to be more similar to MSL, maybe have some common system in place?
@@ -264,8 +275,23 @@ private:
 	bool requires_scalar_reflect = false;
 	bool requires_scalar_refract = false;
 	bool requires_scalar_faceforward = false;
-	uint64_t required_textureSizeVariants = 0;
-	void require_texture_query_variant(const SPIRType &type);
+
+	struct TextureSizeVariants
+	{
+		// MSVC 2013 workaround.
+		TextureSizeVariants()
+		{
+			srv = 0;
+			for (auto &unorm : uav)
+				for (auto &u : unorm)
+					u = 0;
+		}
+		uint64_t srv;
+		uint64_t uav[3][4];
+	} required_texture_size_variants;
+
+	void require_texture_query_variant(uint32_t var_id);
+	void emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav, const char *type_qualifier);
 
 	enum TextureQueryVariantDim
 	{
@@ -323,6 +349,8 @@ private:
 
 	std::unordered_map<StageSetBinding, std::pair<HLSLResourceBinding, bool>, InternalHasher> resource_bindings;
 	void remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding);
+
+	std::unordered_set<SetBindingPair, InternalHasher> force_uav_buffer_bindings;
 };
 } // namespace SPIRV_CROSS_NAMESPACE
 

+ 64 - 13
src/libraries/spirv_cross/spirv_msl.cpp

@@ -3087,11 +3087,11 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32
 	{
 		// If we have an array type, array stride must match exactly with SPIR-V.
 
-		// An exception to this requirement is if we have one array element and a packed decoration.
+		// An exception to this requirement is if we have one array element.
 		// This comes from DX scalar layout workaround.
 		// If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
-		bool relax_array_stride = has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypePacked) &&
-		                          mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
+		// In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
+		bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
 
 		if (!relax_array_stride)
 		{
@@ -3137,7 +3137,9 @@ void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t in
 		SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
 
 	// Perform remapping here.
-	set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
+	// There is nothing to be gained by using packed scalars, so don't attempt it.
+	if (!is_scalar(ib_type))
+		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
 
 	// Try validating again, now with packed.
 	if (validate_member_packing_rules_msl(ib_type, index))
@@ -7038,7 +7040,11 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 			exp += string(", ") + get_memory_order(mem_order_2);
 
 		exp += ")";
-		emit_op(result_type, result_id, exp, false);
+
+		if (strcmp(op, "atomic_store_explicit") != 0)
+			emit_op(result_type, result_id, exp, false);
+		else
+			statement(exp, ";");
 	}
 
 	flush_all_atomic_capable_variables();
@@ -8662,7 +8668,7 @@ string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_
 			td_line += ";";
 			add_typedef_line(td_line);
 		}
-		else
+		else if (!is_scalar(physical_type)) // scalar type is already packed.
 			pack_pfx = "packed_";
 	}
 	else if (row_major)
@@ -9185,7 +9191,14 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 				addr_space = "constant";
 		}
 		else if (!argument)
+		{
 			addr_space = "constant";
+		}
+		else if (type_is_msl_framebuffer_fetch(type))
+		{
+			// Subpass inputs are passed around by value.
+			addr_space = "";
+		}
 		break;
 
 	case StorageClassFunction:
@@ -9626,8 +9639,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
 
 			// Use Metal's native frame-buffer fetch API for subpass inputs.
 			const auto &basetype = get<SPIRType>(var.basetype);
-			if (basetype.image.dim != DimSubpassData || !msl_options.is_ios() ||
-			    !msl_options.ios_use_framebuffer_fetch_subpasses)
+			if (!type_is_msl_framebuffer_fetch(basetype))
 			{
 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
 				if (r.plane > 0)
@@ -9639,7 +9651,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
 			}
 			else
 			{
-				ep_args += image_type_glsl(type, var_id) + "4 " + r.name;
+				ep_args += image_type_glsl(type, var_id) + " " + r.name;
 				ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
 			}
 
@@ -10125,6 +10137,12 @@ uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::Base
 	return resource_index;
 }
 
+bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
+{
+	return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
+	       msl_options.is_ios() && msl_options.ios_use_framebuffer_fetch_subpasses;
+}
+
 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 {
 	auto &var = get<SPIRVariable>(arg.id);
@@ -10140,6 +10158,9 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 		name_id = var.basevariable;
 
 	bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
+	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
+	if (type_is_msl_framebuffer_fetch(type))
+		constref = false;
 
 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
 	                     type.basetype == SPIRType::Sampler;
@@ -10638,6 +10659,9 @@ void CompilerMSL::replace_illegal_names()
 
 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
 {
+	if (index < uint32_t(type.member_type_index_redirection.size()))
+		index = type.member_type_index_redirection[index];
+
 	auto *var = maybe_get<SPIRVariable>(base);
 	// If this is a buffer array, we have to dereference the buffer pointers.
 	// Otherwise, if this is a pointer expression, dereference it.
@@ -10983,10 +11007,11 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
 			}
 
 			// Use Metal's native frame-buffer fetch API for subpass inputs.
-			if (img_type.dim == DimSubpassData && msl_options.is_ios() &&
-			    msl_options.ios_use_framebuffer_fetch_subpasses)
+			if (type_is_msl_framebuffer_fetch(type))
 			{
-				return type_to_glsl(get<SPIRType>(img_type.type));
+				auto img_type_4 = get<SPIRType>(img_type.type);
+				img_type_4.vecsize = 4;
+				return type_to_glsl(img_type_4);
 			}
 			if (img_type.ms && img_type.arrayed)
 			{
@@ -12437,9 +12462,28 @@ void CompilerMSL::MemberSorter::sort()
 	// the members should be reordered, based on builtin and sorting aspect meta info.
 	size_t mbr_cnt = type.member_types.size();
 	SmallVector<uint32_t> mbr_idxs(mbr_cnt);
-	iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
+	std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
 	std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
 
+	bool sort_is_identity = true;
+	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
+	{
+		if (mbr_idx != mbr_idxs[mbr_idx])
+		{
+			sort_is_identity = false;
+			break;
+		}
+	}
+
+	if (sort_is_identity)
+		return;
+
+	if (meta.members.size() < type.member_types.size())
+	{
+		// This should never trigger in normal circumstances, but to be safe.
+		meta.members.resize(type.member_types.size());
+	}
+
 	// Move type and meta member info to the order defined by the sorted member indices.
 	// This is done by creating temporary copies of both member types and meta, and then
 	// copying back to the original content at the sorted indices.
@@ -12450,6 +12494,13 @@ void CompilerMSL::MemberSorter::sort()
 		type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
 	}
+
+	if (sort_aspect == SortAspect::Offset)
+	{
+		// If we're sorting by Offset, this might affect user code which accesses a buffer block.
+		// We will need to redirect member indices from one index to sorted index.
+		type.member_type_index_redirection = std::move(mbr_idxs);
+	}
 }
 
 // Sort first by builtin status (put builtins at end), then by the sorting aspect.

+ 4 - 6
src/libraries/spirv_cross/spirv_msl.hpp

@@ -43,10 +43,6 @@ enum MSLVertexFormat
 struct MSLVertexAttr
 {
 	uint32_t location = 0;
-	uint32_t msl_buffer = 0;
-	uint32_t msl_offset = 0;
-	uint32_t msl_stride = 0;
-	bool per_instance = false;
 	MSLVertexFormat format = MSL_VERTEX_FORMAT_OTHER;
 	spv::BuiltIn builtin = spv::BuiltInMax;
 };
@@ -324,12 +320,12 @@ public:
 		// can be read in subsequent stages.
 		bool enable_clip_distance_user_varying = true;
 
-		bool is_ios()
+		bool is_ios() const
 		{
 			return platform == iOS;
 		}
 
-		bool is_macos()
+		bool is_macos() const
 		{
 			return platform == macOS;
 		}
@@ -898,6 +894,8 @@ protected:
 
 	void activate_argument_buffer_resources();
 
+	bool type_is_msl_framebuffer_fetch(const SPIRType &type) const;
+
 	// OpcodeHandler that handles several MSL preprocessing operations.
 	struct OpCodePreprocessor : OpcodeHandler
 	{

+ 24 - 0
src/libraries/spirv_cross/spirv_parser.cpp

@@ -119,6 +119,16 @@ void Parser::parse()
 	for (auto &i : instructions)
 		parse(i);
 
+	for (auto &fixup : forward_pointer_fixups)
+	{
+		auto &target = get<SPIRType>(fixup.first);
+		auto &source = get<SPIRType>(fixup.second);
+		target.member_types = source.member_types;
+		target.basetype = source.basetype;
+		target.self = source.self;
+	}
+	forward_pointer_fixups.clear();
+
 	if (current_function)
 		SPIRV_CROSS_THROW("Function was not terminated.");
 	if (current_block)
@@ -543,6 +553,11 @@ void Parser::parse(const Instruction &instruction)
 		auto *c = maybe_get<SPIRConstant>(cid);
 		bool literal = c && !c->specialization;
 
+		// We're copying type information into Array types, so we'll need a fixup for any physical pointer
+		// references.
+		if (base.forward_pointer)
+			forward_pointer_fixups.push_back({ id, tid });
+
 		arraybase.array_size_literal.push_back(literal);
 		arraybase.array.push_back(literal ? c->scalar() : cid);
 		// Do NOT set arraybase.self!
@@ -556,6 +571,11 @@ void Parser::parse(const Instruction &instruction)
 		auto &base = get<SPIRType>(ops[1]);
 		auto &arraybase = set<SPIRType>(id);
 
+		// We're copying type information into Array types, so we'll need a fixup for any physical pointer
+		// references.
+		if (base.forward_pointer)
+			forward_pointer_fixups.push_back({ id, ops[1] });
+
 		arraybase = base;
 		arraybase.array.push_back(0);
 		arraybase.array_size_literal.push_back(true);
@@ -614,6 +634,9 @@ void Parser::parse(const Instruction &instruction)
 		if (ptrbase.storage == StorageClassAtomicCounter)
 			ptrbase.basetype = SPIRType::AtomicCounter;
 
+		if (base.forward_pointer)
+			forward_pointer_fixups.push_back({ id, ops[2] });
+
 		ptrbase.parent_type = ops[2];
 
 		// Do NOT set ptrbase.self!
@@ -627,6 +650,7 @@ void Parser::parse(const Instruction &instruction)
 		ptrbase.pointer = true;
 		ptrbase.pointer_depth++;
 		ptrbase.storage = static_cast<StorageClass>(ops[1]);
+		ptrbase.forward_pointer = true;
 
 		if (ptrbase.storage == StorageClassAtomicCounter)
 			ptrbase.basetype = SPIRType::AtomicCounter;

+ 1 - 0
src/libraries/spirv_cross/spirv_parser.hpp

@@ -84,6 +84,7 @@ private:
 
 	// This must be an ordered data structure so we always pick the same type aliases.
 	SmallVector<uint32_t> global_struct_cache;
+	SmallVector<std::pair<uint32_t, uint32_t>> forward_pointer_fixups;
 
 	bool types_are_logically_equivalent(const SPIRType &a, const SPIRType &b) const;
 	bool variable_storage_is_aliased(const SPIRVariable &v) const;

+ 78 - 23
src/libraries/spirv_cross/spirv_reflect.cpp

@@ -277,23 +277,54 @@ string CompilerReflection::compile()
 	return json_stream->str();
 }
 
+static bool naturally_emit_type(const SPIRType &type)
+{
+	return type.basetype == SPIRType::Struct && !type.pointer && type.array.empty();
+}
+
+bool CompilerReflection::type_is_reference(const SPIRType &type) const
+{
+	// Physical pointers and arrays of physical pointers need to refer to the pointee's type.
+	return type_is_top_level_physical_pointer(type) ||
+	       (!type.array.empty() && type_is_top_level_physical_pointer(get<SPIRType>(type.parent_type)));
+}
+
 void CompilerReflection::emit_types()
 {
 	bool emitted_open_tag = false;
 
-	ir.for_each_typed_id<SPIRType>([&](uint32_t, SPIRType &type) {
-		if (type.basetype == SPIRType::Struct && !type.pointer && type.array.empty())
-			emit_type(type, emitted_open_tag);
+	SmallVector<uint32_t> physical_pointee_types;
+
+	// If we have physical pointers or arrays of physical pointers, it's also helpful to emit the pointee type
+	// and chain the type hierarchy. For POD, arrays can emit the entire type in-place.
+	ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &type) {
+		if (naturally_emit_type(type))
+		{
+			emit_type(self, emitted_open_tag);
+		}
+		else if (type_is_reference(type))
+		{
+			if (!naturally_emit_type(this->get<SPIRType>(type.parent_type)) &&
+			    find(physical_pointee_types.begin(), physical_pointee_types.end(),
+			         type.parent_type) == physical_pointee_types.end())
+			{
+				physical_pointee_types.push_back(type.parent_type);
+			}
+		}
 	});
 
+	for (uint32_t pointee_type : physical_pointee_types)
+		emit_type(pointee_type, emitted_open_tag);
+
 	if (emitted_open_tag)
 	{
 		json_stream->end_json_object();
 	}
 }
 
-void CompilerReflection::emit_type(const SPIRType &type, bool &emitted_open_tag)
+void CompilerReflection::emit_type(uint32_t type_id, bool &emitted_open_tag)
 {
+	auto &type = get<SPIRType>(type_id);
 	auto name = type_to_glsl(type);
 
 	if (type.type_alias != TypeID(0))
@@ -304,26 +335,42 @@ void CompilerReflection::emit_type(const SPIRType &type, bool &emitted_open_tag)
 		json_stream->emit_json_key_object("types");
 		emitted_open_tag = true;
 	}
-	json_stream->emit_json_key_object("_" + std::to_string(type.self));
+	json_stream->emit_json_key_object("_" + std::to_string(type_id));
 	json_stream->emit_json_key_value("name", name);
-	json_stream->emit_json_key_array("members");
-	// FIXME ideally we'd like to emit the size of a structure as a
-	// convenience to people parsing the reflected JSON.  The problem
-	// is that there's no implicit size for a type.  It's final size
-	// will be determined by the top level declaration in which it's
-	// included.  So there might be one size for the struct if it's
-	// included in a std140 uniform block and another if it's included
-	// in a std430 uniform block.
-	// The solution is to include *all* potential sizes as a map of
-	// layout type name to integer, but that will probably require
-	// some additional logic being written in this class, or in the
-	// parent CompilerGLSL class.
-	auto size = type.member_types.size();
-	for (uint32_t i = 0; i < size; ++i)
+
+	if (type_is_top_level_physical_pointer(type))
 	{
-		emit_type_member(type, i);
+		json_stream->emit_json_key_value("type", "_" + std::to_string(type.parent_type));
+		json_stream->emit_json_key_value("physical_pointer", true);
 	}
-	json_stream->end_json_array();
+	else if (!type.array.empty())
+	{
+		emit_type_array(type);
+		json_stream->emit_json_key_value("type", "_" + std::to_string(type.parent_type));
+		json_stream->emit_json_key_value("array_stride", get_decoration(type_id, DecorationArrayStride));
+	}
+	else
+	{
+		json_stream->emit_json_key_array("members");
+		// FIXME ideally we'd like to emit the size of a structure as a
+		// convenience to people parsing the reflected JSON.  The problem
+		// is that there's no implicit size for a type.  It's final size
+		// will be determined by the top level declaration in which it's
+		// included.  So there might be one size for the struct if it's
+		// included in a std140 uniform block and another if it's included
+		// in a std430 uniform block.
+		// The solution is to include *all* potential sizes as a map of
+		// layout type name to integer, but that will probably require
+		// some additional logic being written in this class, or in the
+		// parent CompilerGLSL class.
+		auto size = type.member_types.size();
+		for (uint32_t i = 0; i < size; ++i)
+		{
+			emit_type_member(type, i);
+		}
+		json_stream->end_json_array();
+	}
+
 	json_stream->end_json_object();
 }
 
@@ -335,7 +382,12 @@ void CompilerReflection::emit_type_member(const SPIRType &type, uint32_t index)
 	// FIXME we'd like to emit the offset of each member, but such offsets are
 	// context dependent.  See the comment above regarding structure sizes
 	json_stream->emit_json_key_value("name", name);
-	if (membertype.basetype == SPIRType::Struct)
+
+	if (type_is_reference(membertype))
+	{
+		json_stream->emit_json_key_value("type", "_" + std::to_string(membertype.parent_type));
+	}
+	else if (membertype.basetype == SPIRType::Struct)
 	{
 		json_stream->emit_json_key_value("type", "_" + std::to_string(membertype.self));
 	}
@@ -349,7 +401,7 @@ void CompilerReflection::emit_type_member(const SPIRType &type, uint32_t index)
 
 void CompilerReflection::emit_type_array(const SPIRType &type)
 {
-	if (!type.array.empty())
+	if (!type_is_top_level_physical_pointer(type) && !type.array.empty())
 	{
 		json_stream->emit_json_key_array("array");
 		// Note that we emit the zeros here as a means of identifying
@@ -388,6 +440,9 @@ void CompilerReflection::emit_type_member_qualifiers(const SPIRType &type, uint3
 			json_stream->emit_json_key_value("matrix_stride", dec.matrix_stride);
 		if (dec.decoration_flags.get(DecorationRowMajor))
 			json_stream->emit_json_key_value("row_major", true);
+
+		if (type_is_top_level_physical_pointer(membertype))
+			json_stream->emit_json_key_value("physical_pointer", true);
 	}
 }
 

+ 2 - 1
src/libraries/spirv_cross/spirv_reflect.hpp

@@ -67,11 +67,12 @@ private:
 	void emit_resources();
 	void emit_specialization_constants();
 
-	void emit_type(const SPIRType &type, bool &emitted_open_tag);
+	void emit_type(uint32_t type_id, bool &emitted_open_tag);
 	void emit_type_member(const SPIRType &type, uint32_t index);
 	void emit_type_member_qualifiers(const SPIRType &type, uint32_t index);
 	void emit_type_array(const SPIRType &type);
 	void emit_resources(const char *tag, const SmallVector<Resource> &resources);
+	bool type_is_reference(const SPIRType &type) const;
 
 	std::string to_member_name(const SPIRType &type, uint32_t index) const;