Sfoglia il codice sorgente

Updated spirv-cross.

Бранимир Караџић 5 anni fa
parent
commit
6b39b61e27

+ 1 - 0
3rdparty/spirv-cross/spirv_common.hpp

@@ -1554,6 +1554,7 @@ struct AccessChainMeta
 	bool need_transpose = false;
 	bool storage_is_packed = false;
 	bool storage_is_invariant = false;
+	bool flattened_struct = false;
 };
 
 enum ExtendedDecorations

+ 179 - 44
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -2124,6 +2124,65 @@ const char *CompilerGLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
 	return "";
 }
 
+void CompilerGLSL::emit_flattened_io_block_member(const std::string &basename, const SPIRType &type, const char *qual,
+                                                  const SmallVector<uint32_t> &indices)
+{
+	uint32_t member_type_id = type.self;
+	const SPIRType *member_type = &type;
+	const SPIRType *parent_type = nullptr;
+	auto flattened_name = basename;
+	for (auto &index : indices)
+	{
+		flattened_name += "_";
+		flattened_name += to_member_name(*member_type, index);
+		parent_type = member_type;
+		member_type_id = member_type->member_types[index];
+		member_type = &get<SPIRType>(member_type_id);
+	}
+
+	assert(member_type->basetype != SPIRType::Struct);
+
+	// Sanitize underscores because joining the two identifiers might create more than 1 underscore in a row,
+	// which is not allowed.
+	flattened_name = sanitize_underscores(flattened_name);
+
+	uint32_t last_index = indices.back();
+
+	// Pass in the varying qualifier here so it will appear in the correct declaration order.
+	// Replace member name while emitting it so it encodes both struct name and member name.
+	auto backup_name = get_member_name(parent_type->self, last_index);
+	auto member_name = to_member_name(*parent_type, last_index);
+	set_member_name(parent_type->self, last_index, flattened_name);
+	emit_struct_member(*parent_type, member_type_id, last_index, qual);
+	// Restore member name.
+	set_member_name(parent_type->self, last_index, member_name);
+}
+
+void CompilerGLSL::emit_flattened_io_block_struct(const std::string &basename, const SPIRType &type, const char *qual,
+                                                  const SmallVector<uint32_t> &indices)
+{
+	auto sub_indices = indices;
+	sub_indices.push_back(0);
+
+	const SPIRType *member_type = &type;
+	for (auto &index : indices)
+		member_type = &get<SPIRType>(member_type->member_types[index]);
+
+	assert(member_type->basetype == SPIRType::Struct);
+
+	if (!member_type->array.empty())
+		SPIRV_CROSS_THROW("Cannot flatten array of structs in I/O blocks.");
+
+	for (uint32_t i = 0; i < uint32_t(member_type->member_types.size()); i++)
+	{
+		sub_indices.back() = i;
+		if (get<SPIRType>(member_type->member_types[i]).basetype == SPIRType::Struct)
+			emit_flattened_io_block_struct(basename, type, qual, sub_indices);
+		else
+			emit_flattened_io_block_member(basename, type, qual, sub_indices);
+	}
+}
+
 void CompilerGLSL::emit_flattened_io_block(const SPIRVariable &var, const char *qual)
 {
 	auto &type = get<SPIRType>(var.basetype);
@@ -2136,32 +2195,28 @@ void CompilerGLSL::emit_flattened_io_block(const SPIRVariable &var, const char *
 
 	type.member_name_cache.clear();
 
+	SmallVector<uint32_t> member_indices;
+	member_indices.push_back(0);
+	auto basename = to_name(var.self);
+
 	uint32_t i = 0;
 	for (auto &member : type.member_types)
 	{
 		add_member_name(type, i);
 		auto &membertype = get<SPIRType>(member);
 
+		member_indices.back() = i;
 		if (membertype.basetype == SPIRType::Struct)
-			SPIRV_CROSS_THROW("Cannot flatten struct inside structs in I/O variables.");
-
-		// Pass in the varying qualifier here so it will appear in the correct declaration order.
-		// Replace member name while emitting it so it encodes both struct name and member name.
-		// Sanitize underscores because joining the two identifiers might create more than 1 underscore in a row,
-		// which is not allowed.
-		auto backup_name = get_member_name(type.self, i);
-		auto member_name = to_member_name(type, i);
-		set_member_name(type.self, i, sanitize_underscores(join(to_name(var.self), "_", member_name)));
-		emit_struct_member(type, member, i, qual);
-		// Restore member name.
-		set_member_name(type.self, i, member_name);
+			emit_flattened_io_block_struct(basename, type, qual, member_indices);
+		else
+			emit_flattened_io_block_member(basename, type, qual, member_indices);
 		i++;
 	}
 
 	ir.meta[type.self].decoration.decoration_flags = old_flags;
 
-	// Treat this variable as flattened from now on.
-	flattened_structs.insert(var.self);
+	// Treat this variable as fully flattened from now on.
+	flattened_structs[var.self] = true;
 }
 
 void CompilerGLSL::emit_interface_block(const SPIRVariable &var)
@@ -3502,6 +3557,10 @@ string CompilerGLSL::to_expression(uint32_t id, bool register_expression_read)
 			return convert_row_major_matrix(e.expression, get<SPIRType>(e.expression_type), physical_type_id,
 			                                is_packed);
 		}
