Browse Source

Updated spirv-cross.

Бранимир Караџић 3 years ago
parent
commit
041c4c75ff

+ 14 - 0
3rdparty/spirv-cross/main.cpp

@@ -672,6 +672,8 @@ struct CLIArguments
 	bool msl_emulate_subgroups = false;
 	uint32_t msl_fixed_subgroup_size = 0;
 	bool msl_force_sample_rate_shading = false;
+	bool msl_manual_helper_invocation_updates = true;
+	bool msl_check_discarded_frag_stores = false;
 	const char *msl_combined_sampler_suffix = nullptr;
 	bool glsl_emit_push_constant_as_ubo = false;
 	bool glsl_emit_ubo_as_plain_uniforms = false;
@@ -934,6 +936,13 @@ static void print_help_msl()
 	                "\t\tIf 0, assume variable subgroup size as actually exposed by Metal.\n"
 	                "\t[--msl-force-sample-rate-shading]:\n\t\tForce fragment shaders to run per sample.\n"
 	                "\t\tThis adds a [[sample_id]] parameter if none is already present.\n"
+	                "\t[--msl-no-manual-helper-invocation-updates]:\n\t\tDo not manually update the HelperInvocation builtin when a fragment is discarded.\n"
+	                "\t\tSome Metal devices have a bug where simd_is_helper_thread() does not return true\n"
+	                "\t\tafter the fragment is discarded. This behavior is required by Vulkan and SPIR-V, however.\n"
+	                "\t[--msl-check-discarded-frag-stores]:\n\t\tAdd additional checks to resource stores in a fragment shader.\n"
+	                "\t\tSome Metal devices have a bug where stores to resources from a fragment shader\n"
+	                "\t\tcontinue to execute, even when the fragment is discarded. These checks\n"
+	                "\t\tprevent these stores from executing.\n"
 	                "\t[--msl-combined-sampler-suffix <suffix>]:\n\t\tUses a custom suffix for combined samplers.\n");
 	// clang-format on
 }
@@ -1205,6 +1214,8 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
 		msl_opts.emulate_subgroups = args.msl_emulate_subgroups;
 		msl_opts.fixed_subgroup_size = args.msl_fixed_subgroup_size;
 		msl_opts.force_sample_rate_shading = args.msl_force_sample_rate_shading;
+		msl_opts.manual_helper_invocation_updates = args.msl_manual_helper_invocation_updates;
+		msl_opts.check_discarded_frag_stores = args.msl_check_discarded_frag_stores;
 		msl_opts.ios_support_base_vertex_instance = true;
 		msl_comp->set_msl_options(msl_opts);
 		for (auto &v : args.msl_discrete_descriptor_sets)
@@ -1751,6 +1762,9 @@ static int main_inner(int argc, char *argv[])
 	cbs.add("--msl-fixed-subgroup-size",
 	        [&args](CLIParser &parser) { args.msl_fixed_subgroup_size = parser.next_uint(); });
 	cbs.add("--msl-force-sample-rate-shading", [&args](CLIParser &) { args.msl_force_sample_rate_shading = true; });
+	cbs.add("--msl-no-manual-helper-invocation-updates",
+	        [&args](CLIParser &) { args.msl_manual_helper_invocation_updates = false; });
+	cbs.add("--msl-check-discarded-frag-stores", [&args](CLIParser &) { args.msl_check_discarded_frag_stores = true; });
 	cbs.add("--msl-combined-sampler-suffix", [&args](CLIParser &parser) {
 		args.msl_combined_sampler_suffix = parser.next_string();
 	});

+ 29 - 1
3rdparty/spirv-cross/spirv_common.hpp

@@ -643,7 +643,8 @@ struct SPIRExtension : IVariant
 		SPV_AMD_shader_explicit_vertex_parameter,
 		SPV_AMD_shader_trinary_minmax,
 		SPV_AMD_gcn_shader,
-		NonSemanticDebugPrintf
+		NonSemanticDebugPrintf,
+		NonSemanticShaderDebugInfo
 	};
 
 	explicit SPIRExtension(Extension ext_)
@@ -1796,6 +1797,33 @@ static inline bool opcode_is_sign_invariant(spv::Op opcode)
 	}
 }
 
+static inline bool opcode_can_promote_integer_implicitly(spv::Op opcode)
+{
+	switch (opcode)
+	{
+	case spv::OpSNegate:
+	case spv::OpNot:
+	case spv::OpBitwiseAnd:
+	case spv::OpBitwiseOr:
+	case spv::OpBitwiseXor:
+	case spv::OpShiftLeftLogical:
+	case spv::OpShiftRightLogical:
+	case spv::OpShiftRightArithmetic:
+	case spv::OpIAdd:
+	case spv::OpISub:
+	case spv::OpIMul:
+	case spv::OpSDiv:
+	case spv::OpUDiv:
+	case spv::OpSRem:
+	case spv::OpUMod:
+	case spv::OpSMod:
+		return true;
+
+	default:
+		return false;
+	}
+}
+
 struct SetBindingPair
 {
 	uint32_t desc_set;

+ 0 - 2
3rdparty/spirv-cross/spirv_cpp.cpp

@@ -274,8 +274,6 @@ void CompilerCPP::emit_resources()
 	if (emitted)
 		statement("");
 
-	declare_undefined_values();
-
 	statement("inline void init(spirv_cross_shader& s)");
 	begin_scope();
 	statement(resource_type, "::init(s);");

+ 1 - 1
3rdparty/spirv-cross/spirv_cross.cpp

@@ -725,7 +725,7 @@ bool Compiler::InterfaceVariableAccessHandler::handle(Op opcode, const uint32_t
 
 	case OpExtInst:
 	{
-		if (length < 5)
+		if (length < 3)
 			return false;
 		auto &extension_set = compiler.get<SPIRExtension>(args[2]);
 		switch (extension_set.ext)

+ 8 - 0
3rdparty/spirv-cross/spirv_cross_c.cpp

@@ -723,6 +723,14 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
 	case SPVC_COMPILER_OPTION_MSL_SHADER_PATCH_INPUT_BUFFER_INDEX:
 		options->msl.shader_patch_input_buffer_index = value;
 		break;
+
+	case SPVC_COMPILER_OPTION_MSL_MANUAL_HELPER_INVOCATION_UPDATES:
+		options->msl.manual_helper_invocation_updates = value != 0;
+		break;
+
+	case SPVC_COMPILER_OPTION_MSL_CHECK_DISCARDED_FRAG_STORES:
+		options->msl.check_discarded_frag_stores = value != 0;
+		break;
 #endif
 
 	default:

+ 3 - 1
3rdparty/spirv-cross/spirv_cross_c.h

@@ -40,7 +40,7 @@ extern "C" {
 /* Bumped if ABI or API breaks backwards compatibility. */
 #define SPVC_C_API_VERSION_MAJOR 0
 /* Bumped if APIs or enumerations are added in a backwards compatible way. */
-#define SPVC_C_API_VERSION_MINOR 51
+#define SPVC_C_API_VERSION_MINOR 52
 /* Bumped if internal implementation details change. */
 #define SPVC_C_API_VERSION_PATCH 0
 
@@ -718,6 +718,8 @@ typedef enum spvc_compiler_option
 
 	SPVC_COMPILER_OPTION_MSL_RAW_BUFFER_TESE_INPUT = 79 | SPVC_COMPILER_OPTION_MSL_BIT,
 	SPVC_COMPILER_OPTION_MSL_SHADER_PATCH_INPUT_BUFFER_INDEX = 80 | SPVC_COMPILER_OPTION_MSL_BIT,
+	SPVC_COMPILER_OPTION_MSL_MANUAL_HELPER_INVOCATION_UPDATES = 81 | SPVC_COMPILER_OPTION_MSL_BIT,
+	SPVC_COMPILER_OPTION_MSL_CHECK_DISCARDED_FRAG_STORES = 82 | SPVC_COMPILER_OPTION_MSL_BIT,
 
 	SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
 } spvc_compiler_option;

+ 9 - 2
3rdparty/spirv-cross/spirv_cross_containers.hpp

@@ -210,7 +210,8 @@ public:
 		buffer_capacity = N;
 	}
 
-	SmallVector(const T *arg_list_begin, const T *arg_list_end) SPIRV_CROSS_NOEXCEPT : SmallVector()
+	template <typename U>
+	SmallVector(const U *arg_list_begin, const U *arg_list_end) SPIRV_CROSS_NOEXCEPT : SmallVector()
 	{
 		auto count = size_t(arg_list_end - arg_list_begin);
 		reserve(count);
@@ -219,7 +220,13 @@ public:
 		this->buffer_size = count;
 	}
 
-	SmallVector(std::initializer_list<T> init) SPIRV_CROSS_NOEXCEPT : SmallVector(init.begin(), init.end())
+	template <typename U>
+	SmallVector(std::initializer_list<U> init) SPIRV_CROSS_NOEXCEPT : SmallVector(init.begin(), init.end())
+	{
+	}
+
+	template <typename U, size_t M>
+	SmallVector(const U (&init)[M]) SPIRV_CROSS_NOEXCEPT : SmallVector(init, init + M)
 	{
 	}
 

+ 5 - 4
3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp

@@ -66,7 +66,7 @@ ParsedIR &ParsedIR::operator=(ParsedIR &&other) SPIRV_CROSS_NOEXCEPT
 		meta = std::move(other.meta);
 		for (int i = 0; i < TypeCount; i++)
 			ids_for_type[i] = std::move(other.ids_for_type[i]);
-		ids_for_constant_or_type = std::move(other.ids_for_constant_or_type);
+		ids_for_constant_undef_or_type = std::move(other.ids_for_constant_undef_or_type);
 		ids_for_constant_or_variable = std::move(other.ids_for_constant_or_variable);
 		declared_capabilities = std::move(other.declared_capabilities);
 		declared_extensions = std::move(other.declared_extensions);
@@ -102,7 +102,7 @@ ParsedIR &ParsedIR::operator=(const ParsedIR &other)
 		meta = other.meta;
 		for (int i = 0; i < TypeCount; i++)
 			ids_for_type[i] = other.ids_for_type[i];
-		ids_for_constant_or_type = other.ids_for_constant_or_type;
+		ids_for_constant_undef_or_type = other.ids_for_constant_undef_or_type;
 		ids_for_constant_or_variable = other.ids_for_constant_or_variable;
 		declared_capabilities = other.declared_capabilities;
 		declared_extensions = other.declared_extensions;
@@ -934,7 +934,7 @@ void ParsedIR::add_typed_id(Types type, ID id)
 		{
 		case TypeConstant:
 			ids_for_constant_or_variable.push_back(id);
-			ids_for_constant_or_type.push_back(id);
+			ids_for_constant_undef_or_type.push_back(id);
 			break;
 
 		case TypeVariable:
@@ -943,7 +943,8 @@ void ParsedIR::add_typed_id(Types type, ID id)
 
 		case TypeType:
 		case TypeConstantOp:
-			ids_for_constant_or_type.push_back(id);
+		case TypeUndef:
+			ids_for_constant_undef_or_type.push_back(id);
 			break;
 
 		default:

+ 2 - 2
3rdparty/spirv-cross/spirv_cross_parsed_ir.hpp

@@ -74,8 +74,8 @@ public:
 	// Special purpose lists which contain a union of types.
 	// This is needed so we can declare specialization constants and structs in an interleaved fashion,
 	// among other things.
-	// Constants can be of struct type, and struct array sizes can use specialization constants.
-	SmallVector<ID> ids_for_constant_or_type;
+	// Constants can be undef or of struct type, and struct array sizes can use specialization constants.
+	SmallVector<ID> ids_for_constant_undef_or_type;
 	SmallVector<ID> ids_for_constant_or_variable;
 
 	// We need to keep track of the width the Ops that contains a type for the

+ 279 - 52
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -3146,9 +3146,30 @@ void CompilerGLSL::fixup_implicit_builtin_block_names(ExecutionModel model)
 			{
 				auto flags = get_buffer_block_flags(var.self);
 				if (flags.get(DecorationPerPrimitiveEXT))
+				{
 					set_name(var.self, "gl_MeshPrimitivesEXT");
+					set_name(type.self, "gl_MeshPerPrimitiveEXT");
+				}
 				else
+				{
 					set_name(var.self, "gl_MeshVerticesEXT");
+					set_name(type.self, "gl_MeshPerVertexEXT");
+				}
+			}
+		}
+
+		if (model == ExecutionModelMeshEXT && var.storage == StorageClassOutput && !block)
+		{
+			auto *m = ir.find_meta(var.self);
+			if (m && m->decoration.builtin)
+			{
+				auto builtin_type = m->decoration.builtin_type;
+				if (builtin_type == BuiltInPrimitivePointIndicesEXT)
+					set_name(var.self, "gl_PrimitivePointIndicesEXT");
+				else if (builtin_type == BuiltInPrimitiveLineIndicesEXT)
+					set_name(var.self, "gl_PrimitiveLineIndicesEXT");
+				else if (builtin_type == BuiltInPrimitiveTriangleIndicesEXT)
+					set_name(var.self, "gl_PrimitiveTriangleIndicesEXT");
 			}
 		}
 	});
