Explorar el Código

Updated spirv-cross.

Бранимир Караџић hace 1 año
padre
commit
ec4220ae44

+ 5 - 0
3rdparty/spirv-cross/spirv_cross.cpp

@@ -1850,6 +1850,11 @@ const SmallVector<SPIRBlock::Case> &Compiler::get_case_list(const SPIRBlock &blo
 		const auto &type = get<SPIRType>(constant->constant_type);
 		width = type.width;
 	}
+	else if (const auto *op = maybe_get<SPIRConstantOp>(block.condition))
+	{
+		const auto &type = get<SPIRType>(op->basetype);
+		width = type.width;
+	}
 	else if (const auto *var = maybe_get<SPIRVariable>(block.condition))
 	{
 		const auto &type = get<SPIRType>(var->basetype);

+ 48 - 0
3rdparty/spirv-cross/spirv_cross_c.cpp

@@ -516,6 +516,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
 	case SPVC_COMPILER_OPTION_HLSL_FLATTEN_MATRIX_VERTEX_INPUT_SEMANTICS:
 		options->hlsl.flatten_matrix_vertex_input_semantics = value != 0;
 		break;
+
+	case SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME:
+		options->hlsl.use_entry_point_name = value != 0;
+		break;
 #endif
 
 #if SPIRV_CROSS_C_API_MSL
@@ -1355,6 +1359,34 @@ spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
 #endif
 }
 
+spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
+                                                     const spvc_msl_resource_binding_2 *binding)
+{
+#if SPIRV_CROSS_C_API_MSL
+	if (compiler->backend != SPVC_BACKEND_MSL)
+	{
+		compiler->context->report_error("MSL function used on a non-MSL backend.");
+		return SPVC_ERROR_INVALID_ARGUMENT;
+	}
+
+	auto &msl = *static_cast<CompilerMSL *>(compiler->compiler.get());
+	MSLResourceBinding bind;
+	bind.binding = binding->binding;
+	bind.desc_set = binding->desc_set;
+	bind.stage = static_cast<spv::ExecutionModel>(binding->stage);
+	bind.msl_buffer = binding->msl_buffer;
+	bind.msl_texture = binding->msl_texture;
+	bind.msl_sampler = binding->msl_sampler;
+	bind.count = binding->count;
+	msl.add_msl_resource_binding(bind);
+	return SPVC_SUCCESS;
+#else
+	(void)binding;
+	compiler->context->report_error("MSL function used on a non-MSL backend.");
+	return SPVC_ERROR_INVALID_ARGUMENT;
+#endif
+}
+
 spvc_result spvc_compiler_msl_add_dynamic_buffer(spvc_compiler compiler, unsigned desc_set, unsigned binding, unsigned index)
 {
 #if SPIRV_CROSS_C_API_MSL
@@ -2811,6 +2843,22 @@ void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding)
 #endif
 }
 
+void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding)
+{
+#if SPIRV_CROSS_C_API_MSL
+	MSLResourceBinding binding_default;
+	binding->desc_set = binding_default.desc_set;
+	binding->binding = binding_default.binding;
+	binding->msl_buffer = binding_default.msl_buffer;
+	binding->msl_texture = binding_default.msl_texture;
+	binding->msl_sampler = binding_default.msl_sampler;
+	binding->stage = static_cast<SpvExecutionModel>(binding_default.stage);
+	binding->count = 0;
+#else
+	memset(binding, 0, sizeof(*binding));
+#endif
+}
+
 void spvc_hlsl_resource_binding_init(spvc_hlsl_resource_binding *binding)
 {
 #if SPIRV_CROSS_C_API_HLSL

+ 21 - 2
3rdparty/spirv-cross/spirv_cross_c.h

@@ -40,7 +40,7 @@ extern "C" {
 /* Bumped if ABI or API breaks backwards compatibility. */
 #define SPVC_C_API_VERSION_MAJOR 0
 /* Bumped if APIs or enumerations are added in a backwards compatible way. */
-#define SPVC_C_API_VERSION_MINOR 60
+#define SPVC_C_API_VERSION_MINOR 62
 /* Bumped if internal implementation details change. */
 #define SPVC_C_API_VERSION_PATCH 0
 
@@ -380,7 +380,8 @@ typedef struct spvc_msl_shader_interface_var_2
  */
 SPVC_PUBLIC_API void spvc_msl_shader_interface_var_init_2(spvc_msl_shader_interface_var_2 *var);
 
-/* Maps to C++ API. */
+/* Maps to C++ API.
+ * Deprecated. Use spvc_msl_resource_binding_2. */
 typedef struct spvc_msl_resource_binding
 {
 	SpvExecutionModel stage;
@@ -391,11 +392,24 @@ typedef struct spvc_msl_resource_binding
 	unsigned msl_sampler;
 } spvc_msl_resource_binding;
 
+typedef struct spvc_msl_resource_binding_2
+{
+	SpvExecutionModel stage;
+	unsigned desc_set;
+	unsigned binding;
+	unsigned count;
+	unsigned msl_buffer;
+	unsigned msl_texture;
+	unsigned msl_sampler;
+} spvc_msl_resource_binding_2;
+
 /*
  * Initializes the resource binding struct.
  * The defaults are non-zero.
+ * Deprecated: Use spvc_msl_resource_binding_init_2.
  */
 SPVC_PUBLIC_API void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding);
+SPVC_PUBLIC_API void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding);
 
 #define SPVC_MSL_PUSH_CONSTANT_DESC_SET (~(0u))
 #define SPVC_MSL_PUSH_CONSTANT_BINDING (0)