+		else if (flattened_structs.count(id))
+		{
+			return load_flattened_struct(e.expression, get<SPIRType>(e.expression_type));
+		}
 		else
 		{
 			if (is_forcing_recompilation())
@@ -3554,7 +3613,7 @@ string CompilerGLSL::to_expression(uint32_t id, bool register_expression_read)
 		}
 		else if (flattened_structs.count(id))
 		{
-			return load_flattened_struct(var);
+			return load_flattened_struct(to_name(id), get<SPIRType>(var.basetype));
 		}
 		else
 		{
@@ -7367,6 +7426,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 	bool chain_only = (flags & ACCESS_CHAIN_CHAIN_ONLY_BIT) != 0;
 	bool ptr_chain = (flags & ACCESS_CHAIN_PTR_CHAIN_BIT) != 0;
 	bool register_expression_read = (flags & ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT) == 0;
+	bool flatten_member_reference = (flags & ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT) != 0;
 
 	if (!chain_only)
 	{
@@ -7581,6 +7641,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 				string qual_mbr_name = get_member_qualified_name(type_id, index);
 				if (!qual_mbr_name.empty())
 					expr = qual_mbr_name;
+				else if (flatten_member_reference)
+					expr += join("_", to_member_name(*type, index));
 				else
 					expr += to_member_reference(base, *type, index, ptr_chain);
 			}
@@ -7629,6 +7691,23 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 				}
 			}
 
+			// Internally, access chain implementation can also be used on composites,
+			// ignore scalar access workarounds in this case.
+			StorageClass effective_storage;
+			if (expression_type(base).pointer)
+				effective_storage = get_expression_effective_storage_class(base);
+			else
+				effective_storage = StorageClassGeneric;
+
+			if (!row_major_matrix_needs_conversion)
+			{
+				// On some backends, we might not be able to safely access individual scalars in a vector.
+				// To work around this, we might have to cast the access chain reference to something which can,
+				// like a pointer to scalar, which we can then index into.
+				prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
+				                                       is_packed);
+			}
+
 			if (is_literal && !is_packed && !row_major_matrix_needs_conversion)
 			{
 				expr += ".";
@@ -7660,6 +7739,12 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 				expr += "]";
 			}
 
+			if (row_major_matrix_needs_conversion)
+			{
+				prepare_access_chain_for_scalar_access(expr, get<SPIRType>(type->parent_type), effective_storage,
+				                                       is_packed);
+			}
+
 			expr += deferred_index;
 			row_major_matrix_needs_conversion = false;
 
@@ -7690,10 +7775,13 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 	return expr;
 }
 