@@ -3395,27 +3416,6 @@ void CompilerGLSL::emit_declared_builtin_block(StorageClass storage, ExecutionMo
 	statement("");
 }
 
-void CompilerGLSL::declare_undefined_values()
-{
-	bool emitted = false;
-	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
-		auto &type = this->get<SPIRType>(undef.basetype);
-		// OpUndef can be void for some reason ...
-		if (type.basetype == SPIRType::Void)
-			return;
-
-		string initializer;
-		if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
-			initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
-
-		statement(variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
-		emitted = true;
-	});
-
-	if (emitted)
-		statement("");
-}
-
 bool CompilerGLSL::variable_is_lut(const SPIRVariable &var) const
 {
 	bool statically_assigned = var.statically_assigned && var.static_expression != ID(0) && var.remapped_variable;
@@ -3516,7 +3516,7 @@ void CompilerGLSL::emit_resources()
 	//
 	{
 		auto loop_lock = ir.create_loop_hard_lock();
-		for (auto &id_ : ir.ids_for_constant_or_type)
+		for (auto &id_ : ir.ids_for_constant_undef_or_type)
 		{
 			auto &id = ir.ids[id_];
 
@@ -3569,6 +3569,22 @@ void CompilerGLSL::emit_resources()
 					emit_struct(*type);
 				}
 			}
+			else if (id.get_type() == TypeUndef)
+			{
+				auto &undef = id.get<SPIRUndef>();
+				auto &type = this->get<SPIRType>(undef.basetype);
+				// OpUndef can be void for some reason ...
+				if (type.basetype == SPIRType::Void)
+					return;
+
+				string initializer;
+				if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
+					initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
+
+				// FIXME: If used in a constant, we must declare it as one.
+				statement(variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
+				emitted = true;
+			}
 		}
 	}
 
@@ -3785,8 +3801,6 @@ void CompilerGLSL::emit_resources()
 
 	if (emitted)
 		statement("");
-
-	declare_undefined_values();
 }
 
 void CompilerGLSL::emit_output_variable_initializer(const SPIRVariable &var)
@@ -4859,6 +4873,9 @@ string CompilerGLSL::to_expression(uint32_t id, bool register_expression_read)
 				}
 			}
 
+			if (expression_is_forwarded(id))
+				return constant_expression(c);
+
 			return to_name(id);
 		}
 		else if (c.is_used_as_lut)
@@ -4930,6 +4947,80 @@ string CompilerGLSL::to_expression(uint32_t id, bool register_expression_read)
 	}
 }
 
+SmallVector<ConstantID> CompilerGLSL::get_composite_constant_ids(ConstantID const_id)
+{
+	if (auto *constant = maybe_get<SPIRConstant>(const_id))
+	{
+		const auto &type = get<SPIRType>(constant->constant_type);
+		if (is_array(type) || type.basetype == SPIRType::Struct)
+			return constant->subconstants;
+		if (is_matrix(type))
+			return constant->m.id;
+		if (is_vector(type))
+			return constant->m.c[0].id;
+		SPIRV_CROSS_THROW("Unexpected scalar constant!");
+	}
+	if (!const_composite_insert_ids.count(const_id))
+		SPIRV_CROSS_THROW("Unimplemented for this OpSpecConstantOp!");
+	return const_composite_insert_ids[const_id];
+}
+
+void CompilerGLSL::fill_composite_constant(SPIRConstant &constant, TypeID type_id,
+                                           const SmallVector<ConstantID> &initializers)
+{
+	auto &type = get<SPIRType>(type_id);
+	constant.specialization = true;
+	if (is_array(type) || type.basetype == SPIRType::Struct)
+	{
+		constant.subconstants = initializers;
+	}
+	else if (is_matrix(type))
+	{
+		constant.m.columns = type.columns;
+		for (uint32_t i = 0; i < type.columns; ++i)
+		{
+			constant.m.id[i] = initializers[i];
+			constant.m.c[i].vecsize = type.vecsize;
+		}
+	}
+	else if (is_vector(type))
+	{
+		constant.m.c[0].vecsize = type.vecsize;
+		for (uint32_t i = 0; i < type.vecsize; ++i)
+			constant.m.c[0].id[i] = initializers[i];
+	}
+	else
+		SPIRV_CROSS_THROW("Unexpected scalar in SpecConstantOp CompositeInsert!");
+}
+
+void CompilerGLSL::set_composite_constant(ConstantID const_id, TypeID type_id,
+                                          const SmallVector<ConstantID> &initializers)
+{
+	if (maybe_get<SPIRConstantOp>(const_id))
+	{
+		const_composite_insert_ids[const_id] = initializers;
+		return;
+	}
+
+	auto &constant = set<SPIRConstant>(const_id, type_id);
+	fill_composite_constant(constant, type_id, initializers);
+	forwarded_temporaries.insert(const_id);
+}
+
+TypeID CompilerGLSL::get_composite_member_type(TypeID type_id, uint32_t member_idx)
+{
+	auto &type = get<SPIRType>(type_id);
+	if (is_array(type))
+		return type.parent_type;
+	if (type.basetype == SPIRType::Struct)
+		return type.member_types[member_idx];
+	if (is_matrix(type))
+		return type.parent_type;
+	if (is_vector(type))
+		return type.parent_type;
+	SPIRV_CROSS_THROW("Shouldn't reach lower than vector handling OpSpecConstantOp CompositeInsert!");
+}
+
 string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop)
 {
 	auto &type = get<SPIRType>(cop.basetype);
@@ -5034,10 +5125,21 @@ string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop)
 		for (uint32_t i = 2; i < uint32_t(cop.arguments.size()); i++)
 		{
 			uint32_t index = cop.arguments[i];
-			if (index >= left_components)
+			if (index == 0xFFFFFFFF)
+			{
+				SPIRConstant c;
+				c.constant_type = type.parent_type;
+				assert(type.parent_type != ID(0));
+				expr += constant_expression(c);
+			}
+			else if (index >= left_components)
+			{
 				expr += right_arg + "." + "xyzw"[index - left_components];
+			}
 			else
+			{
 				expr += left_arg + "." + "xyzw"[index];
+			}
 
 			if (i + 1 < uint32_t(cop.arguments.size()))
 				expr += ", ";
@@ -5055,7 +5157,30 @@ string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop)
 	}
 
 	case OpCompositeInsert:
-		SPIRV_CROSS_THROW("OpCompositeInsert spec constant op is not supported.");
+	{
+		SmallVector<ConstantID> new_init = get_composite_constant_ids(cop.arguments[1]);
+		uint32_t idx;
+		uint32_t target_id = cop.self;
+		uint32_t target_type_id = cop.basetype;
+		// We have to drill down to the part we want to modify, and create new
+		// constants for each containing part.
+		for (idx = 2; idx < cop.arguments.size() - 1; ++idx)
+		{
+			uint32_t new_const = ir.increase_bound_by(1);
+			uint32_t old_const = new_init[cop.arguments[idx]];
+			new_init[cop.arguments[idx]] = new_const;
+			set_composite_constant(target_id, target_type_id, new_init);
+			new_init = get_composite_constant_ids(old_const);
+			target_id = new_const;
+			target_type_id = get_composite_member_type(target_type_id, cop.arguments[idx]);
+		}
+		// Now replace the initializer with the one from this instruction.
+		new_init[cop.arguments[idx]] = cop.arguments[0];
+		set_composite_constant(target_id, target_type_id, new_init);
+		SPIRConstant tmp_const(cop.basetype);
+		fill_composite_constant(tmp_const, cop.basetype, const_composite_insert_ids[cop.self]);
+		return constant_expression(tmp_const);
+	}
 
 	default:
 		// Some opcodes are unimplemented here, these are currently not possible to test from glslang.
@@ -5206,20 +5331,31 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c, bool inside_bloc
 		uint32_t subconstant_index = 0;
 		for (auto &elem : c.subconstants)
 		{
-			auto &subc = get<SPIRConstant>(elem);
-			if (subc.specialization)
+			if (auto *op = maybe_get<SPIRConstantOp>(elem))
+			{
+				res += constant_op_expression(*op);
+			}
+			else if (maybe_get<SPIRUndef>(elem) != nullptr)
+			{
 				res += to_name(elem);
+			}
 			else
 			{
-				if (type.array.empty() && type.basetype == SPIRType::Struct)
+				auto &subc = get<SPIRConstant>(elem);
+				if (subc.specialization && !expression_is_forwarded(elem))
+					res += to_name(elem);
+				else
 				{
-					// When we get down to emitting struct members, override the block-like information.
-					// For constants, we can freely mix and match block-like state.
-					inside_block_like_struct_scope =
-							has_member_decoration(type.self, subconstant_index, DecorationOffset);
-				}
+					if (type.array.empty() && type.basetype == SPIRType::Struct)
+					{
+						// When we get down to emitting struct members, override the block-like information.
+						// For constants, we can freely mix and match block-like state.
+						inside_block_like_struct_scope =
+						    has_member_decoration(type.self, subconstant_index, DecorationOffset);
+					}
 
-				res += constant_expression(subc, inside_block_like_struct_scope);
+					res += constant_expression(subc, inside_block_like_struct_scope);
+				}
 			}
 
 			if (&elem != &c.subconstants.back())
@@ -5984,6 +6120,14 @@ void CompilerGLSL::emit_unary_op(uint32_t result_type, uint32_t result_id, uint3
 	inherit_expression_dependencies(result_id, op0);
 }
 
+void CompilerGLSL::emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op)
+{
+	auto &type = get<SPIRType>(result_type);
+	bool forward = should_forward(op0);
+	emit_op(result_type, result_id, join(type_to_glsl(type), "(", op, to_enclosed_unpacked_expression(op0), ")"), forward);
+	inherit_expression_dependencies(result_id, op0);
+}
+
 void CompilerGLSL::emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op)
 {
 	// Various FP arithmetic opcodes such as add, sub, mul will hit this.
@@ -6127,7 +6271,9 @@ bool CompilerGLSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint3
 }
 
 void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
