Бранимир Караџић 3 лет назад
Родитель
Сommit
e73aa1e0b3

+ 1 - 4
3rdparty/spirv-cross/spirv_cpp.cpp

@@ -338,11 +338,8 @@ string CompilerCPP::compile()
 	uint32_t pass_count = 0;
 	do
 	{
-		if (pass_count >= 3)
-			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
-
 		resource_registrations.clear();
-		reset();
+		reset(pass_count);
 
 		// Move constructor for this type is broken on GCC 4.9 ...
 		buffer.reset();

+ 7 - 0
3rdparty/spirv-cross/spirv_cross.cpp

@@ -4828,6 +4828,12 @@ void Compiler::force_recompile()
 	is_force_recompile = true;
 }
 
+void Compiler::force_recompile_guarantee_forward_progress()
+{
+	force_recompile();
+	is_force_recompile_forward_progress = true;
+}
+
 bool Compiler::is_forcing_recompilation() const
 {
 	return is_force_recompile;
@@ -4836,6 +4842,7 @@ bool Compiler::is_forcing_recompilation() const
 void Compiler::clear_force_recompile()
 {
 	is_force_recompile = false;
+	is_force_recompile_forward_progress = false;
 }
 
 Compiler::PhysicalStorageBufferPointerHandler::PhysicalStorageBufferPointerHandler(Compiler &compiler_)

+ 2 - 0
3rdparty/spirv-cross/spirv_cross.hpp

@@ -735,9 +735,11 @@ protected:
 	SPIRBlock::ContinueBlockType continue_block_type(const SPIRBlock &continue_block) const;
 
 	void force_recompile();
+	void force_recompile_guarantee_forward_progress();
 	void clear_force_recompile();
 	bool is_forcing_recompilation() const;
 	bool is_force_recompile = false;
+	bool is_force_recompile_forward_progress = false;
 
 	bool block_is_loop_candidate(const SPIRBlock &block, SPIRBlock::Method method) const;
 

+ 51 - 17
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -296,8 +296,19 @@ const char *CompilerGLSL::vector_swizzle(int vecsize, int index)
 	return swizzle[vecsize - 1][index];
 }
 