-string CompilerGLSL::to_flattened_struct_member(const SPIRVariable &var, uint32_t index)
+void CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
 {
-	auto &type = get<SPIRType>(var.basetype);
-	return sanitize_underscores(join(to_name(var.self), "_", to_member_name(type, index)));
+}
+
+string CompilerGLSL::to_flattened_struct_member(const string &basename, const SPIRType &type, uint32_t index)
+{
+	return sanitize_underscores(join(basename, "_", to_member_name(type, index)));
 }
 
 string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
@@ -7722,13 +7810,22 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
 		if (ptr_chain)
 			flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
 
+		if (flattened_structs[base])
+		{
+			flags |= ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT;
+			if (meta)
+				meta->flattened_struct = target_type.basetype == SPIRType::Struct;
+		}
+
 		auto chain = access_chain_internal(base, indices, count, flags, nullptr).substr(1);
 		if (meta)
 		{
 			meta->need_transpose = false;
 			meta->storage_is_packed = false;
 		}
-		return sanitize_underscores(join(to_name(base), "_", chain));
+
+		auto basename = to_flattened_access_chain_expression(base);
+		return sanitize_underscores(join(basename, "_", chain));
 	}
 	else
 	{
@@ -7739,48 +7836,72 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
 	}
 }
 
-string CompilerGLSL::load_flattened_struct(SPIRVariable &var)
+string CompilerGLSL::load_flattened_struct(const string &basename, const SPIRType &type)
 {
-	auto expr = type_to_glsl_constructor(get<SPIRType>(var.basetype));
+	auto expr = type_to_glsl_constructor(type);
 	expr += '(';
 
-	auto &type = get<SPIRType>(var.basetype);
 	for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
 	{
 		if (i)
 			expr += ", ";
 
-		// Flatten the varyings.
-		// Apply name transformation for flattened I/O blocks.
-		expr += to_flattened_struct_member(var, i);
+		auto &member_type = get<SPIRType>(type.member_types[i]);
+		if (member_type.basetype == SPIRType::Struct)
+			expr += load_flattened_struct(to_flattened_struct_member(basename, type, i), member_type);
+		else
+			expr += to_flattened_struct_member(basename, type, i);
 	}
 	expr += ')';
 	return expr;
 }
 
-void CompilerGLSL::store_flattened_struct(SPIRVariable &var, uint32_t value)
+std::string CompilerGLSL::to_flattened_access_chain_expression(uint32_t id)
 {
-	// We're trying to store a structure which has been flattened.
-	// Need to copy members one by one.
-	auto rhs = to_expression(value);
+	// Do not use to_expression as that will unflatten access chains.
+	string basename;
+	if (const auto *var = maybe_get<SPIRVariable>(id))
+		basename = to_name(var->self);
+	else if (const auto *expr = maybe_get<SPIRExpression>(id))
+		basename = expr->expression;
+	else
+		basename = to_expression(id);
 
-	// Store result locally.
-	// Since we're declaring a variable potentially multiple times here,
-	// store the variable in an isolated scope.
-	begin_scope();
-	statement(variable_decl_function_local(var), " = ", rhs, ";");
+	return basename;
+}
 