-                                       const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type)
+                                       const char *op, SPIRType::BaseType input_type,
+                                       bool skip_cast_if_equal_type,
+                                       bool implicit_integer_promotion)
 {
 	string cast_op0, cast_op1;
 	auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type);
@@ -6136,17 +6282,23 @@ void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id,
 	// We might have casted away from the result type, so bitcast again.
 	// For example, arithmetic right shift with uint inputs.
 	// Special case boolean outputs since relational opcodes output booleans instead of int/uint.
+	auto bitop = join(cast_op0, " ", op, " ", cast_op1);
 	string expr;
-	if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean)
+
+	if (implicit_integer_promotion)
+	{
+		// Simple value cast.
+		expr = join(type_to_glsl(out_type), '(', bitop, ')');
+	}
+	else if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean)
 	{
 		expected_type.basetype = input_type;
-		expr = bitcast_glsl_op(out_type, expected_type);
-		expr += '(';
-		expr += join(cast_op0, " ", op, " ", cast_op1);
-		expr += ')';
+		expr = join(bitcast_glsl_op(out_type, expected_type), '(', bitop, ')');
 	}
 	else
-		expr += join(cast_op0, " ", op, " ", cast_op1);
+	{
+		expr = std::move(bitop);
+	}
 
 	emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1));
 	inherit_expression_dependencies(result_id, op0);
@@ -9189,6 +9341,14 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 					break;
 				}
 			}
+			else if (backend.force_merged_mesh_block && i == 0 && var &&
+			         !is_builtin_variable(*var) && var->storage == StorageClassOutput)
+			{
+				if (is_per_primitive_variable(*var))
+					expr = join("gl_MeshPrimitivesEXT[", to_expression(index, register_expression_read), "].", expr);
+				else
+					expr = join("gl_MeshVerticesEXT[", to_expression(index, register_expression_read), "].", expr);
+			}
 			else if (options.flatten_multidimensional_arrays && dimension_flatten)
 			{
 				// If we are flattening multidimensional arrays, do manual stride computation.
@@ -9238,7 +9398,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 			if (index >= type->member_types.size())
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
 
-			BuiltIn builtin;
+			BuiltIn builtin = BuiltInMax;
 			if (is_member_builtin(*type, index, &builtin) && access_chain_needs_stage_io_builtin_translation(base))
 			{
 				if (access_chain_is_arrayed)
@@ -9258,7 +9418,13 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 				else if (flatten_member_reference)
 					expr += join("_", to_member_name(*type, index));
 				else
-					expr += to_member_reference(base, *type, index, ptr_chain);
+				{
+					// Any pointer de-refences for values are handled in the first access chain.
+					// For pointer chains, the pointer-ness is resolved through an array access.
+					// The only time this is not true is when accessing array of SSBO/UBO.
+					// This case is explicitly handled.
+					expr += to_member_reference(base, *type, index, ptr_chain || i != 0);
+				}
 			}
 
 			if (has_member_decoration(type->self, index, DecorationInvariant))
@@ -9901,9 +10067,32 @@ bool CompilerGLSL::should_dereference(uint32_t id)
 	if (auto *var = maybe_get<SPIRVariable>(id))
 		return var->phi_variable;
 
-	// If id is an access chain, we should not dereference it.
 	if (auto *expr = maybe_get<SPIRExpression>(id))
-		return !expr->access_chain;
+	{
+		// If id is an access chain, we should not dereference it.
+		if (expr->access_chain)
+			return false;
+
+		// If id is a forwarded copy of a variable pointer, we should not dereference it.
+		SPIRVariable *var = nullptr;
+		while (expr->loaded_from && expression_is_forwarded(expr->self))
+		{
+			auto &src_type = expression_type(expr->loaded_from);
+			// To be a copy, the pointer and its source expression must be the
+			// same type. Can't check type.self, because for some reason that's
+			// usually the base type with pointers stripped off. This check is
+			// complex enough that I've hoisted it out of the while condition.
+			if (src_type.pointer != type.pointer || src_type.pointer_depth != type.pointer ||
+			    src_type.parent_type != type.parent_type)
+				break;
+			if ((var = maybe_get<SPIRVariable>(expr->loaded_from)))
+				break;
+			if (!(expr = maybe_get<SPIRExpression>(expr->loaded_from)))
+				break;
+		}
+
+		return !var || var->phi_variable;
+	}
 
 	// Otherwise, we should dereference this pointer expression.
 	return true;
@@ -10751,8 +10940,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 
 #define GLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
 #define GLSL_BOP_CAST(op, type) \
-	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
+	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, \
+	                    opcode_is_sign_invariant(opcode), implicit_integer_promotion)
 #define GLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
+#define GLSL_UOP_CAST(op) emit_unary_op_cast(ops[0], ops[1], ops[2], #op)
 #define GLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
 #define GLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
 #define GLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
@@ -10766,6 +10957,13 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 	auto int_type = to_signed_basetype(integer_width);
 	auto uint_type = to_unsigned_basetype(integer_width);
 
+	// Handle C implicit integer promotion rules.
+	// If we get implicit promotion to int, need to make sure we cast by value to intended return type,
+	// otherwise, future sign-dependent operations and bitcasts will break.
+	bool implicit_integer_promotion = integer_width < 32 && backend.implicit_c_integer_promotion_rules &&
+	                                  opcode_can_promote_integer_implicitly(opcode) &&
+	                                  get<SPIRType>(ops[0]).vecsize == 1;
+
 	opcode = get_remapped_spirv_op(opcode);
 
 	switch (opcode)
@@ -11491,7 +11689,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			// RHS expression is immutable, so just forward it.
 			// Copying these things really make no sense, but
 			// seems to be allowed anyways.
-			auto &e = set<SPIRExpression>(id, to_expression(rhs), result_type, true);
+			auto &e = emit_op(result_type, id, to_expression(rhs), true, true);
 			if (pointer)
 			{
 				auto *var = maybe_get_backing_variable(rhs);
@@ -11600,6 +11798,12 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		break;
 
 	case OpSNegate:
+		if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0])
+			GLSL_UOP_CAST(-);
+		else
+			GLSL_UOP(-);
+		break;
+
 	case OpFNegate:
 		GLSL_UOP(-);
 		break;
@@ -11744,6 +11948,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(",
 		                 to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")");
 
+		if (implicit_integer_promotion)
+			expr = join(type_to_glsl(get<SPIRType>(result_type)), '(', expr, ')');
+
 		emit_op(result_type, result_id, expr, forward);
 		inherit_expression_dependencies(result_id, op0);
 		inherit_expression_dependencies(result_id, op1);
@@ -11841,7 +12048,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 	}
 
 	case OpNot:
-		GLSL_UOP(~);
+		if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0])
+			GLSL_UOP_CAST(~);
+		else
+			GLSL_UOP(~);
 		break;
 
 	case OpUMod:
@@ -13099,7 +13309,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		{
 			emit_spv_amd_gcn_shader_op(ops[0], ops[1], ops[3], &ops[4], length - 4);
 		}
-		else if (ext == SPIRExtension::SPV_debug_info)
+		else if (ext == SPIRExtension::SPV_debug_info ||
+		         ext == SPIRExtension::NonSemanticShaderDebugInfo)
 		{
 			break; // Ignore SPIR-V debug information extended instructions.
 		}
@@ -13965,7 +14176,7 @@ string CompilerGLSL::to_qualifiers_glsl(uint32_t id)
 
 	if (var && var->storage == StorageClassWorkgroup && !backend.shared_is_implied)
 		res += "shared ";
-	else if (var && var->storage == StorageClassTaskPayloadWorkgroupEXT)
+	else if (var && var->storage == StorageClassTaskPayloadWorkgroupEXT && !backend.shared_is_implied)
 		res += "taskPayloadSharedEXT ";
 
 	res += to_interpolation_qualifiers(flags);
@@ -16715,7 +16926,7 @@ void CompilerGLSL::reorder_type_alias()
 			if (alias_itr < master_itr)
 			{
 				// Must also swap the type order for the constant-type joined array.
-				auto &joined_types = ir.ids_for_constant_or_type;
+				auto &joined_types = ir.ids_for_constant_undef_or_type;
 				auto alt_alias_itr = find(begin(joined_types), end(joined_types), *alias_itr);
 				auto alt_master_itr = find(begin(joined_types), end(joined_types), *master_itr);
 				assert(alt_alias_itr != end(joined_types));
@@ -17210,6 +17421,22 @@ bool CompilerGLSL::is_stage_output_block_member_masked(const SPIRVariable &var,
 	}
 }
 
+bool CompilerGLSL::is_per_primitive_variable(const SPIRVariable &var) const
+{
+	if (has_decoration(var.self, DecorationPerPrimitiveEXT))
+		return true;
+
+	auto &type = get<SPIRType>(var.basetype);
+	if (!has_decoration(type.self, DecorationBlock))
+		return false;
+
+	for (uint32_t i = 0, n = uint32_t(type.member_types.size()); i < n; i++)
+		if (!has_member_decoration(type.self, i, DecorationPerPrimitiveEXT))
+			return false;
+
+	return true;
+}
+
 bool CompilerGLSL::is_stage_output_location_masked(uint32_t location, uint32_t component) const
 {
 	return masked_output_locations.count({ location, component }) != 0;

+ 12 - 4
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -602,6 +602,7 @@ protected:
 		bool allow_precision_qualifiers = false;
 		bool can_swizzle_scalar = false;
 		bool force_gl_in_out_block = false;
+		bool force_merged_mesh_block = false;
 		bool can_return_array = true;
 		bool allow_truncated_access_chain = false;
 		bool supports_extensions = false;
@@ -619,6 +620,7 @@ protected:
 		bool support_64bit_switch = false;
 		bool workgroup_size_is_hidden = false;
 		bool requires_relaxed_precision_analysis = false;
+		bool implicit_c_integer_promotion_rules = false;
 	} backend;
 
 	void emit_struct(SPIRType &type);
@@ -691,7 +693,7 @@ protected:
 	void emit_unrolled_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
 	                             bool negate, SPIRType::BaseType expected_type);
 	void emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
-	                         SPIRType::BaseType input_type, bool skip_cast_if_equal_type);
+	                         SPIRType::BaseType input_type, bool skip_cast_if_equal_type, bool implicit_integer_promotion);
 
 	SPIRType binary_op_bitcast_helper(std::string &cast_op0, std::string &cast_op1, SPIRType::BaseType &input_type,
 	                                  uint32_t op0, uint32_t op1, bool skip_cast_if_equal_type);
@@ -702,6 +704,7 @@ protected:
 	                                  uint32_t false_value);
 
 	void emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op);
+	void emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op);
 	bool expression_is_forwarded(uint32_t id) const;
 	bool expression_suppresses_usage_tracking(uint32_t id) const;
 	bool expression_read_implies_multiple_reads(uint32_t id) const;
@@ -767,7 +770,7 @@ protected:
 	std::string address_of_expression(const std::string &expr);
 	void strip_enclosed_expression(std::string &expr);
 	std::string to_member_name(const SPIRType &type, uint32_t index);
-	virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain);
+	virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved);
 	std::string to_multi_member_reference(const SPIRType &type, const SmallVector<uint32_t> &indices);
 	std::string type_to_glsl_constructor(const SPIRType &type);
 	std::string argument_decl(const SPIRFunction::Parameter &arg);
@@ -934,8 +937,6 @@ protected:
 
 	bool type_is_empty(const SPIRType &type);
 