@@ -730,6 +744,8 @@ typedef enum spvc_compiler_option
 	SPVC_COMPILER_OPTION_MSL_AGX_MANUAL_CUBE_GRAD_FIXUP = 88 | SPVC_COMPILER_OPTION_MSL_BIT,
 	SPVC_COMPILER_OPTION_MSL_FORCE_FRAGMENT_WITH_SIDE_EFFECTS_EXECUTION = 89 | SPVC_COMPILER_OPTION_MSL_BIT,
 
+	SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME = 90 | SPVC_COMPILER_OPTION_HLSL_BIT,
+
 	SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
 } spvc_compiler_option;
 
@@ -836,8 +852,11 @@ SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_patch_output_buffer(spvc_compi
 SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_input_threadgroup_mem(spvc_compiler compiler);
 SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_vertex_attribute(spvc_compiler compiler,
                                                                    const spvc_msl_vertex_attribute *attrs);
+/* Deprecated; use spvc_compiler_msl_add_resource_binding_2(). */
 SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
                                                                    const spvc_msl_resource_binding *binding);
+SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
+                                                                     const spvc_msl_resource_binding_2 *binding);
 /* Deprecated; use spvc_compiler_msl_add_shader_input_2(). */
 SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_shader_input(spvc_compiler compiler,
                                                                const spvc_msl_shader_interface_var *input);

+ 2 - 0
3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp

@@ -783,6 +783,8 @@ uint32_t ParsedIR::get_member_decoration(TypeID id, uint32_t index, Decoration d
 		return dec.stream;
 	case DecorationSpecId:
 		return dec.spec_id;
+	case DecorationMatrixStride:
+		return dec.matrix_stride;
 	case DecorationIndex:
 		return dec.index;
 	default:

+ 63 - 4
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_
 string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
 {
 	auto &type = expression_type(id);
-	if (type.pointer && should_dereference(id))
+
+	if (is_pointer(type) && should_dereference(id))
 		return dereference_expression(type, to_enclosed_expression(id, register_expression_read));
 	else
 		return to_expression(id, register_expression_read);
@@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre
 string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read)
 {
 	auto &type = expression_type(id);
-	if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
+	if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
 		return address_of_expression(to_enclosed_expression(id, register_expression_read));
 	else
 		return to_unpacked_expression(id, register_expression_read);
@@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression
 string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read)
 {
 	auto &type = expression_type(id);
-	if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
+	if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
 		return address_of_expression(to_enclosed_expression(id, register_expression_read));
 	else
 		return to_enclosed_unpacked_expression(id, register_expression_read);
@@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 			}
 			else
 			{
-				append_index(index, is_literal, true);
+				if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
+				{
+					SPIRType tmp_type(OpTypeInt);
+					tmp_type.basetype = SPIRType::UInt64;
+					tmp_type.width = 64;
+					tmp_type.vecsize = 1;
+					tmp_type.columns = 1;
+
+					TypeID ptr_type_id = expression_type_id(base);
+					const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
+					const SPIRType &pointee_type = get_pointee_type(ptr_type);
+
+					// This only runs in native pointer backends.
+					// Can replace reinterpret_cast with a backend string if ever needed.
+					// We expect this to count as a de-reference.
+					// This leaks some MSL details, but feels slightly overkill to
+					// add yet another virtual interface just for this.
+					auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
+					intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
+					                    get_decoration(ptr_type_id, DecorationArrayStride));
+
+					if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
+					{
+						is_packed = true;
+						expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
+						            " *>(", intptr_expr, ")");
+					}
+					else
+					{
+						expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
+					}
+				}
+				else
+					append_index(index, is_literal, true);
 			}
 
 			if (type->basetype == SPIRType::ControlPointArray)