-	auto &type = get<SPIRType>(var.basetype);
-	for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
+void CompilerGLSL::store_flattened_struct(const string &basename, uint32_t rhs_id, const SPIRType &type,
+                                          const SmallVector<uint32_t> &indices)
+{
+	SmallVector<uint32_t> sub_indices = indices;
+	sub_indices.push_back(0);
+
+	auto *member_type = &type;
+	for (auto &index : indices)
+		member_type = &get<SPIRType>(member_type->member_types[index]);
+
+	for (uint32_t i = 0; i < uint32_t(member_type->member_types.size()); i++)
 	{
-		// Flatten the varyings.
-		// Apply name transformation for flattened I/O blocks.
+		sub_indices.back() = i;
+		auto lhs = sanitize_underscores(join(basename, "_", to_member_name(*member_type, i)));
 
-		auto lhs = sanitize_underscores(join(to_name(var.self), "_", to_member_name(type, i)));
-		rhs = join(to_name(var.self), ".", to_member_name(type, i));
-		statement(lhs, " = ", rhs, ";");
+		if (get<SPIRType>(member_type->member_types[i]).basetype == SPIRType::Struct)
+		{
+			store_flattened_struct(lhs, rhs_id, type, sub_indices);
+		}
+		else
+		{
+			auto rhs = to_expression(rhs_id) + to_multi_member_reference(type, sub_indices);
+			statement(lhs, " = ", rhs, ";");
+		}
 	}
-	end_scope();
+}
+
+void CompilerGLSL::store_flattened_struct(uint32_t lhs_id, uint32_t value)
+{
+	auto &type = expression_type(lhs_id);
+	auto basename = to_flattened_access_chain_expression(lhs_id);
+	store_flattened_struct(basename, value, type, {});
 }
 
 std::string CompilerGLSL::flattened_access_chain(uint32_t base, const uint32_t *indices, uint32_t count,
@@ -8850,6 +8971,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
 		if (meta.storage_is_invariant)
 			set_decoration(ops[1], DecorationInvariant);
+		if (meta.flattened_struct)
+			flattened_structs[ops[1]] = true;
 
 		// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
 		// temporary which could be subject to invalidation.
@@ -8887,9 +9010,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		{
 			// Skip the write.
 		}
-		else if (var && flattened_structs.count(ops[0]))
+		else if (flattened_structs.count(ops[0]))
 		{
-			store_flattened_struct(*var, ops[1]);
+			store_flattened_struct(ops[0], ops[1]);
 			register_write(ops[0]);
 		}
 		else
@@ -11279,6 +11402,18 @@ string CompilerGLSL::to_member_reference(uint32_t, const SPIRType &type, uint32_
 	return join(".", to_member_name(type, index));
 }
 
+string CompilerGLSL::to_multi_member_reference(const SPIRType &type, const SmallVector<uint32_t> &indices)
+{
+	string ret;
+	auto *member_type = &type;
+	for (auto &index : indices)
+	{
+		ret += join(".", to_member_name(*member_type, index));
+		member_type = &get<SPIRType>(member_type->member_types[index]);
+	}
+	return ret;
+}
+
 void CompilerGLSL::add_member_name(SPIRType &type, uint32_t index)
 {
 	auto &memb = ir.meta[type.self].members;

+ 18 - 6
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -57,7 +57,8 @@ enum AccessChainFlagBits
 	ACCESS_CHAIN_CHAIN_ONLY_BIT = 1 << 1,
 	ACCESS_CHAIN_PTR_CHAIN_BIT = 1 << 2,
 	ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
-	ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4
+	ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
+	ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5
 };
 typedef uint32_t AccessChainFlags;
 
@@ -488,6 +489,10 @@ protected:
 	void emit_push_constant_block_glsl(const SPIRVariable &var);
 	void emit_interface_block(const SPIRVariable &type);
 	void emit_flattened_io_block(const SPIRVariable &var, const char *qual);
+	void emit_flattened_io_block_struct(const std::string &basename, const SPIRType &type, const char *qual,
+	                                    const SmallVector<uint32_t> &indices);
+	void emit_flattened_io_block_member(const std::string &basename, const SPIRType &type, const char *qual,
+	                                    const SmallVector<uint32_t> &indices);
 	void emit_block_chain(SPIRBlock &block);
 	void emit_hoisted_temporaries(SmallVector<std::pair<TypeID, ID>> &temporaries);
 	std::string constant_value_macro_name(uint32_t id);
@@ -560,6 +565,9 @@ protected:
 	std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
 	                                  AccessChainMeta *meta);
 
+	virtual void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
+	                                                    spv::StorageClass storage, bool &is_packed);
+
 	std::string access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
 	                         AccessChainMeta *meta = nullptr, bool ptr_chain = false);
 
@@ -604,6 +612,7 @@ protected:
 	void strip_enclosed_expression(std::string &expr);
 	std::string to_member_name(const SPIRType &type, uint32_t index);
 	virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain);