-	virtual void declare_undefined_values();
-
 	bool can_use_io_location(spv::StorageClass storage, bool block);
 	const Instruction *get_next_instruction_in_block(const Instruction &instr);
 	static uint32_t mask_relevant_memory_semantics(uint32_t semantics);
@@ -980,6 +981,7 @@ protected:
 	bool is_stage_output_builtin_masked(spv::BuiltIn builtin) const;
 	bool is_stage_output_variable_masked(const SPIRVariable &var) const;
 	bool is_stage_output_block_member_masked(const SPIRVariable &var, uint32_t index, bool strip_array) const;
+	bool is_per_primitive_variable(const SPIRVariable &var) const;
 	uint32_t get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const;
 	uint32_t get_declared_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array) const;
 	std::unordered_set<LocationComponentPair, InternalHasher> masked_output_locations;
@@ -987,6 +989,12 @@ protected:
 
 private:
 	void init();
+
+	SmallVector<ConstantID> get_composite_constant_ids(ConstantID const_id);
+	void fill_composite_constant(SPIRConstant &constant, TypeID type_id, const SmallVector<ConstantID> &initializers);
+	void set_composite_constant(ConstantID const_id, TypeID type_id, const SmallVector<ConstantID> &initializers);
+	TypeID get_composite_member_type(TypeID type_id, uint32_t member_idx);
+	std::unordered_map<uint32_t, SmallVector<ConstantID>> const_composite_insert_ids;
 };
 } // namespace SPIRV_CROSS_NAMESPACE
 

+ 487 - 68
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -603,36 +603,80 @@ void CompilerHLSL::emit_builtin_outputs_in_struct()
 			break;
 
 		case BuiltInClipDistance:
+		{
+			static const char *types[] = { "float", "float2", "float3", "float4" };
+
 			// HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
-			for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
+			if (execution.model == ExecutionModelMeshEXT)
 			{
-				uint32_t to_declare = clip_distance_count - clip;
-				if (to_declare > 4)
-					to_declare = 4;
+				if (clip_distance_count > 4)
+					SPIRV_CROSS_THROW("Clip distance count > 4 not supported for mesh shaders.");
 
-				uint32_t semantic_index = clip / 4;
+				if (clip_distance_count == 1)
+				{
+					// Avoids having to hack up access_chain code. Makes it trivially indexable.
+					statement("float gl_ClipDistance[1] : SV_ClipDistance;");
+				}
+				else
+				{
+					// Replace array with vector directly, avoids any weird fixup path.
+					statement(types[clip_distance_count - 1], " gl_ClipDistance : SV_ClipDistance;");
+				}
+			}
+			else
+			{
+				for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
+				{
+					uint32_t to_declare = clip_distance_count - clip;
+					if (to_declare > 4)
+						to_declare = 4;
 
-				static const char *types[] = { "float", "float2", "float3", "float4" };
-				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
-				          " : SV_ClipDistance", semantic_index, ";");
+					uint32_t semantic_index = clip / 4;
+
+					statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
+					          " : SV_ClipDistance", semantic_index, ";");
+				}
 			}
 			break;
+		}
 
 		case BuiltInCullDistance:
+		{
+			static const char *types[] = { "float", "float2", "float3", "float4" };
+
 			// HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
-			for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
+			if (execution.model == ExecutionModelMeshEXT)
 			{
-				uint32_t to_declare = cull_distance_count - cull;
-				if (to_declare > 4)
-					to_declare = 4;
+				if (cull_distance_count > 4)
+					SPIRV_CROSS_THROW("Cull distance count > 4 not supported for mesh shaders.");
 
-				uint32_t semantic_index = cull / 4;
+				if (cull_distance_count == 1)
+				{
+					// Avoids having to hack up access_chain code. Makes it trivially indexable.
+					statement("float gl_CullDistance[1] : SV_CullDistance;");
+				}
+				else
+				{
+					// Replace array with vector directly, avoids any weird fixup path.
+					statement(types[cull_distance_count - 1], " gl_CullDistance : SV_CullDistance;");
+				}
+			}
+			else
+			{
+				for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
+				{
+					uint32_t to_declare = cull_distance_count - cull;
+					if (to_declare > 4)
+						to_declare = 4;
 
-				static const char *types[] = { "float", "float2", "float3", "float4" };
-				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
-				          " : SV_CullDistance", semantic_index, ";");
+					uint32_t semantic_index = cull / 4;
+
+					statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
+					          " : SV_CullDistance", semantic_index, ";");
+				}
 			}
 			break;
+		}
 
 		case BuiltInPointSize:
 			// If point_size_compat is enabled, just ignore PointSize.
@@ -644,14 +688,69 @@ void CompilerHLSL::emit_builtin_outputs_in_struct()
 				SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
 
 		case BuiltInLayer:
-			if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelGeometry)
-				SPIRV_CROSS_THROW("Render target array index output is only supported in GS 5.0 or higher.");
+		case BuiltInPrimitiveId:
+		case BuiltInViewportIndex:
+		case BuiltInPrimitiveShadingRateKHR:
+		case BuiltInCullPrimitiveEXT:
+			// per-primitive attributes handled separatly
+			break;
+
+		case BuiltInPrimitivePointIndicesEXT:
+		case BuiltInPrimitiveLineIndicesEXT:
+		case BuiltInPrimitiveTriangleIndicesEXT:
+			// meshlet local-index buffer handled separatly
+			break;
+
+		default:
+			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
+		}
+
+		if (type && semantic)
+			statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
+	    });
+}
+
+void CompilerHLSL::emit_builtin_primitive_outputs_in_struct()
+{
+	active_output_builtins.for_each_bit([&](uint32_t i) {
+		const char *type = nullptr;
+		const char *semantic = nullptr;
+		auto builtin = static_cast<BuiltIn>(i);
+		switch (builtin)
+		{
+		case BuiltInLayer:
+		{
+			const ExecutionModel model = get_entry_point().model;
+			if (hlsl_options.shader_model < 50 ||
+			    (model != ExecutionModelGeometry && model != ExecutionModelMeshEXT))
+				SPIRV_CROSS_THROW("Render target array index output is only supported in GS/MS 5.0 or higher.");
 			type = "uint";
 			semantic = "SV_RenderTargetArrayIndex";
 			break;
+		}
+
+		case BuiltInPrimitiveId:
+			type = "uint";
+			semantic = "SV_PrimitiveID";
+			break;
+
+		case BuiltInViewportIndex:
+			type = "uint";
+			semantic = "SV_ViewportArrayIndex";
+			break;
+
+		case BuiltInPrimitiveShadingRateKHR:
+			type = "uint";
+			semantic = "SV_ShadingRate";
+			break;
+
+		case BuiltInCullPrimitiveEXT:
+			type = "bool";
+			semantic = "SV_CullPrimitive";
+			break;
 
 		default:
-			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
+			break;
 		}
 
 		if (type && semantic)
@@ -981,17 +1080,25 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord
 		}
 		else
 		{
-			statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
+			auto decl_type = type;
+			if (execution.model == ExecutionModelMeshEXT)
+			{
+				decl_type.array.erase(decl_type.array.begin());
+				decl_type.array_size_literal.erase(decl_type.array_size_literal.begin());
+			}
+			statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(decl_type, name), " : ",
 			          semantic, ";");
 
 			// Structs and arrays should consume more locations.
-			uint32_t consumed_locations = type_to_consumed_locations(type);
+			uint32_t consumed_locations = type_to_consumed_locations(decl_type);
 			for (uint32_t i = 0; i < consumed_locations; i++)
 				active_locations.insert(location_number + i);
 		}
 	}
 	else
+	{
 		statement(variable_decl(type, name), " : ", binding, ";");
+	}
 }
 
 std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
@@ -1071,6 +1178,18 @@ void CompilerHLSL::emit_builtin_variables()
 		if (init_itr != builtin_to_initializer.end())
 			init_expr = join(" = ", to_expression(init_itr->second));
 
+		if (get_execution_model() == ExecutionModelMeshEXT)
+		{
+			if (builtin == BuiltInPosition || builtin == BuiltInPointSize || builtin == BuiltInClipDistance ||
+			    builtin == BuiltInCullDistance || builtin == BuiltInLayer || builtin == BuiltInPrimitiveId ||
+			    builtin == BuiltInViewportIndex || builtin == BuiltInCullPrimitiveEXT ||
+			    builtin == BuiltInPrimitiveShadingRateKHR || builtin == BuiltInPrimitivePointIndicesEXT ||
+			    builtin == BuiltInPrimitiveLineIndicesEXT || builtin == BuiltInPrimitiveTriangleIndicesEXT)
+			{
+				return;
+			}
+		}
+
 		switch (builtin)
 		{
 		case BuiltInFragCoord:
@@ -1171,6 +1290,13 @@ void CompilerHLSL::emit_builtin_variables()
 			type = "uint";
 			break;
 
+		case BuiltInViewportIndex:
+		case BuiltInPrimitiveShadingRateKHR:
+		case BuiltInPrimitiveLineIndicesEXT:
+		case BuiltInCullPrimitiveEXT:
+			type = "uint";
+			break;
+
 		default:
 			SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
 		}
@@ -1283,7 +1409,7 @@ void CompilerHLSL::emit_specialization_constants_and_structs()
 	});
 
 	auto loop_lock = ir.create_loop_hard_lock();
-	for (auto &id_ : ir.ids_for_constant_or_type)
+	for (auto &id_ : ir.ids_for_constant_undef_or_type)
 	{
 		auto &id = ir.ids[id_];
 
@@ -1345,6 +1471,21 @@ void CompilerHLSL::emit_specialization_constants_and_structs()
 				emit_struct(type);
 			}
 		}
+		else if (id.get_type() == TypeUndef)
+		{
+			auto &undef = id.get<SPIRUndef>();
+			auto &type = this->get<SPIRType>(undef.basetype);
+			// OpUndef can be void for some reason ...
+			if (type.basetype == SPIRType::Void)
+				return;
+
+			string initializer;
+			if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
+				initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
+
+			statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
+			emitted = true;
+		}
 	}
 
 	if (emitted)
@@ -1365,12 +1506,12 @@ void CompilerHLSL::replace_illegal_names()
 		"double", "DomainShader", "dword",
 		"else", "export", "false", "float", "for", "fxgroup",
 		"GeometryShader", "groupshared", "half", "HullShader",
-		"if", "in", "inline", "inout", "InputPatch", "int", "interface",
+		"indices", "if", "in", "inline", "inout", "InputPatch", "int", "interface",
 		"line", "lineadj", "linear", "LineStream",
 		"matrix", "min16float", "min10float", "min16int", "min16uint",
 		"namespace", "nointerpolation", "noperspective", "NULL",
 		"out", "OutputPatch",
-		"packoffset", "pass", "pixelfragment", "PixelShader", "point",
+		"payload", "packoffset", "pass", "pixelfragment", "PixelShader", "point",
 		"PointStream", "precise", "RasterizerState", "RenderTargetView",
 		"return", "register", "row_major", "RWBuffer", "RWByteAddressBuffer",
 		"RWStructuredBuffer", "RWTexture1D", "RWTexture1DArray", "RWTexture2D",
@@ -1381,40 +1522,32 @@ void CompilerHLSL::replace_illegal_names()
 		"Texture1DArray", "Texture2D", "Texture2DArray", "Texture2DMS", "Texture2DMSArray",
 		"Texture3D", "TextureCube", "TextureCubeArray", "true", "typedef", "triangle",
 		"triangleadj", "TriangleStream", "uint", "uniform", "unorm", "unsigned",