@@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
 	return ret;
 }
 
+uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
+{
+	SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
+}
+
 string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
                                   AccessChainMeta *meta, bool ptr_chain)
 {
@@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
 	{
 		AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
 		if (ptr_chain)
+		{
 			flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
+			// PtrAccessChain could get complicated.
+			TypeID type_id = expression_type_id(base);
+			if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
+			{
+				// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
+				// Using packed hacks only gets us so far, and is not designed to deal with pointer to
+				// random values. It works for structs though.
+				auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
+				uint32_t physical_stride = get_physical_type_stride(pointee_type);
+				uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
+				if (physical_stride != requested_stride)
+				{
+					flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
+					if (is_vector(pointee_type))
+						flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
+				}
+			}
+		}
+
 		return access_chain_internal(base, indices, count, flags, meta);
 	}
 }

+ 7 - 1
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -66,7 +66,9 @@ enum AccessChainFlagBits
 	ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
 	ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
 	ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
-	ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
+	ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
+	ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
+	ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
 };
 typedef uint32_t AccessChainFlags;
 
@@ -753,6 +755,10 @@ protected:
 	std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
 	                                  AccessChainMeta *meta);
 
+	// Only meaningful on backends with physical pointer support ala MSL.
+	// Relevant for PtrAccessChain / BDA.
+	virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
+
 	spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
 	virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
 

+ 48 - 9
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -849,9 +849,23 @@ void CompilerHLSL::emit_builtin_inputs_in_struct()
 		case BuiltInSubgroupLeMask:
 		case BuiltInSubgroupGtMask:
 		case BuiltInSubgroupGeMask:
+			// Handled specially.
+			break;
+
 		case BuiltInBaseVertex:
+			if (hlsl_options.shader_model >= 68)
+			{
+				type = "uint";
+				semantic = "SV_StartVertexLocation";
+			}
+			break;
+
 		case BuiltInBaseInstance:
-			// Handled specially.
+			if (hlsl_options.shader_model >= 68)
+			{
+				type = "uint";
+				semantic = "SV_StartInstanceLocation";
+			}
 			break;
 
 		case BuiltInHelperInvocation:
@@ -1231,7 +1245,7 @@ void CompilerHLSL::emit_builtin_variables()
 		case BuiltInVertexIndex:
 		case BuiltInInstanceIndex:
 			type = "int";
-			if (hlsl_options.support_nonzero_base_vertex_base_instance)
+			if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
 				base_vertex_info.used = true;
 			break;
 
@@ -1353,7 +1367,7 @@ void CompilerHLSL::emit_builtin_variables()
 		}
 	});
 
-	if (base_vertex_info.used)
+	if (base_vertex_info.used && hlsl_options.shader_model < 68)
 	{
 		string binding_info;
 		if (base_vertex_info.explicit_binding)
@@ -3136,23 +3150,39 @@ void CompilerHLSL::emit_hlsl_entry_point()
 		case BuiltInVertexIndex:
 		case BuiltInInstanceIndex:
 			// D3D semantics are uint, but shader wants int.
-			if (hlsl_options.support_nonzero_base_vertex_base_instance)
+			if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
 			{
-				if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
-					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
+				if (hlsl_options.shader_model >= 68)
+				{
+					if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
+						statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseInstanceARB);");
+					else
+						statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseVertexARB);");
+				}
 				else
-					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
+				{
+					if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
+						statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
+					else
+						statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
+				}
 			}
 			else
 				statement(builtin, " = int(stage_input.", builtin, ");");
 			break;
 
 		case BuiltInBaseVertex:
-			statement(builtin, " = SPIRV_Cross_BaseVertex;");
+			if (hlsl_options.shader_model >= 68)
+				statement(builtin, " = stage_input.gl_BaseVertexARB;");
+			else
+				statement(builtin, " = SPIRV_Cross_BaseVertex;");
 			break;
 
 		case BuiltInBaseInstance:
-			statement(builtin, " = SPIRV_Cross_BaseInstance;");
+			if (hlsl_options.shader_model >= 68)
+				statement(builtin, " = stage_input.gl_BaseInstanceARB;");
+			else
+				statement(builtin, " = SPIRV_Cross_BaseInstance;");
 			break;
 
 		case BuiltInInstanceId:
@@ -6714,6 +6744,15 @@ string CompilerHLSL::compile()
 	if (need_subpass_input)
 		active_input_builtins.set(BuiltInFragCoord);
 
+	// Need to offset by BaseVertex/BaseInstance in SM 6.8+.
+	if (hlsl_options.shader_model >= 68)
+	{
+		if (active_input_builtins.get(BuiltInVertexIndex))
+			active_input_builtins.set(BuiltInBaseVertex);
+		if (active_input_builtins.get(BuiltInInstanceIndex))
+			active_input_builtins.set(BuiltInBaseInstance);
+	}
+
 	uint32_t pass_count = 0;
 	do
 	{

+ 164 - 88
3rdparty/spirv-cross/spirv_msl.cpp

@@ -1361,14 +1361,14 @@ void CompilerMSL::emit_entry_point_declarations()
 
 		if (is_array(type))
 		{
-			if (!type.array[type.array.size() - 1])
-				SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
-
 			is_using_builtin_array = true;
 			statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, true), name,
 			          type_to_array_glsl(type, var_id), " =");
 
-			uint32_t array_size = to_array_size_literal(type);
+			uint32_t array_size = get_resource_array_size(type, var_id);
+			if (array_size == 0)
+				SPIRV_CROSS_THROW("Size of runtime array with dynamic offset could not be determined from resource bindings.");
+
 			begin_scope();
 
 			for (uint32_t i = 0; i < array_size; i++)
@@ -1576,8 +1576,7 @@ string CompilerMSL::compile()
 	preprocess_op_codes();
 	build_implicit_builtins();
 
-	if (needs_manual_helper_invocation_updates() &&
-	    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+	if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
 	{
 		string builtin_helper_invocation = builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput);
 		string discard_expr = join(builtin_helper_invocation, " = true, discard_fragment()");
@@ -1721,7 +1720,7 @@ void CompilerMSL::preprocess_op_codes()
 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
 	                          (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses))))
 		needs_sample_id = true;
-	if (preproc.needs_helper_invocation)
+	if (preproc.needs_helper_invocation || active_input_builtins.get(BuiltInHelperInvocation))
 		needs_helper_invocation = true;
 
 	// OpKill is removed by the parser, so we need to identify those by inspecting
@@ -2058,8 +2057,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 			}
 
 			case OpDemoteToHelperInvocation:
-				if (needs_manual_helper_invocation_updates() &&
-				    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+				if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
 					added_arg_ids.insert(builtin_helper_invocation_id);
 				break;
 
@@ -2112,7 +2110,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 			}
 
 			if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill &&
-			    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+			    needs_helper_invocation)
 				added_arg_ids.insert(builtin_helper_invocation_id);
 
 			// TODO: Add all other operations which can affect memory.
@@ -4803,7 +4801,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32
 			return false;
 	}
 
-	if (!mbr_type.array.empty())
+	if (is_array(mbr_type))
 	{
 		// If we have an array type, array stride must match exactly with SPIR-V.
 
@@ -5615,6 +5613,10 @@ void CompilerMSL::emit_custom_templates()
 // otherwise they will cause problems when linked together in a single Metallib.
 void CompilerMSL::emit_custom_functions()
 {
+	// Use when outputting overloaded functions to cover different address spaces.
+	static const char *texture_addr_spaces[] = { "device", "constant", "thread" };
+	static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*);
+
 	if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim))
 		spv_function_implementations.insert(SPVFuncImplArrayCopy);
 
@@ -6264,54 +6266,62 @@ void CompilerMSL::emit_custom_functions()
 			break;
 
 		case SPVFuncImplGatherConstOffsets:
-			statement("// Wrapper function that processes a texture gather with a constant offset array.");
-			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
-			          "typename Toff, typename... Tp>");
-			statement("inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, "
-			          "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
-			begin_scope();
-			statement("vec<T, 4> rslts[4];");
-			statement("for (uint i = 0; i < 4; i++)");
-			begin_scope();
-			statement("switch (c)");
-			begin_scope();
-			// Work around texture::gather() requiring its component parameter to be a constant expression
-			statement("case component::x:");
-			statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
-			statement("    break;");
-			statement("case component::y:");
-			statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
-			statement("    break;");
-			statement("case component::z:");
-			statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
-			statement("    break;");
-			statement("case component::w:");
-			statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
-			statement("    break;");
-			end_scope();
-			end_scope();
-			// Pull all values from the i0j0 component of each gather footprint
-			statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
-			end_scope();
-			statement("");
+			// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
+			for (uint32_t i = 0; i < texture_addr_space_count; i++)
+			{
+				statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
+				statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
+						  "typename Toff, typename... Tp>");
+				statement("inline vec<T, 4> spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
+						  "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
+				begin_scope();
+				statement("vec<T, 4> rslts[4];");
+				statement("for (uint i = 0; i < 4; i++)");
+				begin_scope();
+				statement("switch (c)");
+				begin_scope();
+				// Work around texture::gather() requiring its component parameter to be a constant expression
+				statement("case component::x:");
+				statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
+				statement("    break;");
+				statement("case component::y:");
+				statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
+				statement("    break;");
+				statement("case component::z:");
+				statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
+				statement("    break;");
+				statement("case component::w:");
+				statement("    rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
+				statement("    break;");
+				end_scope();
+				end_scope();
+				// Pull all values from the i0j0 component of each gather footprint
+				statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
+				end_scope();
+				statement("");
+			}
 			break;
 
 		case SPVFuncImplGatherCompareConstOffsets:
-			statement("// Wrapper function that processes a texture gather with a constant offset array.");
-			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
-			          "typename Toff, typename... Tp>");
-			statement("inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, "
-			          "Toff coffsets, Tp... params)");
-			begin_scope();
-			statement("vec<T, 4> rslts[4];");
-			statement("for (uint i = 0; i < 4; i++)");
-			begin_scope();
-			statement("    rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
-			end_scope();
-			// Pull all values from the i0j0 component of each gather footprint
-			statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
-			end_scope();
-			statement("");
+			// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
+			for (uint32_t i = 0; i < texture_addr_space_count; i++)
+			{
+				statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
+				statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
+						  "typename Toff, typename... Tp>");
+				statement("inline vec<T, 4> spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
+						  "Toff coffsets, Tp... params)");
+				begin_scope();
+				statement("vec<T, 4> rslts[4];");
+				statement("for (uint i = 0; i < 4; i++)");
+				begin_scope();
+				statement("    rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
+				end_scope();
+				// Pull all values from the i0j0 component of each gather footprint
+				statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
+				end_scope();
+				statement("");
+			}
 			break;
 
 		case SPVFuncImplSubgroupBroadcast:
@@ -9246,18 +9256,40 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t coord_id = ops[3];
 		emit_uninitialized_temporary_expression(result_type, id);
 
+		std::string coord_expr = to_expression(coord_id);
 		auto sampler_expr = to_sampler_expression(image_id);
 		auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
 		auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
+		const SPIRType &image_type = expression_type(image_id);
+		const SPIRType &coord_type = expression_type(coord_id);
+
+		switch (image_type.image.dim)
+		{
+		case Dim1D:
+			if (!msl_options.texture_1D_as_2D)
+				SPIRV_CROSS_THROW("ImageQueryLod is not supported on 1D textures.");
+			[[fallthrough]];
+		case Dim2D:
+			if (coord_type.vecsize > 2)
+				coord_expr = enclose_expression(coord_expr) + ".xy";
+			break;
+		case DimCube:
+		case Dim3D:
+			if (coord_type.vecsize > 3)
+				coord_expr = enclose_expression(coord_expr) + ".xyz";
+			break;
+		default:
+			SPIRV_CROSS_THROW("Bad image type given to OpImageQueryLod");
+		}
 
 		// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
 		// the reported LOD based on the sampler. NEAREST miplevel should
 		// round the LOD, but LINEAR miplevel should not round.
 		// Let's hope this does not become an issue ...
 		statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
-		          to_expression(coord_id), ");");
+		          coord_expr, ");");
 		statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
-		          to_expression(coord_id), ");");
+		          coord_expr, ");");
 		register_control_dependent_expression(id);
 		break;
 	}