+	std::string to_multi_member_reference(const SPIRType &type, const SmallVector<uint32_t> &indices);
 	std::string type_to_glsl_constructor(const SPIRType &type);
 	std::string argument_decl(const SPIRFunction::Parameter &arg);
 	virtual std::string to_qualifiers_glsl(uint32_t id);
@@ -664,11 +673,14 @@ protected:
 	std::unordered_set<uint32_t> flushed_phi_variables;
 
 	std::unordered_set<uint32_t> flattened_buffer_blocks;
-	std::unordered_set<uint32_t> flattened_structs;
-
-	std::string load_flattened_struct(SPIRVariable &var);
-	std::string to_flattened_struct_member(const SPIRVariable &var, uint32_t index);
-	void store_flattened_struct(SPIRVariable &var, uint32_t value);
+	std::unordered_map<uint32_t, bool> flattened_structs;
+
+	std::string load_flattened_struct(const std::string &basename, const SPIRType &type);
+	std::string to_flattened_struct_member(const std::string &basename, const SPIRType &type, uint32_t index);
+	void store_flattened_struct(uint32_t lhs_id, uint32_t value);
+	void store_flattened_struct(const std::string &basename, uint32_t rhs, const SPIRType &type,
+	                            const SmallVector<uint32_t> &indices);
+	std::string to_flattened_access_chain_expression(uint32_t id);
 
 	// Usage tracking. If a temporary is used more than once, use the temporary instead to
 	// avoid AST explosion when SPIRV is generated with pure SSA and doesn't write stuff to variables.

+ 2 - 2
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -2005,7 +2005,7 @@ void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
 		{
 			// Flatten the top-level struct so we can use packoffset,
 			// this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
-			flattened_structs.insert(var.self);
+			flattened_structs[var.self] = false;
 
 			// Prefer the block name if possible.
 			auto buffer_name = to_name(type.self, false);
@@ -2110,7 +2110,7 @@ void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
 				                       ") cannot be expressed with either HLSL packing layout or packoffset."));
 			}
 
-			flattened_structs.insert(var.self);
+			flattened_structs[var.self] = false;
 			type.member_name_cache.clear();
 			add_resource_name(var.self);
 			auto &memb = ir.meta[type.self].members;

+ 50 - 18
3rdparty/spirv-cross/spirv_msl.cpp

@@ -1278,6 +1278,11 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				uint32_t base_id = ops[0];
 				if (global_var_ids.find(base_id) != global_var_ids.end())
 					added_arg_ids.insert(base_id);
+				
+                		uint32_t rvalue_id = ops[1];
+                		if (global_var_ids.find(rvalue_id) != global_var_ids.end())
+                    			added_arg_ids.insert(rvalue_id);
+				
 				break;
 			}
 
@@ -3335,15 +3340,10 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 
 		auto &physical_type = get<SPIRType>(physical_type_id);
 
-		static const char *swizzle_lut[] = {
-			".x",
-			".xy",
-			".xyz",
-			"",
-		};
-
 		if (is_matrix(type))
 		{
+			const char *packed_pfx = lhs_packed_type ? "packed_" : "";
+
 			// Packed matrices are stored as arrays of packed vectors, so we need
 			// to assign the vectors one at a time.
 			// For row-major matrices, we need to transpose the *right-hand* side,
@@ -3352,6 +3352,8 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 			// Lots of cases to cover here ...
 
 			bool rhs_transpose = rhs_e && rhs_e->need_transpose;
+			SPIRType write_type = type;
+			string cast_expr;
 
 			// We're dealing with transpose manually.
 			if (rhs_transpose)
@@ -3361,17 +3363,18 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 			{
 				// We're dealing with transpose manually.
 				lhs_e->need_transpose = false;
+				write_type.vecsize = type.columns;
+				write_type.columns = 1;
 
-				const char *store_swiz = "";
 				if (physical_type.columns != type.columns)
-					store_swiz = swizzle_lut[type.columns - 1];
+					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
 
 				if (rhs_transpose)
 				{
 					// If RHS is also transposed, we can just copy row by row.
 					for (uint32_t i = 0; i < type.vecsize; i++)
 					{
-						statement(to_enclosed_expression(lhs_expression), "[", i, "]", store_swiz, " = ",
+						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
 						          to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
 					}
 				}
@@ -3394,7 +3397,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 						}
 						rhs_row += ")";
 
-						statement(to_enclosed_expression(lhs_expression), "[", i, "]", store_swiz, " = ", rhs_row, ";");
+						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
 					}
 				}
 