-		"vector", "vertexfragment", "VertexShader", "void", "volatile", "while",
+		"vector", "vertexfragment", "VertexShader", "vertices", "void", "volatile", "while",
 	};
 
 	CompilerGLSL::replace_illegal_names(keywords);
 	CompilerGLSL::replace_illegal_names();
 }
 
-void CompilerHLSL::declare_undefined_values()
-{
-	bool emitted = false;
-	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
-		auto &type = this->get<SPIRType>(undef.basetype);
-		// OpUndef can be void for some reason ...
-		if (type.basetype == SPIRType::Void)
-			return;
-
-		string initializer;
-		if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
-			initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
-
-		statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
-		emitted = true;
-	});
-
-	if (emitted)
-		statement("");
-}
-
 void CompilerHLSL::emit_resources()
 {
 	auto &execution = get_entry_point();
 
 	replace_illegal_names();
 
+	switch (execution.model)
+	{
+	case ExecutionModelGeometry:
+	case ExecutionModelTessellationControl:
+	case ExecutionModelTessellationEvaluation:
+	case ExecutionModelMeshEXT:
+		fixup_implicit_builtin_block_names(execution.model);
+		break;
+
+	default:
+		break;
+	}
+
 	emit_specialization_constants_and_structs();
 	emit_composite_constants();
 
@@ -1487,18 +1620,21 @@ void CompilerHLSL::emit_resources()
 	// Emit builtin input and output variables here.
 	emit_builtin_variables();
 
-	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
-		auto &type = this->get<SPIRType>(var.basetype);
+	if (execution.model != ExecutionModelMeshEXT)
+	{
+		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
+			auto &type = this->get<SPIRType>(var.basetype);
 
-		if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
-		    (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
-		    interface_variable_exists_in_entry_point(var.self))
-		{
-			// Builtin variables are handled separately.
-			emit_interface_block_globally(var);
-			emitted = true;
-		}
-	});
+			if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
+			   (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
+			   interface_variable_exists_in_entry_point(var.self))
+			{
+				// Builtin variables are handled separately.
+				emit_interface_block_globally(var);
+				emitted = true;
+			}
+		});
+	}
 
 	if (emitted)
 		statement("");
@@ -1612,23 +1748,48 @@ void CompilerHLSL::emit_resources()
 		statement("");
 	}
 
+	const bool is_mesh_shader = execution.model == ExecutionModelMeshEXT;
 	if (!output_variables.empty() || !active_output_builtins.empty())
 	{
-		require_output = true;
-		statement("struct SPIRV_Cross_Output");
+		sort(output_variables.begin(), output_variables.end(), variable_compare);
+		require_output = !is_mesh_shader;
 
+		statement(is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
 		begin_scope();
-		sort(output_variables.begin(), output_variables.end(), variable_compare);
 		for (auto &var : output_variables)
 		{
-			if (var.block)
+			if (is_per_primitive_variable(*var.var))
+				continue;
+			if (var.block && is_mesh_shader && var.block_member_index != 0)
+				continue;
+			if (var.block && !is_mesh_shader)
 				emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_outputs);
 			else
 				emit_interface_block_in_struct(*var.var, active_outputs);
 		}
 		emit_builtin_outputs_in_struct();
+		if (!is_mesh_shader)
+			emit_builtin_primitive_outputs_in_struct();
 		end_scope_decl();
 		statement("");
+
+		if (is_mesh_shader)
+		{
+			statement("struct gl_MeshPerPrimitiveEXT");
+			begin_scope();
+			for (auto &var : output_variables)
+			{
+				if (!is_per_primitive_variable(*var.var))
+					continue;
+				if (var.block && var.block_member_index != 0)
+					continue;
+
+				emit_interface_block_in_struct(*var.var, active_outputs);
+			}
+			emit_builtin_primitive_outputs_in_struct();
+			end_scope_decl();
+			statement("");
+		}
 	}
 
 	// Global variables.
@@ -1638,7 +1799,8 @@ void CompilerHLSL::emit_resources()
 		if (is_hidden_variable(var, true))
 			continue;
 
-		if (var.storage != StorageClassOutput)
+		if (var.storage != StorageClassOutput &&
+		    var.storage != StorageClassTaskPayloadWorkgroupEXT)
 		{
 			if (!variable_is_lut(var))
 			{
@@ -1672,8 +1834,6 @@ void CompilerHLSL::emit_resources()
 	if (emitted)
 		statement("");
 
-	declare_undefined_values();
-
 	if (requires_op_fmod)
 	{
 		static const char *types[] = {
@@ -2164,6 +2324,194 @@ void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char
 	}
 }
 
+void CompilerHLSL::analyze_meshlet_writes()
+{
+	uint32_t id_per_vertex = 0;
+	uint32_t id_per_primitive = 0;
+	bool need_per_primitive = false;
+	bool need_per_vertex = false;
+
+	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
+		auto &type = this->get<SPIRType>(var.basetype);
+		bool block = has_decoration(type.self, DecorationBlock);
+		if (var.storage == StorageClassOutput && block && is_builtin_variable(var))
+		{
+			auto flags = get_buffer_block_flags(var.self);
+			if (flags.get(DecorationPerPrimitiveEXT))
+				id_per_primitive = var.self;
+			else
+				id_per_vertex = var.self;
+		}
+		else if (var.storage == StorageClassOutput)
+		{
+			Bitset flags;
+			if (block)
+				flags = get_buffer_block_flags(var.self);
+			else
+				flags = get_decoration_bitset(var.self);
+
+			if (flags.get(DecorationPerPrimitiveEXT))
+				need_per_primitive = true;
+			else
+				need_per_vertex = true;
+		}
+	});
+
+	// If we have per-primitive outputs, and no per-primitive builtins,
+	// empty version of gl_MeshPerPrimitiveEXT will be emitted.
+	// If we don't use block IO for vertex output, we'll also need to synthesize the PerVertex block.
+
+	const auto generate_block = [&](const char *block_name, const char *instance_name, bool per_primitive) -> uint32_t {
+		auto &execution = get_entry_point();
+
+		uint32_t op_type = ir.increase_bound_by(4);
+		uint32_t op_arr = op_type + 1;
+		uint32_t op_ptr = op_type + 2;
+		uint32_t op_var = op_type + 3;
+
+		auto &type = set<SPIRType>(op_type);
+		type.basetype = SPIRType::Struct;
+		set_name(op_type, block_name);
+		set_decoration(op_type, DecorationBlock);
+		if (per_primitive)
+			set_decoration(op_type, DecorationPerPrimitiveEXT);
+
+		auto &arr = set<SPIRType>(op_arr, type);
+		arr.parent_type = type.self;
+		arr.array.push_back(per_primitive ? execution.output_primitives : execution.output_vertices);
+		arr.array_size_literal.push_back(true);
+
+		auto &ptr = set<SPIRType>(op_ptr, arr);
+		ptr.parent_type = arr.self;
+		ptr.pointer = true;
+		ptr.pointer_depth++;
+		ptr.storage = StorageClassOutput;
+		set_decoration(op_ptr, DecorationBlock);
+		set_name(op_ptr, block_name);
+
+		auto &var = set<SPIRVariable>(op_var, op_ptr, StorageClassOutput);
+		if (per_primitive)
+			set_decoration(op_var, DecorationPerPrimitiveEXT);
+		set_name(op_var, instance_name);
+		execution.interface_variables.push_back(var.self);
+
+		return op_var;
+	};
+
+	if (id_per_vertex == 0 && need_per_vertex)
+		id_per_vertex = generate_block("gl_MeshPerVertexEXT", "gl_MeshVerticesEXT", false);
+	if (id_per_primitive == 0 && need_per_primitive)
+		id_per_primitive = generate_block("gl_MeshPerPrimitiveEXT", "gl_MeshPrimitivesEXT", true);
+
+	unordered_set<uint32_t> processed_func_ids;
+	analyze_meshlet_writes(ir.default_entry_point, id_per_vertex, id_per_primitive, processed_func_ids);
+}
+
+void CompilerHLSL::analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vertex, uint32_t id_per_primitive,
+                                          std::unordered_set<uint32_t> &processed_func_ids)
+{
+	// Avoid processing a function more than once
+	if (processed_func_ids.find(func_id) != processed_func_ids.end())
+		return;
+	processed_func_ids.insert(func_id);
+
+	auto &func = get<SPIRFunction>(func_id);
+	// Recursively establish global args added to functions on which we depend.
+	for (auto& block : func.blocks)
+	{
+		auto &b = get<SPIRBlock>(block);
+		for (auto &i : b.ops)
+		{
+			auto ops = stream(i);
+			auto op = static_cast<Op>(i.op);
+
+			switch (op)
+			{
+			case OpFunctionCall:
+			{
+				// Then recurse into the function itself to extract globals used internally in the function
+				uint32_t inner_func_id = ops[2];
+				analyze_meshlet_writes(inner_func_id, id_per_vertex, id_per_primitive, processed_func_ids);
+				auto &inner_func = get<SPIRFunction>(inner_func_id);
+				for (auto &iarg : inner_func.arguments)
+				{
+					if (!iarg.alias_global_variable)
+						continue;
+
+					bool already_declared = false;
+					for (auto &arg : func.arguments)
+					{
+						if (arg.id == iarg.id)
+						{
+							already_declared = true;
+							break;
+						}
+					}
+
+					if (!already_declared)
+					{
+						// basetype is effectively ignored here since we declare the argument
+						// with explicit types. Just pass down a valid type.
+						func.arguments.push_back({ expression_type_id(iarg.id), iarg.id,
+						                           iarg.read_count, iarg.write_count, true });
+					}
+				}
+				break;
+			}
+
+			case OpStore:
+			case OpLoad:
+			case OpInBoundsAccessChain:
+			case OpAccessChain:
+			case OpPtrAccessChain:
+			case OpInBoundsPtrAccessChain:
+			case OpArrayLength:
+			{
+				auto *var = maybe_get<SPIRVariable>(ops[op == OpStore ? 0 : 2]);
+				if (var && (var->storage == StorageClassOutput || var->storage == StorageClassTaskPayloadWorkgroupEXT))
+				{
+					bool already_declared = false;
+					auto builtin_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn));
+
+					uint32_t var_id = var->self;
+					if (var->storage != StorageClassTaskPayloadWorkgroupEXT &&
+						builtin_type != BuiltInPrimitivePointIndicesEXT &&
+						builtin_type != BuiltInPrimitiveLineIndicesEXT &&
+						builtin_type != BuiltInPrimitiveTriangleIndicesEXT)
+					{
+						var_id = is_per_primitive_variable(*var) ? id_per_primitive : id_per_vertex;
+					}
+
+					for (auto &arg : func.arguments)
+					{
+						if (arg.id == var_id)
+						{
+							already_declared = true;
+							break;
+						}
+					}
+
+					if (!already_declared)
+					{
+						// basetype is effectively ignored here since we declare the argument
+						// with explicit types. Just pass down a valid type.
+						uint32_t type_id = expression_type_id(var_id);
+						if (var->storage == StorageClassTaskPayloadWorkgroupEXT)
+							func.arguments.push_back({ type_id, var_id, 1u, 0u, true });
+						else
+							func.arguments.push_back({ type_id, var_id, 1u, 1u, true });
+					}
+				}
+				break;
+			}
+
+			default:
+				break;
+			}
+		}
+	}
+}
+
 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
 {
 	auto &flags = get_member_decoration_bitset(type.self, index);
@@ -2459,6 +2807,8 @@ string CompilerHLSL::get_inner_entry_point_name() const
 		return "frag_main";
 	else if (execution.model == ExecutionModelGLCompute)
 		return "comp_main";
+	else if (execution.model == ExecutionModelMeshEXT)
+		return "mesh_main";
 	else
 		SPIRV_CROSS_THROW("Unsupported execution model.");
 }