-void CompilerGLSL::reset()
-{
+void CompilerGLSL::reset(uint32_t iteration_count)
+{
+	// Sanity check the iteration count to be robust against a certain class of bugs where
+	// we keep forcing recompilations without making clear forward progress.
+	// In buggy situations we will loop forever, or loop for an unbounded number of iterations.
+	// Certain types of recompilations are considered to make forward progress,
+	// but in almost all situations, we'll never see more than 3 iterations.
+	// It is highly context-sensitive when we need to force recompilation,
+	// and it is not practical with the current architecture
+	// to resolve everything up front.
+	if (iteration_count >= 3 && !is_force_recompile_forward_progress)
+		SPIRV_CROSS_THROW("Over 3 compilation loops detected and no forward progress was made. Must be a bug!");
+
 	// We do some speculative optimizations which should pretty much always work out,
 	// but just in case the SPIR-V is rather weird, recompile until it's happy.
 	// This typically only means one extra pass.
@@ -664,10 +675,7 @@ string CompilerGLSL::compile()
 	uint32_t pass_count = 0;
 	do
 	{
-		if (pass_count >= 3)
-			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
-
-		reset();
+		reset(pass_count);
 
 		buffer.reset();
 
@@ -2678,6 +2686,7 @@ string CompilerGLSL::constant_value_macro_name(uint32_t id)
 void CompilerGLSL::emit_specialization_constant_op(const SPIRConstantOp &constant)
 {
 	auto &type = get<SPIRType>(constant.basetype);
+	add_resource_name(constant.self);
 	auto name = to_name(constant.self);
 	statement("const ", variable_decl(type, name), " = ", constant_op_expression(constant), ";");
 }
@@ -2705,7 +2714,6 @@ int CompilerGLSL::get_constant_mapping_to_workgroup_component(const SPIRConstant
 void CompilerGLSL::emit_constant(const SPIRConstant &constant)
 {
 	auto &type = get<SPIRType>(constant.constant_type);
-	auto name = to_name(constant.self);
 
 	SpecializationConstant wg_x, wg_y, wg_z;
 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
@@ -2732,6 +2740,9 @@ void CompilerGLSL::emit_constant(const SPIRConstant &constant)
 		return;
 	}
 
+	add_resource_name(constant.self);
+	auto name = to_name(constant.self);
+
 	// Only scalars have constant IDs.
 	if (has_decoration(constant.self, DecorationSpecId))
 	{
@@ -4291,8 +4302,13 @@ void CompilerGLSL::handle_invalid_expression(uint32_t id)
 {
 	// We tried to read an invalidated expression.
 	// This means we need another pass at compilation, but next time, force temporary variables so that they cannot be invalidated.
-	forced_temporaries.insert(id);
-	force_recompile();
+	auto res = forced_temporaries.insert(id);
+
+	// Forcing new temporaries guarantees forward progress.
+	if (res.second)
+		force_recompile_guarantee_forward_progress();
+	else
+		force_recompile();
 }
 
 // Converts the format of the current expression from packed to unpacked,
@@ -4546,12 +4562,13 @@ string CompilerGLSL::to_rerolled_array_expression(const string &base_expr, const
 	return expr;
 }
 
-string CompilerGLSL::to_composite_constructor_expression(uint32_t id, bool uses_buffer_offset)
+string CompilerGLSL::to_composite_constructor_expression(uint32_t id, bool block_like_type)
 {
 	auto &type = expression_type(id);
 
-	bool reroll_array = !type.array.empty() && (!backend.array_is_value_type ||
-	                                            (uses_buffer_offset && !backend.buffer_offset_array_is_value_type));
+	bool reroll_array = !type.array.empty() &&
+	                    (!backend.array_is_value_type ||
+	                     (block_like_type && !backend.array_is_value_type_in_buffer_blocks));
 
 	if (reroll_array)
 	{
@@ -4953,7 +4970,7 @@ string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop)
 	}
 }
 
-string CompilerGLSL::constant_expression(const SPIRConstant &c)
+string CompilerGLSL::constant_expression(const SPIRConstant &c, bool inside_block_like_struct_scope)
 {
 	auto &type = get<SPIRType>(c.constant_type);
 
@@ -4966,6 +4983,15 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c)
 		// Handles Arrays and structures.
 		string res;
 
+		// Only consider the decay if we are inside a struct scope.
+		// Outside a struct declaration, we can always bind to a constant array with templated type.
+		bool array_type_decays = inside_block_like_struct_scope &&
+		                         !type.array.empty() && !backend.array_is_value_type_in_buffer_blocks &&
+		                         has_decoration(c.constant_type, DecorationArrayStride);
+
+		if (type.array.empty() && type.basetype == SPIRType::Struct && type_is_block_like(type))
+			inside_block_like_struct_scope = true;
+
 		// Allow Metal to use the array<T> template to make arrays a value type
 		bool needs_trailing_tracket = false;
 		if (backend.use_initializer_list && backend.use_typed_initializer_list && type.basetype == SPIRType::Struct &&
@@ -4974,7 +5000,7 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c)
 			res = type_to_glsl_constructor(type) + "{ ";
 		}
 		else if (backend.use_initializer_list && backend.use_typed_initializer_list && backend.array_is_value_type &&
-		         !type.array.empty())
+		         !type.array.empty() && !array_type_decays)
 		{
 			res = type_to_glsl_constructor(type) + "({ ";
 			needs_trailing_tracket = true;
@@ -4994,7 +5020,7 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c)
 			if (subc.specialization)
 				res += to_name(elem);
 			else
-				res += constant_expression(subc);
+				res += constant_expression(subc, inside_block_like_struct_scope);
 
 			if (&elem != &c.subconstants.back())
 				res += ", ";
@@ -6952,7 +6978,7 @@ bool CompilerGLSL::expression_is_non_value_type_array(uint32_t ptr)
 		return false;
 
 	auto &backed_type = get<SPIRType>(var->basetype);
-	return !backend.buffer_offset_array_is_value_type && backed_type.basetype == SPIRType::Struct &&
+	return !backend.array_is_value_type_in_buffer_blocks && backed_type.basetype == SPIRType::Struct &&
 	       has_member_decoration(backed_type.self, 0, DecorationOffset);
 }
 
@@ -9498,6 +9524,7 @@ bool CompilerGLSL::should_forward(uint32_t id) const
 {
 	// If id is a variable we will try to forward it regardless of force_temporary check below
 	// This is important because otherwise we'll get local sampler copies (highp sampler2D foo = bar) that are invalid in OpenGL GLSL
+
 	auto *var = maybe_get<SPIRVariable>(id);
 	if (var && var->forwardable)
 		return true;
@@ -9506,6 +9533,13 @@ bool CompilerGLSL::should_forward(uint32_t id) const
 	if (options.force_temporary)
 		return false;
 
+	// If an expression carries enough dependencies we need to stop forwarding at some point,
+	// or we explode compilers. There are usually limits to how much we can nest expressions.
+	auto *expr = maybe_get<SPIRExpression>(id);
+	const uint32_t max_expression_dependencies = 64;
+	if (expr && expr->expression_dependencies.size() >= max_expression_dependencies)
+		return false;
+
 	// Immutable expression can always be forwarded.
 	if (is_immutable(id))
 		return true;
@@ -14752,7 +14786,7 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
 	// as writes to said loop variables might have been masked out, we need a recompile.
 	if (!emitted_loop_header_variables && !block.loop_variables.empty())
 	{
-		force_recompile();
+		force_recompile_guarantee_forward_progress();
 		for (auto var : block.loop_variables)
 			get<SPIRVariable>(var).loop_variable = false;
 		block.loop_variables.clear();

+ 4 - 4
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -342,7 +342,7 @@ protected:
 	// TODO remove this function when all subgroup ops are supported (or make it always return true)
 	static bool is_supported_subgroup_op_in_opengl(spv::Op op);
 
-	void reset();
+	void reset(uint32_t iteration_count);
 	void emit_function(SPIRFunction &func, const Bitset &return_flags);
 
 	bool has_extension(const std::string &ext) const;
@@ -385,7 +385,7 @@ protected:
 	                                const std::string &qualifier = "", uint32_t base_offset = 0);
 	virtual void emit_struct_padding_target(const SPIRType &type);
 	virtual std::string image_type_glsl(const SPIRType &type, uint32_t id = 0);
-	std::string constant_expression(const SPIRConstant &c);
+	std::string constant_expression(const SPIRConstant &c, bool inside_block_like_struct_scope = false);
 	virtual std::string constant_op_expression(const SPIRConstantOp &cop);
 	virtual std::string constant_expression_vector(const SPIRConstant &c, uint32_t vector);
 	virtual void emit_fixup();
@@ -577,7 +577,7 @@ protected:
 		bool supports_extensions = false;
 		bool supports_empty_struct = false;
 		bool array_is_value_type = true;
-		bool buffer_offset_array_is_value_type = true;
+		bool array_is_value_type_in_buffer_blocks = true;
 		bool comparison_image_samples_scalar = false;
 		bool native_pointers = false;
 		bool support_small_type_sampling_result = false;
@@ -718,7 +718,7 @@ protected:
 	void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist);
 	std::string to_non_uniform_aware_expression(uint32_t id);
 	std::string to_expression(uint32_t id, bool register_expression_read = true);
-	std::string to_composite_constructor_expression(uint32_t id, bool uses_buffer_offset);
+	std::string to_composite_constructor_expression(uint32_t id, bool block_like_type);
 	std::string to_rerolled_array_expression(const std::string &expr, const SPIRType &type);
 	std::string to_enclosed_expression(uint32_t id, bool register_expression_read = true);
 	std::string to_unpacked_expression(uint32_t id, bool register_expression_read = true);

+ 4 - 4
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -1167,6 +1167,7 @@ void CompilerHLSL::emit_composite_constants()
 
 		if (type.basetype == SPIRType::Struct || !type.array.empty())
 		{
+			add_resource_name(c.self);
 			auto name = to_name(c.self);
 			statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
 			emitted = true;
@@ -1213,6 +1214,7 @@ void CompilerHLSL::emit_specialization_constants_and_structs()
 			else if (c.specialization)
 			{
 				auto &type = get<SPIRType>(c.constant_type);
+				add_resource_name(c.self);
 				auto name = to_name(c.self);
 
 				if (has_decoration(c.self, DecorationSpecId))
@@ -1236,6 +1238,7 @@ void CompilerHLSL::emit_specialization_constants_and_structs()
 		{
 			auto &c = id.get<SPIRConstantOp>();
 			auto &type = get<SPIRType>(c.basetype);
+			add_resource_name(c.self);
 			auto name = to_name(c.self);
 			statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
 			emitted = true;
@@ -5801,10 +5804,7 @@ string CompilerHLSL::compile()
 	uint32_t pass_count = 0;
 	do
 	{
-		if (pass_count >= 3)
-			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
-
-		reset();
+		reset(pass_count);
 
 		// Move constructor for this type is broken on GCC 4.9 ...
 		buffer.reset();

+ 112 - 49
3rdparty/spirv-cross/spirv_msl.cpp

@@ -1360,8 +1360,8 @@ string CompilerMSL::compile()
 	// Allow Metal to use the array<T> template unless we force it off.
 	backend.can_return_array = !msl_options.force_native_arrays;
 	backend.array_is_value_type = !msl_options.force_native_arrays;
-	// Arrays which are part of buffer objects are never considered to be native arrays.
-	backend.buffer_offset_array_is_value_type = false;
+	// Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
+	backend.array_is_value_type_in_buffer_blocks = false;
 	backend.support_pointer_to_pointer = true;
 
 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
@@ -1446,10 +1446,7 @@ string CompilerMSL::compile()
 	uint32_t pass_count = 0;
 	do
 	{
-		if (pass_count >= 3)
-			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
-
-		reset();
+		reset(pass_count);
 
 		// Start bindings at zero.
 		next_metal_resource_index_buffer = 0;
@@ -6568,6 +6565,7 @@ void CompilerMSL::declare_constant_arrays()
 		// link into Metal libraries. This is hacky.
 		if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
 		{
+			add_resource_name(c.self);
 			auto name = to_name(c.self);
 			statement(inject_top_level_storage_qualifier(variable_decl(type, name), "constant"),
 			          " = ", constant_expression(c), ";");
@@ -6599,6 +6597,7 @@ void CompilerMSL::declare_complex_constant_arrays()
 		auto &type = this->get<SPIRType>(c.constant_type);
 		if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
 		{
+			add_resource_name(c.self);
 			auto name = to_name(c.self);
 			statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
 			emitted = true;
@@ -6677,6 +6676,7 @@ void CompilerMSL::emit_specialization_constants_and_structs()
 			{
 				auto &type = get<SPIRType>(c.constant_type);
 				string sc_type_name = type_to_glsl(type);
+				add_resource_name(c.self);
 				string sc_name = to_name(c.self);
 				string sc_tmp_name = sc_name + "_tmp";
 
@@ -6719,6 +6719,7 @@ void CompilerMSL::emit_specialization_constants_and_structs()
 		{
 			auto &c = id.get<SPIRConstantOp>();
 			auto &type = get<SPIRType>(c.basetype);
+			add_resource_name(c.self);
 			auto name = to_name(c.self);
 			statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
 			emitted = true;
@@ -7763,7 +7764,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t ptr = ops[2];
 		uint32_t mem_sem = ops[4];
 		uint32_t val = ops[5];
-		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
+		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", opcode, mem_sem, mem_sem, false, ptr, val);
 		break;
 	}
 
@@ -7776,7 +7777,8 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t mem_sem_fail = ops[5];
 		uint32_t val = ops[6];
 		uint32_t comp = ops[7];
-		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
+		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", opcode,
+		                    mem_sem_pass, mem_sem_fail, true,
 		                    ptr, comp, true, false, val);
 		break;
 	}
@@ -7790,7 +7792,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t id = ops[1];
 		uint32_t ptr = ops[2];
 		uint32_t mem_sem = ops[4];
-		emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
+		emit_atomic_func_op(result_type, id, "atomic_load_explicit", opcode, mem_sem, mem_sem, false, ptr, 0);
 		break;
 	}
 
@@ -7801,7 +7803,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t ptr = ops[0];
 		uint32_t mem_sem = ops[2];
 		uint32_t val = ops[3];
-		emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
+		emit_atomic_func_op(result_type, id, "atomic_store_explicit", opcode, mem_sem, mem_sem, false, ptr, val);
 		break;
 	}
 
@@ -7813,7 +7815,8 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t ptr = ops[2];                                                                                   \
 		uint32_t mem_sem = ops[4];                                                                               \
 		uint32_t val = valsrc;                                                                                   \
-		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
+		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", opcode,                            \
+		                    mem_sem, mem_sem, false, ptr, val,                                                   \
 		                    false, valconst);                                                                    \
 	} while (false)
 
@@ -8799,13 +8802,22 @@ bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
 }
 
 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
-void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
-                                      uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
+void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, Op opcode,
+                                      uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
 {
 	string exp = string(op) + "(";
 
 	auto &type = get_pointee_type(expression_type(obj));
+	auto expected_type = type.basetype;
+	if (opcode == OpAtomicUMax || opcode == OpAtomicUMin)
+		expected_type = to_unsigned_basetype(type.width);
+	else if (opcode == OpAtomicSMax || opcode == OpAtomicSMin)
+		expected_type = to_signed_basetype(type.width);
+
+	auto remapped_type = type;
+	remapped_type.basetype = expected_type;
+
 	exp += "(";
 	auto *var = maybe_get_backing_variable(obj);
 	if (!var)
@@ -8823,7 +8835,9 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 	}
 
 	exp += " atomic_";
-	exp += type_to_glsl(type);
+	// For signed and unsigned min/max, we can signal this through the pointer type.
+	// There is no other way, since C++ does not have explicit signage for atomics.
+	exp += type_to_glsl(remapped_type);
 	exp += "*)";
 
 	exp += "&";
@@ -8866,7 +8880,7 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 			if (op1_is_literal)
 				exp += join(", ", op1);
 			else
-				exp += ", " + to_expression(op1);
+				exp += ", " + bitcast_expression(expected_type, op1);
 		}
 		if (op2)
 			exp += ", " + to_expression(op2);
@@ -8877,6 +8891,9 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 
 		exp += ")";
 
+		if (expected_type != type.basetype)
+			exp = bitcast_expression(type, expected_type, exp);
+
 		if (strcmp(op, "atomic_store_explicit") != 0)
 			emit_op(result_type, result_id, exp, false);
 		else
@@ -9364,7 +9381,20 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
 
 			// Manufacture automatic sampler arg for SampledImage texture
 			if (arg_type.image.dim != DimBuffer)
-				decl += join(", thread const ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
+			{
+				if (arg_type.array.empty())
+				{
+					decl += join(", ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
+				}
+				else
+				{
+					const char *sampler_address_space =
+							descriptor_address_space(name_id,
+							                         StorageClassUniformConstant,
+							                         "thread const");
+					decl += join(", ", sampler_address_space, " ", sampler_type(arg_type, arg.id), "& ", to_sampler_expression(arg.id));
+				}
+			}
 		}
 
 		// Manufacture automatic swizzle arg.
@@ -12652,6 +12682,39 @@ bool CompilerMSL::type_is_pointer_to_pointer(const SPIRType &type) const
 	return type.pointer_depth > parent_type.pointer_depth && type_is_pointer(parent_type);
 }
 
+const char *CompilerMSL::descriptor_address_space(uint32_t id, StorageClass storage, const char *plain_address_space) const
+{
+	if (msl_options.argument_buffers)
+	{
+		bool storage_class_is_descriptor = storage == StorageClassUniform ||
+		                                   storage == StorageClassStorageBuffer ||
+		                                   storage == StorageClassUniformConstant;
+
+		uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
+		if (storage_class_is_descriptor && descriptor_set_is_argument_buffer(desc_set))
+		{
+			// An awkward case where we need to emit *more* address space declarations (yay!).
+			// An example is where we pass down an array of buffer pointers to leaf functions.
+			// It's a constant array containing pointers to constants.
+			// The pointer array is always constant however. E.g.
+			// device SSBO * constant (&array)[N].
+			// const device SSBO * constant (&array)[N].
+			// constant SSBO * constant (&array)[N].
+			// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
+			// we emit the buffer array on stack instead, and that seems to work just fine apparently.
+
+			// If the argument was marked as being in device address space, any pointer to member would
+			// be const device, not constant.
+			if (argument_buffer_device_storage_mask & (1u << desc_set))
+				return "const device";
+			else
+				return "constant";
+		}
+	}
+
+	return plain_address_space;
+}
+
 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 {
 	auto &var = get<SPIRVariable>(arg.id);
@@ -12670,15 +12733,14 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
 	if (type_is_msl_framebuffer_fetch(type))
 		constref = false;
+	else if (type_storage == StorageClassUniformConstant)
+		constref = true;
 
 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
 	                     type.basetype == SPIRType::Sampler;
 
-	// Arrays of images/samplers in MSL are always const.
-	if (!type.array.empty() && type_is_image)
-		constref = true;
-
-	const char *cv_qualifier = constref ? "const " : "";
+	// For opaque types we handle const later due to descriptor address spaces.
+	const char *cv_qualifier = (constref && !type_is_image) ? "const " : "";
 	string decl;
 
 	// If this is a combined image-sampler for a 2D image with floating-point type,
@@ -12750,9 +12812,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 			decl = join(cv_qualifier, type_to_glsl(type, arg.id));
 	}
 
-	bool opaque_handle = type_storage == StorageClassUniformConstant;
-
-	if (!builtin && !opaque_handle && !is_pointer &&
+	if (!builtin && !is_pointer &&
 	    (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
 	{
 		// If the argument is a pure value and not an opaque type, we will pass by value.
@@ -12787,33 +12847,15 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 	}
 	else if (is_array(type) && !type_is_image)
 	{
-		// Arrays of images and samplers are special cased.
+		// Arrays of opaque types are special cased.
 		if (!address_space.empty())
 			decl = join(address_space, " ", decl);
 
-		if (msl_options.argument_buffers)
+		const char *argument_buffer_space = descriptor_address_space(name_id, type_storage, nullptr);
+		if (argument_buffer_space)
 		{
-			uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
-			if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) &&
-			    descriptor_set_is_argument_buffer(desc_set))
-			{
-				// An awkward case where we need to emit *more* address space declarations (yay!).
-				// An example is where we pass down an array of buffer pointers to leaf functions.
-				// It's a constant array containing pointers to constants.
-				// The pointer array is always constant however. E.g.
-				// device SSBO * constant (&array)[N].
-				// const device SSBO * constant (&array)[N].
-				// constant SSBO * constant (&array)[N].
-				// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
-				// we emit the buffer array on stack instead, and that seems to work just fine apparently.
-
-				// If the argument was marked as being in device address space, any pointer to member would
-				// be const device, not constant.
-				if (argument_buffer_device_storage_mask & (1u << desc_set))
-					decl += " const device";
-				else
-					decl += " constant";
-			}
+			decl += " ";
+			decl += argument_buffer_space;
 		}
 
 		// Special case, need to override the array size here if we're using tess level as an argument.
@@ -12857,7 +12899,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 			}
 		}
 	}
-	else if (!opaque_handle && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
+	else if (!type_is_image && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
 	{
 		// If this is going to be a reference to a variable pointer, the address space
 		// for the reference has to go before the '&', but after the '*'.
@@ -12877,6 +12919,27 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 		decl += to_restrict(name_id);
 		decl += to_expression(name_id);
 	}
+	else if (type_is_image)
+	{
+		if (type.array.empty())
+		{
+			// For non-arrayed types we can just pass opaque descriptors by value.
+			// This fixes problems if descriptors are passed by value from argument buffers and plain descriptors
+			// in same shader.
+			// There is no address space we can actually use, but value will work.
+			// This will break if applications attempt to pass down descriptor arrays as arguments, but
+			// fortunately that is extremely unlikely ...
+			decl += " ";
+			decl += to_expression(name_id);
+		}
+		else
+		{
+			const char *img_address_space = descriptor_address_space(name_id, type_storage, "thread const");
+			decl = join(img_address_space, " ", decl);
+			decl += "& ";
+			decl += to_expression(name_id);
+		}
+	}
 	else
 	{
 		if (!address_space.empty())
@@ -13565,7 +13628,7 @@ std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
 	if (variable_decl_is_remapped_storage(variable, StorageClassWorkgroup))
 		is_using_builtin_array = true;
 
-	std::string expr = CompilerGLSL::variable_decl(variable);
+	auto expr = CompilerGLSL::variable_decl(variable);
 	is_using_builtin_array = old_is_using_builtin_array;
 	return expr;
 }

+ 3 - 2
3rdparty/spirv-cross/spirv_msl.hpp

@@ -872,6 +872,7 @@ protected:
 	std::string member_attribute_qualifier(const SPIRType &type, uint32_t index);
 	std::string member_location_attribute_qualifier(const SPIRType &type, uint32_t index);
 	std::string argument_decl(const SPIRFunction::Parameter &arg);
+	const char *descriptor_address_space(uint32_t id, spv::StorageClass storage, const char *plain_address_space) const;
 	std::string round_fp_tex_coords(std::string tex_coords, bool coord_is_fp);
 	uint32_t get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane = 0);
 	uint32_t get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp = nullptr) const;
@@ -920,8 +921,8 @@ protected:
 	std::string get_tess_factor_struct_name();
 	SPIRType &get_uint_type();
 	uint32_t get_uint_type_id();
-	void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
-	                         uint32_t mem_order_2, bool has_mem_order_2, uint32_t op0, uint32_t op1 = 0,
+	void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, spv::Op opcode,
+	                         uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t op0, uint32_t op1 = 0,
 	                         bool op1_is_pointer = false, bool op1_is_literal = false, uint32_t op2 = 0);
 	const char *get_memory_order(uint32_t spv_mem_sem);
 	void add_pragma_line(const std::string &line);