@@ -3403,9 +3406,10 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 			}
 			else
 			{
-				const char *store_swiz = "";
+				write_type.columns = 1;
+
 				if (physical_type.vecsize != type.vecsize)
-					store_swiz = swizzle_lut[type.vecsize - 1];
+					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
 
 				if (rhs_transpose)
 				{
@@ -3427,7 +3431,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 						}
 						rhs_row += ")";
 
-						statement(to_enclosed_expression(lhs_expression), "[", i, "]", store_swiz, " = ", rhs_row, ";");
+						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
 					}
 				}
 				else
@@ -3435,7 +3439,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 					// Copy column-by-column.
 					for (uint32_t i = 0; i < type.columns; i++)
 					{
-						statement(to_enclosed_expression(lhs_expression), "[", i, "]", store_swiz, " = ",
+						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
 						          to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
 					}
 				}
@@ -3449,6 +3453,10 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 		{
 			lhs_e->need_transpose = false;
 
+			SPIRType write_type = type;
+			write_type.vecsize = 1;
+			write_type.columns = 1;
+
 			// Storing a column to a row-major matrix. Unroll the write.
 			for (uint32_t c = 0; c < type.vecsize; c++)
 			{
@@ -3456,7 +3464,8 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 				auto column_index = lhs_expr.find_last_of('[');
 				if (column_index != string::npos)
 				{
-					statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
+					statement("((device ", type_to_glsl(write_type), "*)&",
+					          lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
 					          to_extract_component_expression(rhs_expression, c), ";");
 				}
 			}
@@ -3478,7 +3487,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
 
 			// Unpack the expression so we can store to it with a float or float2.
 			// It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
-			lhs = enclose_expression(lhs) + swizzle_lut[type.vecsize - 1];
+			lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
 				statement(lhs, " = ", rhs, ";");
 		}
@@ -5859,6 +5868,23 @@ bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
 	       (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
 }
 
+void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
+                                                         spv::StorageClass storage, bool &is_packed)
+{
+	// If there is any risk of writes happening with the access chain in question,
+	// and there is a risk of concurrent write access to other components,
+	// we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
+	// The MSL compiler refuses to allow component-level access for any non-packed vector types.
+	if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
+	{
+		const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
+		expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
+
+		// Further indexing should happen with packed rules (array index, not swizzle).
+		is_packed = true;
+	}
+}
+
 // Override for MSL-specific syntax instructions
 void CompilerMSL::emit_instruction(const Instruction &instruction)
 {
@@ -10116,7 +10142,13 @@ uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::Base
 	// If a binding has not been specified, revert to incrementing resource indices.
 	uint32_t resource_index;
 
-	if (allocate_argument_buffer_ids)
+	if (type_is_msl_framebuffer_fetch(type))
+	{
+		// Frame-buffer fetch gets its fallback resource index from the input attachment index,
+		// which is then treated as color index.
+		resource_index = get_decoration(var.self, DecorationInputAttachmentIndex);
+	}
+	else if (allocate_argument_buffer_ids)
 	{
 		// Allocate from a flat ID binding space.
 		resource_index = next_metal_resource_ids[var_desc_set];

+ 2 - 0
3rdparty/spirv-cross/spirv_msl.hpp

@@ -793,6 +793,8 @@ protected:
 
 	void analyze_sampled_image_usage();
 
+	void prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
+	                                            bool &is_packed) override;
 	bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length);
 	bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);
 	bool is_out_of_bounds_tessellation_level(uint32_t id_lhs);