@@ -2572,8 +2922,58 @@ void CompilerHLSL::emit_hlsl_entry_point()
 
 	switch (execution.model)
 	{
+	case ExecutionModelMeshEXT:
+	case ExecutionModelMeshNV:
 	case ExecutionModelGLCompute:
 	{
+		if (execution.model == ExecutionModelMeshEXT)
+		{
+			if (execution.flags.get(ExecutionModeOutputTrianglesEXT))
+				statement("[outputtopology(\"triangle\")]");
+			else if (execution.flags.get(ExecutionModeOutputLinesEXT))
+				statement("[outputtopology(\"line\")]");
+			else if (execution.flags.get(ExecutionModeOutputPoints))
+				SPIRV_CROSS_THROW("Topology mode \"points\" is not supported in DirectX");
+
+			auto &func = get<SPIRFunction>(ir.default_entry_point);
+			for (auto &arg : func.arguments)
+			{
+				auto &var = get<SPIRVariable>(arg.id);
+				auto &base_type = get<SPIRType>(var.basetype);
+				bool block = has_decoration(base_type.self, DecorationBlock);
+				if (var.storage == StorageClassTaskPayloadWorkgroupEXT)
+				{
+					arguments.push_back("in payload " + variable_decl(var));
+				}
+				else if (block)
+				{
+					auto flags = get_buffer_block_flags(var.self);
+					if (flags.get(DecorationPerPrimitiveEXT) || has_decoration(arg.id, DecorationPerPrimitiveEXT))
+					{
+						arguments.push_back("out primitives gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[" +
+						                    std::to_string(execution.output_primitives) + "]");
+					}
+					else
+					{
+						arguments.push_back("out vertices gl_MeshPerVertexEXT gl_MeshVerticesEXT[" +
+						                    std::to_string(execution.output_vertices) + "]");
+					}
+				}
+				else
+				{
+					if (execution.flags.get(ExecutionModeOutputTrianglesEXT))
+					{
+						arguments.push_back("out indices uint3 gl_PrimitiveTriangleIndicesEXT[" +
+						                    std::to_string(execution.output_primitives) + "]");
+					}
+					else
+					{
+						arguments.push_back("out indices uint2 gl_PrimitiveLineIndicesEXT[" +
+						                    std::to_string(execution.output_primitives) + "]");
+					}
+				}
+			}
+		}
 		SpecializationConstant wg_x, wg_y, wg_z;
 		get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
 
@@ -2795,9 +3195,18 @@ void CompilerHLSL::emit_hlsl_entry_point()
 	// Run the shader.
 	if (execution.model == ExecutionModelVertex ||
 	    execution.model == ExecutionModelFragment ||
-	    execution.model == ExecutionModelGLCompute)
-	{
-		statement(get_inner_entry_point_name(), "();");
+	    execution.model == ExecutionModelGLCompute ||
+	    execution.model == ExecutionModelMeshEXT)
+	{
+		// For mesh shaders, we receive special arguments that we must pass down as function arguments.
+		// HLSL does not support proper reference types for passing these IO blocks,
+		// but DXC post-inlining seems to magically fix it up anyways *shrug*.
+		SmallVector<string> arglist;
+		auto &func = get<SPIRFunction>(ir.default_entry_point);
+		// The arguments are marked out, avoid detecting reads and emitting inout.
+		for (auto &arg : func.arguments)
+			arglist.push_back(to_expression(arg.id, false));
+		statement(get_inner_entry_point_name(), "(", merge(arglist), ");");
 	}
 	else
 		SPIRV_CROSS_THROW("Unsupported shader stage.");
@@ -4965,7 +5374,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 
 #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
 #define HLSL_BOP_CAST(op, type) \
-	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
+	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
 #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
 #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
 #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
@@ -5926,6 +6335,12 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".WorldRayDirection()"), false);
 		break;
 	}
+	case OpSetMeshOutputsEXT:
+	{
+		statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");");
+		break;
+	}
+
 	default:
 		CompilerGLSL::emit_instruction(instruction);
 		break;
@@ -6126,6 +6541,8 @@ string CompilerHLSL::compile()
 	backend.can_return_array = false;
 	backend.nonuniform_qualifier = "NonUniformResourceIndex";
 	backend.support_case_fallthrough = false;
+	backend.force_merged_mesh_block = get_execution_model() == ExecutionModelMeshEXT;
+	backend.force_gl_in_out_block = backend.force_merged_mesh_block;
 
 	// SM 4.1 does not support precise for some reason.
 	backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
@@ -6138,6 +6555,8 @@ string CompilerHLSL::compile()
 	update_active_builtins();
 	analyze_image_and_sampler_usage();
 	analyze_interlocked_resource_usage();
+	if (get_execution_model() == ExecutionModelMeshEXT)
+		analyze_meshlet_writes();
 
 	// Subpass input needs SV_Position.
 	if (need_subpass_input)

+ 6 - 3
3rdparty/spirv-cross/spirv_hlsl.hpp

@@ -230,14 +230,13 @@ private:
 	void emit_hlsl_entry_point();
 	void emit_header() override;
 	void emit_resources();
-	void declare_undefined_values() override;
 	void emit_interface_block_globally(const SPIRVariable &type);
 	void emit_interface_block_in_struct(const SPIRVariable &var, std::unordered_set<uint32_t> &active_locations);
-	void emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
-	                                           uint32_t location,
+	void emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index, uint32_t location,
 	                                           std::unordered_set<uint32_t> &active_locations);
 	void emit_builtin_inputs_in_struct();
 	void emit_builtin_outputs_in_struct();
+	void emit_builtin_primitive_outputs_in_struct();
 	void emit_texture_op(const Instruction &i, bool sparse) override;
 	void emit_instruction(const Instruction &instruction) override;
 	void emit_glsl_op(uint32_t result_type, uint32_t result_id, uint32_t op, const uint32_t *args,
@@ -355,6 +354,10 @@ private:
 		TypeUnpackUint64
 	};
 
+	void analyze_meshlet_writes();
+	void analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vertex, uint32_t id_per_primitive,
+	                            std::unordered_set<uint32_t> &processed_func_ids);
+
 	BitcastType get_bitcast_type(uint32_t result_type, uint32_t op0);
 
 	void emit_builtin_variables();

+ 323 - 60
3rdparty/spirv-cross/spirv_msl.cpp

@@ -259,8 +259,8 @@ void CompilerMSL::build_implicit_builtins()
 
 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
 	    need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
-	    needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || has_additional_fixed_sample_mask() ||
-	    need_local_invocation_index || need_workgroup_size)
+	    needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation ||
+		has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size)
 	{
 		bool has_frag_coord = false;
 		bool has_sample_id = false;
@@ -274,6 +274,7 @@ void CompilerMSL::build_implicit_builtins()
 		bool has_subgroup_size = false;
 		bool has_view_idx = false;
 		bool has_layer = false;
+		bool has_helper_invocation = false;
 		bool has_local_invocation_index = false;
 		bool has_workgroup_size = false;
 		uint32_t workgroup_id_type = 0;
@@ -430,6 +431,13 @@ void CompilerMSL::build_implicit_builtins()
 				}
 			}
 
+			if (needs_helper_invocation && builtin == BuiltInHelperInvocation)
+			{
+				builtin_helper_invocation_id = var.self;
+				mark_implicit_builtin(StorageClassInput, BuiltInHelperInvocation, var.self);
+				has_helper_invocation = true;
+			}
+
 			if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
 			{
 				builtin_local_invocation_index_id = var.self;
@@ -806,6 +814,35 @@ void CompilerMSL::build_implicit_builtins()
 			mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
 		}
 
+		if (!has_helper_invocation && needs_helper_invocation)
+		{
+			uint32_t offset = ir.increase_bound_by(3);
+			uint32_t type_id = offset;
+			uint32_t type_ptr_id = offset + 1;
+			uint32_t var_id = offset + 2;
+
+			// Create gl_HelperInvocation.
+			SPIRType bool_type;
+			bool_type.basetype = SPIRType::Boolean;
+			bool_type.width = 8;
+			bool_type.vecsize = 1;
+			set<SPIRType>(type_id, bool_type);
+
+			SPIRType bool_type_ptr_in;
+			bool_type_ptr_in = bool_type;
+			bool_type_ptr_in.pointer = true;
+			bool_type_ptr_in.pointer_depth++;
+			bool_type_ptr_in.parent_type = type_id;
+			bool_type_ptr_in.storage = StorageClassInput;
+
+			auto &ptr_in_type = set<SPIRType>(type_ptr_id, bool_type_ptr_in);
+			ptr_in_type.self = type_id;
+			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
+			set_decoration(var_id, DecorationBuiltIn, BuiltInHelperInvocation);
+			builtin_helper_invocation_id = var_id;
+			mark_implicit_builtin(StorageClassInput, BuiltInHelperInvocation, var_id);
+		}
+
 		if (need_local_invocation_index && !has_local_invocation_index)
 		{
 			uint32_t offset = ir.increase_bound_by(2);
@@ -1415,8 +1452,6 @@ string CompilerMSL::compile()
 	backend.basic_uint8_type = "uchar";
 	backend.basic_int16_type = "short";
 	backend.basic_uint16_type = "ushort";
-	backend.discard_literal = "discard_fragment()";
-	backend.demote_literal = "discard_fragment()";
 	backend.boolean_mix_function = "select";
 	backend.swizzle_is_function = false;
 	backend.shared_is_implied = false;
@@ -1439,6 +1474,7 @@ string CompilerMSL::compile()
 	// Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
 	backend.array_is_value_type_in_buffer_blocks = false;
 	backend.support_pointer_to_pointer = true;
+	backend.implicit_c_integer_promotion_rules = true;
 
 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
@@ -1460,6 +1496,20 @@ string CompilerMSL::compile()
 	preprocess_op_codes();
 	build_implicit_builtins();
 
+	if (needs_manual_helper_invocation_updates() &&
+	    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+	{
+		string discard_expr =
+		    join(builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), " = true, discard_fragment()");
+		backend.discard_literal = discard_expr;
+		backend.demote_literal = discard_expr;
+	}
+	else
+	{
+		backend.discard_literal = "discard_fragment()";
+		backend.demote_literal = "discard_fragment()";
+	}
+
 	fixup_image_load_store_access();
 
 	set_enabled_interface_variables(get_active_interface_variables());
@@ -1564,7 +1614,8 @@ void CompilerMSL::preprocess_op_codes()
 
 	// Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
 	// resources must disable rasterization and return void.
-	if (preproc.uses_resource_write)
+	if ((preproc.uses_buffer_write && !msl_options.supports_msl_version(2, 1)) ||
+	    (preproc.uses_image_write && !msl_options.supports_msl_version(2, 2)))
 		is_rasterization_disabled = true;
 
 	// Tessellation control shaders are run as compute functions in Metal, and so
@@ -1586,6 +1637,27 @@ void CompilerMSL::preprocess_op_codes()
 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
 	                          (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses))))
 		needs_sample_id = true;