@@ -12167,21 +12199,26 @@ string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_
 string CompilerMSL::to_sampler_expression(uint32_t id)
 {
 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
-	auto expr = to_expression(combined ? combined->image : VariableID(id));
-	auto index = expr.find_first_of('[');
+	if (combined && combined->sampler)
+		return to_expression(combined->sampler);
 
-	uint32_t samp_id = 0;
-	if (combined)
-		samp_id = combined->sampler;
+	uint32_t expr_id = combined ? uint32_t(combined->image) : id;
 
-	if (index == string::npos)
-		return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
-	else
+	// Constexpr samplers are declared as local variables,
+	// so exclude any qualifier names on the image expression.
+	if (auto *var = maybe_get_backing_variable(expr_id))
 	{
-		auto image_expr = expr.substr(0, index);
-		auto array_expr = expr.substr(index);
-		return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
+		uint32_t img_id =  var->basevariable ? var->basevariable : VariableID(var->self);
+		if (find_constexpr_sampler(img_id))
+			return Compiler::to_name(img_id) + sampler_name_suffix;
 	}
+
+	auto img_expr = to_expression(expr_id);
+	auto index = img_expr.find_first_of('[');
+	if (index == string::npos)
+		return img_expr + sampler_name_suffix;
+	else
+		return img_expr.substr(0, index) + sampler_name_suffix + img_expr.substr(index);
 }
 
 string CompilerMSL::to_swizzle_expression(uint32_t id)
@@ -13176,7 +13213,10 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
 	}
 
-	return join(decoration_flags_signal_volatile(flags) ? "volatile " : "", addr_space);
+	if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread"))
+		return join("volatile ", addr_space);
+	else
+		return addr_space;
 }
 
 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
@@ -13602,7 +13642,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
 
 		claimed_bindings.set(buffer_binding);
 
-		ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
+		ep_args += get_argument_address_space(var) + " ";
+
+		if (recursive_inputs.count(type.self))
+			ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp";
+		else
+			ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
+
 		ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
 
 		next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
@@ -14040,7 +14086,7 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
 						    statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
 						              is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
 						              ".spvBufferSizeConstants", "[",
-						              convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
+						              convert_to_string(get_metal_resource_index(var, SPIRType::UInt)), "];");
 					    }
 					    else
 					    {
@@ -14053,7 +14099,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
 			}
 		}
 
-		if (msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
+		if (!msl_options.argument_buffers &&
+		     msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
 		    (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
 		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
 		{
@@ -17026,13 +17073,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
 	return msl_size;
 }
 
+uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
+{
+	// This should only be relevant for plain types such as scalars and vectors?
+	// If we're pointing to a struct, it will recursively pick up packed/row-major state.
+	return get_declared_type_size_msl(type, false, false);
+}
+
 // Returns the byte size of a struct member.
 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
 {
 	// Pointers take 8 bytes each
+	// Match both pointer and array-of-pointer here.
 	if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
 	{
-		uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
+		uint32_t type_size = 8;
 
 		// Work our way through potentially layered arrays,
 		// stopping when we hit a pointer that is not also an array.
@@ -17107,9 +17162,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t
 // Returns the byte alignment of a type.
 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
 {
-	// Pointers aligns on multiples of 8 bytes
+	// Pointers align on multiples of 8 bytes.
+	// Deliberately ignore array-ness here. It's not relevant for alignment.
 	if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
-		return 8 * (type.vecsize == 3 ? 4 : type.vecsize);
+		return 8;
 
 	switch (type.basetype)
 	{
@@ -18134,6 +18190,13 @@ void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &al
 	}
 	else
 	{
+		// This alias may have already been used to emit an entry point declaration. If there is a mismatch, we need a recompile.
+		// Moving this code to be run earlier will also conflict,
+		// because we need the qualified alias for the base resource,
+		// so forcing recompile until things sync up is the least invasive method for now.
+		if (ir.meta[aliased_var.self].decoration.qualified_alias != name)
+			force_recompile();
+
 		// This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted.
 		set_qualified_name(aliased_var.self, name);
 	}
@@ -18158,6 +18221,7 @@ void CompilerMSL::analyze_argument_buffers()
 		string name;
 		SPIRType::BaseType basetype;
 		uint32_t index;
+		uint32_t plane_count;
 		uint32_t plane;
 		uint32_t overlapping_var_id;
 	};
@@ -18208,14 +18272,14 @@ void CompilerMSL::analyze_argument_buffers()
 				{
 					uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
 					resources_in_set[desc_set].push_back(
-					    { &var, to_name(var_id), SPIRType::Image, image_resource_index, i, 0 });
+					    { &var, to_name(var_id), SPIRType::Image, image_resource_index, plane_count, i, 0 });
 				}
 
 				if (type.image.dim != DimBuffer && !constexpr_sampler)
 				{
 					uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
 					resources_in_set[desc_set].push_back(
-					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0, 0 });
+					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 1, 0, 0 });
 				}
 			}
 			else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