+	if (preproc.needs_helper_invocation)
+		needs_helper_invocation = true;
+
+	// OpKill is removed by the parser, so we need to identify those by inspecting
+	// blocks.
+	ir.for_each_typed_id<SPIRBlock>([&preproc](uint32_t, SPIRBlock &block) {
+		if (block.terminator == SPIRBlock::Kill)
+			preproc.uses_discard = true;
+	});
+
+	// Fragment shaders that both write to storage resources and discard fragments
+	// need checks on the writes, to work around Metal allowing these writes despite
+	// the fragment being dead.
+	if (msl_options.check_discarded_frag_stores && preproc.uses_discard &&
+	    (preproc.uses_buffer_write || preproc.uses_image_write))
+	{
+		frag_shader_needs_discard_checks = true;
+		needs_helper_invocation = true;
+		// Fragment discard store checks imply manual HelperInvocation updates.
+		msl_options.manual_helper_invocation_updates = true;
+	}
 
 	if (is_intersection_query())
 	{
@@ -1626,10 +1698,26 @@ void CompilerMSL::extract_global_variables_from_functions()
 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
 		// Some builtins resolve directly to a function call which does not need any declared variables.
 		// Skip these.
-		if (var.storage == StorageClassInput && has_decoration(var.self, DecorationBuiltIn) &&
-		    BuiltIn(get_decoration(var.self, DecorationBuiltIn)) == BuiltInHelperInvocation)
+		if (var.storage == StorageClassInput && has_decoration(var.self, DecorationBuiltIn))
 		{
-			return;
+			auto bi_type = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
+			if (bi_type == BuiltInHelperInvocation && !needs_manual_helper_invocation_updates())
+				return;
+			if (bi_type == BuiltInHelperInvocation && needs_manual_helper_invocation_updates())
+			{
+				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
+					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
+				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
+					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
+				// Make sure this is declared and initialized.
+				// Force this to have the proper name.
+				set_name(var.self, builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput));
+				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
+				entry_func.add_local_variable(var.self);
+				vars_needing_early_declaration.push_back(var.self);
+				entry_func.fixup_hooks_in.push_back([this, &var]()
+				                                    { statement(to_name(var.self), " = simd_is_helper_thread();"); });
+			}
 		}
 
 		if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
@@ -1745,6 +1833,9 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				if (global_var_ids.find(rvalue_id) != global_var_ids.end())
 					added_arg_ids.insert(rvalue_id);
 
+				if (needs_frag_discard_checks())
+					added_arg_ids.insert(builtin_helper_invocation_id);
+
 				break;
 			}
 
@@ -1759,6 +1850,25 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				break;
 			}
 
+			case OpAtomicExchange:
+			case OpAtomicCompareExchange:
+			case OpAtomicStore:
+			case OpAtomicIIncrement:
+			case OpAtomicIDecrement:
+			case OpAtomicIAdd:
+			case OpAtomicISub:
+			case OpAtomicSMin:
+			case OpAtomicUMin:
+			case OpAtomicSMax:
+			case OpAtomicUMax:
+			case OpAtomicAnd:
+			case OpAtomicOr:
+			case OpAtomicXor:
+			case OpImageWrite:
+				if (needs_frag_discard_checks())
+					added_arg_ids.insert(builtin_helper_invocation_id);
+				break;
+
 			// Emulate texture2D atomic operations
 			case OpImageTexelPointer:
 			{
@@ -1840,6 +1950,17 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				break;
 			}
 
+			case OpDemoteToHelperInvocation:
+				if (needs_manual_helper_invocation_updates() &&
+				    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+					added_arg_ids.insert(builtin_helper_invocation_id);
+				break;
+
+			case OpIsHelperInvocationEXT:
+				if (needs_manual_helper_invocation_updates())
+					added_arg_ids.insert(builtin_helper_invocation_id);
+				break;
+
 			case OpRayQueryInitializeKHR:
 			case OpRayQueryProceedKHR:
 			case OpRayQueryTerminateKHR:
@@ -1883,6 +2004,10 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				break;
 			}
 
+			if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill &&
+			    (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
+				added_arg_ids.insert(builtin_helper_invocation_id);
+
 			// TODO: Add all other operations which can affect memory.
 			// We should consider a more unified system here to reduce boiler-plate.
 			// This kind of analysis is done in several places ...
@@ -7092,28 +7217,6 @@ static string inject_top_level_storage_qualifier(const string &expr, const strin
 	}
 }
 
-// Undefined global memory is not allowed in MSL.
-// Declare constant and init to zeros. Use {}, as global constructors can break Metal.
-void CompilerMSL::declare_undefined_values()
-{
-	bool emitted = false;
-	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
-		auto &type = this->get<SPIRType>(undef.basetype);
-		// OpUndef can be void for some reason ...
-		if (type.basetype == SPIRType::Void)
-			return;
-
-		statement(inject_top_level_storage_qualifier(
-				variable_decl(type, to_name(undef.self), undef.self),
-				"constant"),
-		          " = {};");
-		emitted = true;
-	});
-
-	if (emitted)
-		statement("");
-}
-
 void CompilerMSL::declare_constant_arrays()
 {
 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
@@ -7179,7 +7282,6 @@ void CompilerMSL::declare_complex_constant_arrays()
 void CompilerMSL::emit_resources()
 {
 	declare_constant_arrays();
-	declare_undefined_values();
 
 	// Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
 	emit_interface_block(stage_out_var_id);
@@ -7242,7 +7344,7 @@ void CompilerMSL::emit_specialization_constants_and_structs()
 	emitted = false;
 	declared_structs.clear();
 
-	for (auto &id_ : ir.ids_for_constant_or_type)
+	for (auto &id_ : ir.ids_for_constant_undef_or_type)
 	{
 		auto &id = ir.ids[id_];
 
@@ -7355,6 +7457,21 @@ void CompilerMSL::emit_specialization_constants_and_structs()
 				emit_struct(get<SPIRType>(type_id));
 			}
 		}
+		else if (id.get_type() == TypeUndef)
+		{
+			auto &undef = id.get<SPIRUndef>();
+			auto &type = get<SPIRType>(undef.basetype);
+			// OpUndef can be void for some reason ...
+			if (type.basetype == SPIRType::Void)
+				return;
+
+			// Undefined global memory is not allowed in MSL.
+			// Declare constant and init to zeros. Use {}, as global constructors can break Metal.
+			statement(
+			    inject_top_level_storage_qualifier(variable_decl(type, to_name(undef.self), undef.self), "constant"),
+			    " = {};");
+			emitted = true;
+		}
 	}
 
 	if (emitted)
@@ -8167,8 +8284,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 {
 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
 #define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op)
+	// MSL does care about implicit integer promotion, but those cases are all handled in common code.
 #define MSL_BOP_CAST(op, type) \
-	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
+	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
@@ -8614,9 +8732,16 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		args.base.is_fetch = true;
 		args.coord = coord_id;
 		args.lod = lod;
-		statement(join(to_expression(img_id), ".write(",
-		               remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
-		               CompilerMSL::to_function_args(args, &forward), ");"));
+
+		string expr;
+		if (needs_frag_discard_checks())
+			expr = join("(", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), " ? ((void)0) : ");
+		expr += join(to_expression(img_id), ".write(",
+		             remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
+		             CompilerMSL::to_function_args(args, &forward), ")");
+		if (needs_frag_discard_checks())
+			expr += ")";
+		statement(expr, ";");
 
 		if (p_var && variable_storage_is_aliased(*p_var))
 			flush_all_aliased_variables();
@@ -8771,14 +8896,34 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		break;
 
 	case OpStore:
-		if (is_out_of_bounds_tessellation_level(ops[0]))
-			break;
+	{
+		const auto &type = expression_type(ops[0]);
 
-		if (maybe_emit_array_assignment(ops[0], ops[1]))
+		if (is_out_of_bounds_tessellation_level(ops[0]))
 			break;
 
-		CompilerGLSL::emit_instruction(instruction);
+		if (needs_frag_discard_checks() &&
+		    (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
+		{
+			// If we're in a continue block, this kludge will make the block too complex
+			// to emit normally.
+			assert(current_emitting_block);
+			auto cont_type = continue_block_type(*current_emitting_block);
+			if (cont_type != SPIRBlock::ContinueNone && cont_type != SPIRBlock::ComplexLoop)
+			{
+				current_emitting_block->complex_continue = true;
+				force_recompile();
+			}
+			statement("if (!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), ")");
+			begin_scope();
+		}
+		if (!maybe_emit_array_assignment(ops[0], ops[1]))
+			CompilerGLSL::emit_instruction(instruction);
+		if (needs_frag_discard_checks() &&
+		    (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
+			end_scope();
 		break;
+	}
 
 	// Compute barriers
 	case OpMemoryBarrier:
@@ -8935,12 +9080,33 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t op0 = ops[2];
 		uint32_t op1 = ops[3];
 		auto &type = get<SPIRType>(result_type);
+		auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
+		auto &output_type = get_type(result_type);
+		string cast_op0, cast_op1;
+
+		auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false);
+
 		emit_uninitialized_temporary_expression(result_type, result_id);
 
-		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
-				  to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";");
-		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(",
-				  to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");");
+		string mullo_expr, mulhi_expr;
+		mullo_expr = join(cast_op0, " * ", cast_op1);
+		mulhi_expr = join("mulhi(", cast_op0, ", ", cast_op1, ")");
+
+		auto &low_type = get_type(output_type.member_types[0]);
+		auto &high_type = get_type(output_type.member_types[1]);
+		if (low_type.basetype != input_type)
+		{
+			expected_type.basetype = input_type;
+			mullo_expr = join(bitcast_glsl_op(low_type, expected_type), "(", mullo_expr, ")");
+		}
+		if (high_type.basetype != input_type)
+		{
+			expected_type.basetype = input_type;
+			mulhi_expr = join(bitcast_glsl_op(high_type, expected_type), "(", mulhi_expr, ")");
+		}
+
+		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", mullo_expr, ";");
+		statement(to_expression(result_id), ".", to_member_name(type, 1), " = ", mulhi_expr, ";");
 		break;
 	}
 
@@ -9025,7 +9191,10 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
 		else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
-		emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
+		emit_op(ops[0], ops[1],
+		        needs_manual_helper_invocation_updates() ? builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput) :
+		                                                   "simd_is_helper_thread()",
+		        false);
 		break;
 
 	case OpBeginInvocationInterlockEXT:
@@ -9475,7 +9644,7 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
                                       uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
 {
-	string exp = string(op) + "(";
+	string exp;
 
 	auto &type = get_pointee_type(expression_type(obj));
 	auto expected_type = type.basetype;
@@ -9490,13 +9659,33 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 	auto remapped_type = type;
 	remapped_type.basetype = expected_type;
 
-	exp += "(";
 	auto *var = maybe_get_backing_variable(obj);
 	if (!var)
 		SPIRV_CROSS_THROW("No backing variable for atomic operation.");
+	const auto &res_type = get<SPIRType>(var->basetype);
+
+	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
 
+	bool check_discard = opcode != OpAtomicLoad && needs_frag_discard_checks() &&
+	                     ((res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image) ||
+	                      var->storage == StorageClassStorageBuffer || var->storage == StorageClassUniform);
+
+	if (check_discard)
+	{
+		if (is_atomic_compare_exchange_strong)
+		{
+			// We're already emitting a CAS loop here; a conditional won't hurt.
+			emit_uninitialized_temporary_expression(result_type, result_id);
+			statement("if (!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), ")");
+			begin_scope();
+		}
+		else
+			exp = join("(!", builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput), " ? ");
+	}
+
+	exp += string(op) + "(";
+	exp += "(";
 	// Emulate texture2D atomic operations