@@ -18231,14 +18295,14 @@ void CompilerMSL::analyze_argument_buffers()
 				uint32_t resource_index = get_metal_resource_index(var, type.basetype);
 
 				resources_in_set[desc_set].push_back(
-					{ &var, to_name(var_id), type.basetype, resource_index, 0, 0 });
+					{ &var, to_name(var_id), type.basetype, resource_index, 1, 0, 0 });
 
 				// Emulate texture2D atomic operations
 				if (atomic_image_vars_emulated.count(var.self))
 				{
 					uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
 					resources_in_set[desc_set].push_back(
-						{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0, 0 });
+						{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 1, 0, 0 });
 				}
 			}
 
@@ -18286,7 +18350,7 @@ void CompilerMSL::analyze_argument_buffers()
 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
 				set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
 				resources_in_set[desc_set].push_back(
-				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 });
+				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 });
 			}
 
 			if (set_needs_buffer_sizes[desc_set])
@@ -18297,7 +18361,7 @@ void CompilerMSL::analyze_argument_buffers()
 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
 				set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
 				resources_in_set[desc_set].push_back(
-				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 });
+				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 });
 			}
 		}
 	}
@@ -18309,7 +18373,7 @@ void CompilerMSL::analyze_argument_buffers()
 		uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
 		add_resource_name(var_id);
 		resources_in_set[desc_set].push_back(
-		    { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0, 0 });
+		    { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 1, 0, 0 });
 	}
 
 	for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
@@ -18340,7 +18404,8 @@ void CompilerMSL::analyze_argument_buffers()
 		else
 			buffer_type.storage = StorageClassUniform;
 
-		set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
+		auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set);
+		set_name(type_id, buffer_type_name);
 
 		auto &ptr_type = set<SPIRType>(ptr_type_id, OpTypePointer);
 		ptr_type = buffer_type;
@@ -18350,8 +18415,9 @@ void CompilerMSL::analyze_argument_buffers()
 		ptr_type.parent_type = type_id;
 
 		uint32_t buffer_variable_id = next_id;
-		set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
-		set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
+		auto &buffer_var = set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
+		auto buffer_name = join("spvDescriptorSet", desc_set);
+		set_name(buffer_variable_id, buffer_name);
 
 		// Ids must be emitted in ID order.
 		stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
@@ -18386,7 +18452,7 @@ void CompilerMSL::analyze_argument_buffers()
 
 			// If needed, synthesize and add padding members.
 			// member_index and next_arg_buff_index are incremented when padding members are added.
-			if (msl_options.pad_argument_buffer_resources && resource.overlapping_var_id == 0)
+			if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0)
 			{
 				auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
 				while (resource.index > next_arg_buff_index)
@@ -18432,7 +18498,7 @@ void CompilerMSL::analyze_argument_buffers()
 				// Adjust the number of slots consumed by current member itself.
 				// Use the count value from the app, instead of the shader, in case the
 				// shader is only accessing part, or even one element, of the array.
-				next_arg_buff_index += rez_bind.count;
+				next_arg_buff_index += resource.plane_count * rez_bind.count;
 			}
 
 			string mbr_name = ensure_valid_name(resource.name, "m");
@@ -18559,6 +18625,16 @@ void CompilerMSL::analyze_argument_buffers()
 				set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding);
 			member_index++;
 		}
+		
+		if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type))
+		{
+			recursive_inputs.insert(type_id);
+			auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
+			auto addr_space = get_argument_address_space(buffer_var);
+			entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() {
+				statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;");
+			});
+		}
 	}
 }
 

+ 2 - 0
3rdparty/spirv-cross/spirv_msl.hpp

@@ -1028,6 +1028,8 @@ protected:
 
 	uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;
 
+	uint32_t get_physical_type_stride(const SPIRType &type) const override;
+
 	// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
 	// These values can change depending on various extended decorations which control packing rules.
 	// We need to make these rules match up with SPIR-V declared rules.