-	const auto &res_type = get<SPIRType>(var->basetype);
 	if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
 	{
 		exp += "device";
@@ -9515,8 +9704,6 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 	exp += "&";
 	exp += to_enclosed_expression(obj);
 
-	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
-
 	if (is_atomic_compare_exchange_strong)
 	{
 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
@@ -9538,11 +9725,42 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 		// the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
 		// The function updates the comparitor value from the memory value, so the additional
 		// comparison test evaluates the memory value against the expected value.
-		emit_uninitialized_temporary_expression(result_type, result_id);
+		if (!check_discard)
+			emit_uninitialized_temporary_expression(result_type, result_id);
 		statement("do");
 		begin_scope();
 		statement(to_name(result_id), " = ", to_expression(op1), ";");
 		end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
+		if (check_discard)
+		{
+			end_scope();
+			statement("else");
+			begin_scope();
+			exp = "atomic_load_explicit(";
+			exp += "(";
+			// Emulate texture2D atomic operations
+			if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
+				exp += "device";
+			else
+				exp += get_argument_address_space(*var);
+
+			exp += " atomic_";
+			exp += type_to_glsl(remapped_type);
+			exp += "*)";
+
+			exp += "&";
+			exp += to_enclosed_expression(obj);
+
+			if (has_mem_order_2)
+				exp += string(", ") + get_memory_order(mem_order_2);
+			else
+				exp += string(", ") + get_memory_order(mem_order_1);
+
+			exp += ")";
+
+			statement(to_name(result_id), " = ", exp, ";");
+			end_scope();
+		}
 	}
 	else
 	{
@@ -9563,6 +9781,38 @@ void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id,
 
 		exp += ")";
 
+		if (check_discard)
+		{
+			exp += " : ";
+			if (strcmp(op, "atomic_store_explicit") != 0)
+			{
+				exp += "atomic_load_explicit(";
+				exp += "(";
+				// Emulate texture2D atomic operations
+				if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
+					exp += "device";
+				else
+					exp += get_argument_address_space(*var);
+
+				exp += " atomic_";
+				exp += type_to_glsl(remapped_type);
+				exp += "*)";
+
+				exp += "&";
+				exp += to_enclosed_expression(obj);
+
+				if (has_mem_order_2)
+					exp += string(", ") + get_memory_order(mem_order_2);
+				else
+					exp += string(", ") + get_memory_order(mem_order_1);
+
+				exp += ")";
+			}
+			else
+				exp += "((void)0)";
+			exp += ")";
+		}
+
 		if (expected_type != type.basetype)
 			exp = bitcast_expression(type, expected_type, exp);
 
@@ -14195,7 +14445,7 @@ void CompilerMSL::sync_entry_point_aliases_and_names()
 		entry.second.name = ir.meta[entry.first].decoration.alias;
 }
 
-string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
+string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved)
 {
 	auto *var = maybe_get_backing_variable(base);
 	// If this is a buffer array, we have to dereference the buffer pointers.
@@ -14214,7 +14464,7 @@ string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uin
 		declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
 	}
 
-	if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
+	if (declared_as_pointer || (!ptr_chain_is_resolved && should_dereference(base)))
 		return join("->", to_member_name(type, index));
 	else
 		return join(".", to_member_name(type, index));
@@ -15265,6 +15515,8 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 		break;
 
 	case BuiltInHelperInvocation:
+		if (needs_manual_helper_invocation_updates())
+			break;
 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
 		else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
@@ -15978,6 +16230,10 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 		suppress_missing_prototypes = true;
 		break;
 
+	case OpDemoteToHelperInvocationEXT:
+		uses_discard = true;
+		break;
+
 	// Emulate texture2D atomic operations
 	case OpImageTexelPointer:
 	{
@@ -15987,8 +16243,7 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 	}
 
 	case OpImageWrite:
-		if (!compiler.msl_options.supports_msl_version(2, 2))
-			uses_resource_write = true;
+		uses_image_write = true;
 		break;
 
 	case OpStore:
@@ -16015,9 +16270,11 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 		auto it = image_pointers.find(args[2]);
 		if (it != image_pointers.end())
 		{
+			uses_image_write = true;
 			compiler.atomic_image_vars.insert(it->second);
 		}
-		check_resource_write(args[2]);
+		else
+			check_resource_write(args[2]);
 		break;
 	}
 
@@ -16028,8 +16285,10 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 		if (it != image_pointers.end())
 		{
 			compiler.atomic_image_vars.insert(it->second);
+			uses_image_write = true;
 		}
-		check_resource_write(args[0]);
+		else
+			check_resource_write(args[0]);
 		break;
 	}
 
@@ -16132,6 +16391,11 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 		break;
 	}
 
+	case OpIsHelperInvocationEXT:
+		if (compiler.needs_manual_helper_invocation_updates())
+			needs_helper_invocation = true;
+		break;
+
 	default:
 		break;
 	}
@@ -16149,9 +16413,8 @@ void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
 {
 	auto *p_var = compiler.maybe_get_backing_variable(var_id);
 	StorageClass sc = p_var ? p_var->storage : StorageClassMax;
-	if (!compiler.msl_options.supports_msl_version(2, 1) &&
-	    (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
-		uses_resource_write = true;
+	if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
+		uses_buffer_write = true;
 }
 
 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.

+ 32 - 3
3rdparty/spirv-cross/spirv_msl.hpp

@@ -458,6 +458,20 @@ public:
 		// the extra threads away.
 		bool force_sample_rate_shading = false;
 
+		// If set, gl_HelperInvocation will be set manually whenever a fragment is discarded.
+		// Some Metal devices have a bug where simd_is_helper_thread() does not return true
+		// after a fragment has been discarded. This is a workaround that is only expected to be needed
+		// until the bug is fixed in Metal; it is provided as an option to allow disabling it when that occurs.
+		bool manual_helper_invocation_updates = true;
+
+		// If set, extra checks will be emitted in fragment shaders to prevent writes
+		// from discarded fragments. Some Metal devices have a bug where writes to storage resources
+		// from discarded fragment threads continue to occur, despite the fragment being
+		// discarded. This is a workaround that is only expected to be needed until the
+		// bug is fixed in Metal; it is provided as an option so it can be enabled
+		// only when the bug is present.
+		bool check_discarded_frag_stores = false;
+
 		bool is_ios() const
 		{
 			return platform == iOS;
@@ -817,10 +831,9 @@ protected:
 	std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
 	bool emit_complex_bitcast(uint32_t result_id, uint32_t id, uint32_t op0) override;
 	bool skip_argument(uint32_t id) const override;
-	std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain) override;
+	std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved) override;
 	std::string to_qualifiers_glsl(uint32_t id) override;
 	void replace_illegal_names() override;
-	void declare_undefined_values() override;
 	void declare_constant_arrays();
 
 	void replace_illegal_entry_point_names();
@@ -1005,6 +1018,7 @@ protected:
 	uint32_t builtin_frag_coord_id = 0;
 	uint32_t builtin_sample_id_id = 0;
 	uint32_t builtin_sample_mask_id = 0;
+	uint32_t builtin_helper_invocation_id = 0;
 	uint32_t builtin_vertex_idx_id = 0;
 	uint32_t builtin_base_vertex_id = 0;
 	uint32_t builtin_instance_idx_id = 0;
@@ -1029,6 +1043,7 @@ protected:
 	uint32_t argument_buffer_padding_sampler_type_id = 0;
 
 	bool does_shader_write_sample_mask = false;
+	bool frag_shader_needs_discard_checks = false;
 
 	void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
 	void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
@@ -1113,6 +1128,7 @@ protected:
 	bool needs_subgroup_invocation_id = false;
 	bool needs_subgroup_size = false;
 	bool needs_sample_id = false;
+	bool needs_helper_invocation = false;
 	std::string qual_pos_var_name;
 	std::string stage_in_var_name = "in";
 	std::string stage_out_var_name = "out";
@@ -1180,6 +1196,16 @@ protected:
 
 	bool variable_storage_requires_stage_io(spv::StorageClass storage) const;
 
+	bool needs_manual_helper_invocation_updates() const
+	{
+		return msl_options.manual_helper_invocation_updates && msl_options.supports_msl_version(2, 3);
+	}
+	bool needs_frag_discard_checks() const
+	{
+		return get_execution_model() == spv::ExecutionModelFragment && msl_options.supports_msl_version(2, 3) &&
+		       msl_options.check_discarded_frag_stores && frag_shader_needs_discard_checks;
+	}
+
 	bool has_additional_fixed_sample_mask() const { return msl_options.additional_fixed_sample_mask != 0xffffffff; }
 	std::string additional_fixed_sample_mask_str() const;
 
@@ -1200,10 +1226,13 @@ protected:
 		std::unordered_map<uint32_t, uint32_t> image_pointers; // Emulate texture2D atomic operations
 		bool suppress_missing_prototypes = false;
 		bool uses_atomics = false;
-		bool uses_resource_write = false;
+		bool uses_image_write = false;
+		bool uses_buffer_write = false;
+		bool uses_discard = false;
 		bool needs_subgroup_invocation_id = false;
 		bool needs_subgroup_size = false;
 		bool needs_sample_id = false;
+		bool needs_helper_invocation = false;
 	};
 
 	// OpcodeHandler that scans for uses of sampled images

+ 13 - 9
3rdparty/spirv-cross/spirv_parser.cpp

@@ -275,24 +275,28 @@ void Parser::parse(const Instruction &instruction)
 	case OpExtInstImport:
 	{
 		uint32_t id = ops[0];
+
+		SPIRExtension::Extension spirv_ext = SPIRExtension::Unsupported;
+
 		auto ext = extract_string(ir.spirv, instruction.offset + 1);
 		if (ext == "GLSL.std.450")
-			set<SPIRExtension>(id, SPIRExtension::GLSL);
+			spirv_ext = SPIRExtension::GLSL;
 		else if (ext == "DebugInfo")
-			set<SPIRExtension>(id, SPIRExtension::SPV_debug_info);
+			spirv_ext = SPIRExtension::SPV_debug_info;
 		else if (ext == "SPV_AMD_shader_ballot")
-			set<SPIRExtension>(id, SPIRExtension::SPV_AMD_shader_ballot);
+			spirv_ext = SPIRExtension::SPV_AMD_shader_ballot;
 		else if (ext == "SPV_AMD_shader_explicit_vertex_parameter")
-			set<SPIRExtension>(id, SPIRExtension::SPV_AMD_shader_explicit_vertex_parameter);
+			spirv_ext = SPIRExtension::SPV_AMD_shader_explicit_vertex_parameter;
 		else if (ext == "SPV_AMD_shader_trinary_minmax")
-			set<SPIRExtension>(id, SPIRExtension::SPV_AMD_shader_trinary_minmax);
+			spirv_ext = SPIRExtension::SPV_AMD_shader_trinary_minmax;
 		else if (ext == "SPV_AMD_gcn_shader")
-			set<SPIRExtension>(id, SPIRExtension::SPV_AMD_gcn_shader);
+			spirv_ext = SPIRExtension::SPV_AMD_gcn_shader;
 		else if (ext == "NonSemantic.DebugPrintf")
-			set<SPIRExtension>(id, SPIRExtension::NonSemanticDebugPrintf);
-		else
-			set<SPIRExtension>(id, SPIRExtension::Unsupported);
+			spirv_ext = SPIRExtension::NonSemanticDebugPrintf;
+		else if (ext == "NonSemantic.Shader.DebugInfo.100")
+			spirv_ext = SPIRExtension::NonSemanticShaderDebugInfo;
 
+		set<SPIRExtension>(id, spirv_ext);
 		// Other SPIR-V extensions which have ExtInstrs are currently not supported.
 
 		break;