Explorar o código

Updated spirv-cross.

Бранимир Караџић hai 6 meses
pai
achega
7c79acf98e

+ 3 - 0
3rdparty/spirv-cross/spirv_common.hpp

@@ -1035,6 +1035,9 @@ struct SPIRFunction : IVariant
 	// consider arrays value types.
 	// consider arrays value types.
 	SmallVector<ID> constant_arrays_needed_on_stack;
 	SmallVector<ID> constant_arrays_needed_on_stack;
 
 
+	// Does this function (or any function called by it), emit geometry?
+	bool emits_geometry = false;
+
 	bool active = false;
 	bool active = false;
 	bool flush_undeclared = true;
 	bool flush_undeclared = true;
 	bool do_combined_parameters = true;
 	bool do_combined_parameters = true;

+ 40 - 3
3rdparty/spirv-cross/spirv_cross.cpp

@@ -82,7 +82,7 @@ bool Compiler::variable_storage_is_aliased(const SPIRVariable &v)
 	            ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
 	            ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
 	bool image = type.basetype == SPIRType::Image;
 	bool image = type.basetype == SPIRType::Image;
 	bool counter = type.basetype == SPIRType::AtomicCounter;
 	bool counter = type.basetype == SPIRType::AtomicCounter;
-	bool buffer_reference = type.storage == StorageClassPhysicalStorageBufferEXT;
+	bool buffer_reference = type.storage == StorageClassPhysicalStorageBuffer;
 
 
 	bool is_restrict;
 	bool is_restrict;
 	if (ssbo)
 	if (ssbo)
@@ -484,7 +484,7 @@ void Compiler::register_write(uint32_t chain)
 			}
 			}
 		}
 		}
 
 
-		if (type.storage == StorageClassPhysicalStorageBufferEXT || variable_storage_is_aliased(*var))
+		if (type.storage == StorageClassPhysicalStorageBuffer || variable_storage_is_aliased(*var))
 			flush_all_aliased_variables();
 			flush_all_aliased_variables();
 		else if (var)
 		else if (var)
 			flush_dependees(*var);
 			flush_dependees(*var);
@@ -4362,6 +4362,39 @@ bool Compiler::may_read_undefined_variable_in_block(const SPIRBlock &block, uint
 	return true;
 	return true;
 }
 }
 
 
+bool Compiler::GeometryEmitDisocveryHandler::handle(spv::Op opcode, const uint32_t *, uint32_t)
+{
+	if (opcode == OpEmitVertex || opcode == OpEndPrimitive)
+	{
+		for (auto *func : function_stack)
+			func->emits_geometry = true;
+	}
+
+	return true;
+}
+
+bool Compiler::GeometryEmitDisocveryHandler::begin_function_scope(const uint32_t *stream, uint32_t)
+{
+	auto &callee = compiler.get<SPIRFunction>(stream[2]);
+	function_stack.push_back(&callee);
+	return true;
+}
+
+bool Compiler::GeometryEmitDisocveryHandler::end_function_scope([[maybe_unused]] const uint32_t *stream, uint32_t)
+{
+	assert(function_stack.back() == &compiler.get<SPIRFunction>(stream[2]));
+	function_stack.pop_back();
+
+	return true;
+}
+
+void Compiler::discover_geometry_emitters()
+{
+	GeometryEmitDisocveryHandler handler(*this);
+
+	traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), handler);
+}
+
 Bitset Compiler::get_buffer_block_flags(VariableID id) const
 Bitset Compiler::get_buffer_block_flags(VariableID id) const
 {
 {
 	return ir.get_buffer_block_flags(get<SPIRVariable>(id));
 	return ir.get_buffer_block_flags(get<SPIRVariable>(id));
@@ -5194,7 +5227,7 @@ bool Compiler::PhysicalStorageBufferPointerHandler::type_is_bda_block_entry(uint
 
 
 uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_minimum_scalar_alignment(const SPIRType &type) const
 uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_minimum_scalar_alignment(const SPIRType &type) const
 {
 {
-	if (type.storage == spv::StorageClassPhysicalStorageBufferEXT)
+	if (type.storage == spv::StorageClassPhysicalStorageBuffer)
 		return 8;
 		return 8;
 	else if (type.basetype == SPIRType::Struct)
 	else if (type.basetype == SPIRType::Struct)
 	{
 	{
@@ -5298,6 +5331,10 @@ uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_base_non_block_type_
 
 
 void Compiler::PhysicalStorageBufferPointerHandler::analyze_non_block_types_from_block(const SPIRType &type)
 void Compiler::PhysicalStorageBufferPointerHandler::analyze_non_block_types_from_block(const SPIRType &type)
 {
 {
+	if (analyzed_type_ids.count(type.self))
+		return;
+	analyzed_type_ids.insert(type.self);
+
 	for (auto &member : type.member_types)
 	for (auto &member : type.member_types)
 	{
 	{
 		auto &subtype = compiler.get<SPIRType>(member);
 		auto &subtype = compiler.get<SPIRType>(member);

+ 17 - 0
3rdparty/spirv-cross/spirv_cross.hpp

@@ -1054,6 +1054,7 @@ protected:
 		std::unordered_set<uint32_t> non_block_types;
 		std::unordered_set<uint32_t> non_block_types;
 		std::unordered_map<uint32_t, PhysicalBlockMeta> physical_block_type_meta;
 		std::unordered_map<uint32_t, PhysicalBlockMeta> physical_block_type_meta;
 		std::unordered_map<uint32_t, PhysicalBlockMeta *> access_chain_to_physical_block;
 		std::unordered_map<uint32_t, PhysicalBlockMeta *> access_chain_to_physical_block;
+		std::unordered_set<uint32_t> analyzed_type_ids;
 
 
 		void mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length);
 		void mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length);
 		PhysicalBlockMeta *find_block_meta(uint32_t id) const;
 		PhysicalBlockMeta *find_block_meta(uint32_t id) const;
@@ -1072,6 +1073,22 @@ protected:
 	                              bool single_function);
 	                              bool single_function);
 	bool may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var);
 	bool may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var);
 
 
+	struct GeometryEmitDisocveryHandler : OpcodeHandler
+	{
+		explicit GeometryEmitDisocveryHandler(Compiler &compiler_)
+		    : compiler(compiler_)
+		{
+		}
+		Compiler &compiler;
+
+		bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override;
+		bool begin_function_scope(const uint32_t *, uint32_t) override;
+		bool end_function_scope(const uint32_t *, uint32_t) override;
+		SmallVector<SPIRFunction *> function_stack;
+	};
+
+	void discover_geometry_emitters();
+
 	// Finds all resources that are written to from inside the critical section, if present.
 	// Finds all resources that are written to from inside the critical section, if present.
 	// The critical section is delimited by OpBeginInvocationInterlockEXT and
 	// The critical section is delimited by OpBeginInvocationInterlockEXT and
 	// OpEndInvocationInterlockEXT instructions. In MSL and HLSL, any resources written
 	// OpEndInvocationInterlockEXT instructions. In MSL and HLSL, any resources written

+ 1 - 0
3rdparty/spirv-cross/spirv_cross_c.cpp

@@ -55,6 +55,7 @@
 #ifdef _MSC_VER
 #ifdef _MSC_VER
 #pragma warning(push)
 #pragma warning(push)
 #pragma warning(disable : 4996)
 #pragma warning(disable : 4996)
+#pragma warning(disable : 4065) // switch with 'default' but not 'case'.
 #endif
 #endif
 
 
 #ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS
 #ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS

+ 53 - 40
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -545,7 +545,7 @@ void CompilerGLSL::find_static_extensions()
 	if (options.separate_shader_objects && !options.es && options.version < 410)
 	if (options.separate_shader_objects && !options.es && options.version < 410)
 		require_extension_internal("GL_ARB_separate_shader_objects");
 		require_extension_internal("GL_ARB_separate_shader_objects");
 
 
-	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
+	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
 	{
 	{
 		if (!options.vulkan_semantics)
 		if (!options.vulkan_semantics)
 			SPIRV_CROSS_THROW("GL_EXT_buffer_reference is only supported in Vulkan GLSL.");
 			SPIRV_CROSS_THROW("GL_EXT_buffer_reference is only supported in Vulkan GLSL.");
@@ -557,7 +557,7 @@ void CompilerGLSL::find_static_extensions()
 	}
 	}
 	else if (ir.addressing_model != AddressingModelLogical)
 	else if (ir.addressing_model != AddressingModelLogical)
 	{
 	{
-		SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64EXT addressing models are supported.");
+		SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64 addressing models are supported.");
 	}
 	}
 
 
 	// Check for nonuniform qualifier and passthrough.
 	// Check for nonuniform qualifier and passthrough.
@@ -708,7 +708,7 @@ string CompilerGLSL::compile()
 
 
 	// Shaders might cast unrelated data to pointers of non-block types.
 	// Shaders might cast unrelated data to pointers of non-block types.
 	// Find all such instances and make sure we can cast the pointers to a synthesized block type.
 	// Find all such instances and make sure we can cast the pointers to a synthesized block type.
-	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
+	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
 		analyze_non_block_pointer_types();
 		analyze_non_block_pointer_types();
 
 
 	uint32_t pass_count = 0;
 	uint32_t pass_count = 0;
@@ -1542,14 +1542,14 @@ uint32_t CompilerGLSL::type_to_packed_base_size(const SPIRType &type, BufferPack
 uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bitset &flags,
 uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bitset &flags,
                                                 BufferPackingStandard packing)
                                                 BufferPackingStandard packing)
 {
 {
-	// If using PhysicalStorageBufferEXT storage class, this is a pointer,
+	// If using PhysicalStorageBuffer storage class, this is a pointer,
 	// and is 64-bit.
 	// and is 64-bit.
 	if (is_physical_pointer(type))
 	if (is_physical_pointer(type))
 	{
 	{
 		if (!type.pointer)
 		if (!type.pointer)
-			SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers.");
+			SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers.");
 
 
-		if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
+		if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
 		{
 		{
 			if (packing_is_vec4_padded(packing) && type_is_array_of_pointers(type))
 			if (packing_is_vec4_padded(packing) && type_is_array_of_pointers(type))
 				return 16;
 				return 16;
@@ -1557,7 +1557,7 @@ uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bits
 				return 8;
 				return 8;
 		}
 		}
 		else
 		else
-			SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT.");
+			SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer.");
 	}
 	}
 	else if (is_array(type))
 	else if (is_array(type))
 	{
 	{
@@ -1665,17 +1665,17 @@ uint32_t CompilerGLSL::type_to_packed_array_stride(const SPIRType &type, const B
 
 
 uint32_t CompilerGLSL::type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing)
 uint32_t CompilerGLSL::type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing)
 {
 {
-	// If using PhysicalStorageBufferEXT storage class, this is a pointer,
+	// If using PhysicalStorageBuffer storage class, this is a pointer,
 	// and is 64-bit.
 	// and is 64-bit.
 	if (is_physical_pointer(type))
 	if (is_physical_pointer(type))
 	{
 	{
 		if (!type.pointer)
 		if (!type.pointer)
-			SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers.");
+			SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers.");
 
 
-		if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
+		if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
 			return 8;
 			return 8;
 		else
 		else
-			SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT.");
+			SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer.");
 	}
 	}
 	else if (is_array(type))
 	else if (is_array(type))
 	{
 	{
@@ -3638,6 +3638,36 @@ void CompilerGLSL::emit_resources()
 
 
 	bool emitted = false;
 	bool emitted = false;
 
 
+	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
+	{
+		// Output buffer reference block forward declarations.
+		ir.for_each_typed_id<SPIRType>([&](uint32_t id, SPIRType &type)
+		{
+			if (is_physical_pointer(type))
+			{
+				bool emit_type = true;
+				if (!is_physical_pointer_to_buffer_block(type))
+				{
+					// Only forward-declare if we intend to emit it in the non_block_pointer types.
+					// Otherwise, these are just "benign" pointer types that exist as a result of access chains.
+					emit_type = std::find(physical_storage_non_block_pointer_types.begin(),
+					                      physical_storage_non_block_pointer_types.end(),
+					                      id) != physical_storage_non_block_pointer_types.end();
+				}
+
+				if (emit_type)
+				{
+					emit_buffer_reference_block(id, true);
+					emitted = true;
+				}
+			}
+		});
+	}
+
+	if (emitted)
+		statement("");
+	emitted = false;
+
 	// If emitted Vulkan GLSL,
 	// If emitted Vulkan GLSL,
 	// emit specialization constants as actual floats,
 	// emit specialization constants as actual floats,
 	// spec op expressions will redirect to the constant name.
 	// spec op expressions will redirect to the constant name.
@@ -3747,30 +3777,10 @@ void CompilerGLSL::emit_resources()
 
 
 	emitted = false;
 	emitted = false;
 
 
-	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
+	if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
 	{
 	{
 		// Output buffer reference blocks.
 		// Output buffer reference blocks.
-		// Do this in two stages, one with forward declaration,
-		// and one without. Buffer reference blocks can reference themselves
-		// to support things like linked lists.
-		ir.for_each_typed_id<SPIRType>([&](uint32_t id, SPIRType &type) {
-			if (is_physical_pointer(type))
-			{
-				bool emit_type = true;
-				if (!is_physical_pointer_to_buffer_block(type))
-				{
-					// Only forward-declare if we intend to emit it in the non_block_pointer types.
-					// Otherwise, these are just "benign" pointer types that exist as a result of access chains.
-					emit_type = std::find(physical_storage_non_block_pointer_types.begin(),
-					                      physical_storage_non_block_pointer_types.end(),
-					                      id) != physical_storage_non_block_pointer_types.end();
-				}
-
-				if (emit_type)
-					emit_buffer_reference_block(id, true);
-			}
-		});
-
+		// Buffer reference blocks can reference themselves to support things like linked lists.
 		for (auto type : physical_storage_non_block_pointer_types)
 		for (auto type : physical_storage_non_block_pointer_types)
 			emit_buffer_reference_block(type, false);
 			emit_buffer_reference_block(type, false);
 
 
@@ -10317,7 +10327,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 		if (!is_ptr_chain)
 		if (!is_ptr_chain)
 			mod_flags &= ~ACCESS_CHAIN_PTR_CHAIN_BIT;
 			mod_flags &= ~ACCESS_CHAIN_PTR_CHAIN_BIT;
 		access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index);
 		access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index);
-		check_physical_type_cast(expr, type, physical_type);
+		if (check_physical_type_cast(expr, type, physical_type))
+			physical_type = 0;
 	};
 	};
 
 
 	for (uint32_t i = 0; i < count; i++)
 	for (uint32_t i = 0; i < count; i++)
@@ -10825,8 +10836,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 	return expr;
 	return expr;
 }
 }
 
 
-void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t)
+bool CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t)
 {
 {
+	return false;
 }
 }
 
 
 bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
 bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
@@ -15337,8 +15349,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 	case OpConvertUToPtr:
 	case OpConvertUToPtr:
 	{
 	{
 		auto &type = get<SPIRType>(ops[0]);
 		auto &type = get<SPIRType>(ops[0]);
-		if (type.storage != StorageClassPhysicalStorageBufferEXT)
-			SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertUToPtr.");
+		if (type.storage != StorageClassPhysicalStorageBuffer)
+			SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertUToPtr.");
 
 
 		auto &in_type = expression_type(ops[2]);
 		auto &in_type = expression_type(ops[2]);
 		if (in_type.vecsize == 2)
 		if (in_type.vecsize == 2)
@@ -15353,8 +15365,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 	{
 	{
 		auto &type = get<SPIRType>(ops[0]);
 		auto &type = get<SPIRType>(ops[0]);
 		auto &ptr_type = expression_type(ops[2]);
 		auto &ptr_type = expression_type(ops[2]);
-		if (ptr_type.storage != StorageClassPhysicalStorageBufferEXT)
-			SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertPtrToU.");
+		if (ptr_type.storage != StorageClassPhysicalStorageBuffer)
+			SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertPtrToU.");
 
 
 		if (type.vecsize == 2)
 		if (type.vecsize == 2)
 			require_extension_internal("GL_EXT_buffer_reference_uvec2");
 			require_extension_internal("GL_EXT_buffer_reference_uvec2");
@@ -16143,7 +16155,7 @@ string CompilerGLSL::to_array_size(const SPIRType &type, uint32_t index)
 
 
 string CompilerGLSL::type_to_array_glsl(const SPIRType &type, uint32_t)
 string CompilerGLSL::type_to_array_glsl(const SPIRType &type, uint32_t)
 {
 {
-	if (type.pointer && type.storage == StorageClassPhysicalStorageBufferEXT && type.basetype != SPIRType::Struct)
+	if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer && type.basetype != SPIRType::Struct)
 	{
 	{
 		// We are using a wrapped pointer type, and we should not emit any array declarations here.
 		// We are using a wrapped pointer type, and we should not emit any array declarations here.
 		return "";
 		return "";
@@ -16856,6 +16868,7 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags)
 			{
 			{
 				// Recursively emit functions which are called.
 				// Recursively emit functions which are called.
 				uint32_t id = ops[2];
 				uint32_t id = ops[2];
+
 				emit_function(get<SPIRFunction>(id), ir.meta[ops[1]].decoration.decoration_flags);
 				emit_function(get<SPIRFunction>(id), ir.meta[ops[1]].decoration.decoration_flags);
 			}
 			}
 		}
 		}

+ 2 - 2
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -769,7 +769,7 @@ protected:
 	spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
 	spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
 	virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
 	virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
 
 
-	virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
+	virtual bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
 	virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
 	virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
 	                                                    spv::StorageClass storage, bool &is_packed);
 	                                                    spv::StorageClass storage, bool &is_packed);
 
 
@@ -799,7 +799,7 @@ protected:
 	std::string declare_temporary(uint32_t type, uint32_t id);
 	std::string declare_temporary(uint32_t type, uint32_t id);
 	void emit_uninitialized_temporary(uint32_t type, uint32_t id);
 	void emit_uninitialized_temporary(uint32_t type, uint32_t id);
 	SPIRExpression &emit_uninitialized_temporary_expression(uint32_t type, uint32_t id);
 	SPIRExpression &emit_uninitialized_temporary_expression(uint32_t type, uint32_t id);
-	void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist);
+	virtual void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist);
 	std::string to_non_uniform_aware_expression(uint32_t id);
 	std::string to_non_uniform_aware_expression(uint32_t id);
 	std::string to_atomic_ptr_expression(uint32_t id);
 	std::string to_atomic_ptr_expression(uint32_t id);
 	std::string to_expression(uint32_t id, bool register_expression_read = true);
 	std::string to_expression(uint32_t id, bool register_expression_read = true);

+ 213 - 9
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -1117,7 +1117,9 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord
 		else
 		else
 		{
 		{
 			auto decl_type = type;
 			auto decl_type = type;
-			if (execution.model == ExecutionModelMeshEXT || has_decoration(var.self, DecorationPerVertexKHR))
+			if (execution.model == ExecutionModelMeshEXT ||
+			    (execution.model == ExecutionModelGeometry && var.storage == StorageClassInput) ||
+			    has_decoration(var.self, DecorationPerVertexKHR))
 			{
 			{
 				decl_type.array.erase(decl_type.array.begin());
 				decl_type.array.erase(decl_type.array.begin());
 				decl_type.array_size_literal.erase(decl_type.array_size_literal.begin());
 				decl_type.array_size_literal.erase(decl_type.array_size_literal.begin());
@@ -1834,7 +1836,7 @@ void CompilerHLSL::emit_resources()
 	if (!output_variables.empty() || !active_output_builtins.empty())
 	if (!output_variables.empty() || !active_output_builtins.empty())
 	{
 	{
 		sort(output_variables.begin(), output_variables.end(), variable_compare);
 		sort(output_variables.begin(), output_variables.end(), variable_compare);
-		require_output = !is_mesh_shader;
+		require_output = !(is_mesh_shader || execution.model == ExecutionModelGeometry);
 
 
 		statement(is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
 		statement(is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
 		begin_scope();
 		begin_scope();
@@ -2678,6 +2680,83 @@ void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block)
 	}
 	}
 }
 }
 
 
+void CompilerHLSL::emit_geometry_stream_append()
+{
+	begin_scope();
+	statement("SPIRV_Cross_Output stage_output;");
+
+	active_output_builtins.for_each_bit(
+	    [&](uint32_t i)
+	    {
+		    if (i == BuiltInPointSize && hlsl_options.shader_model > 30)
+			    return;
+		    switch (static_cast<BuiltIn>(i))
+		    {
+		    case BuiltInClipDistance:
+			    for (uint32_t clip = 0; clip < clip_distance_count; clip++)
+				    statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
+				              clip, "];");
+			    break;
+		    case BuiltInCullDistance:
+			    for (uint32_t cull = 0; cull < cull_distance_count; cull++)
+				    statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
+				              cull, "];");
+			    break;
+		    case BuiltInSampleMask:
+			    statement("stage_output.gl_SampleMask = gl_SampleMask[0];");
+			    break;
+		    default:
+		    {
+			    auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
+			    statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
+		    }
+		    break;
+		    }
+	    });
+
+	ir.for_each_typed_id<SPIRVariable>(
+	    [&](uint32_t, SPIRVariable &var)
+	    {
+		    auto &type = this->get<SPIRType>(var.basetype);
+		    bool block = has_decoration(type.self, DecorationBlock);
+
+		    if (var.storage != StorageClassOutput)
+			    return;
+
+		    if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
+		        interface_variable_exists_in_entry_point(var.self))
+		    {
+			    if (block)
+			    {
+				    auto type_name = to_name(type.self);
+				    auto var_name = to_name(var.self);
+				    for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
+				    {
+					    auto mbr_name = to_member_name(type, mbr_idx);
+					    auto flat_name = join(type_name, "_", mbr_name);
+					    statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";");
+				    }
+			    }
+			    else
+			    {
+				    auto name = to_name(var.self);
+				    if (hlsl_options.shader_model <= 30 && get_entry_point().model == ExecutionModelFragment)
+				    {
+					    string output_filler;
+					    for (uint32_t size = type.vecsize; size < 4; ++size)
+						    output_filler += ", 0.0";
+					    statement("stage_output.", name, " = float4(", name, output_filler, ");");
+				    }
+				    else
+					    statement("stage_output.", name, " = ", name, ";");
+			    }
+		    }
+	    });
+
+	statement("geometry_stream.Append(stage_output);");
+	end_scope();
+}
+
 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
 {
 {
 	auto &type = get<SPIRType>(var.basetype);
 	auto &type = get<SPIRType>(var.basetype);
@@ -2940,6 +3019,8 @@ string CompilerHLSL::get_inner_entry_point_name() const
 		return "frag_main";
 		return "frag_main";
 	else if (execution.model == ExecutionModelGLCompute)
 	else if (execution.model == ExecutionModelGLCompute)
 		return "comp_main";
 		return "comp_main";
+	else if (execution.model == ExecutionModelGeometry)
+		return "geom_main";
 	else if (execution.model == ExecutionModelMeshEXT)
 	else if (execution.model == ExecutionModelMeshEXT)
 		return "mesh_main";
 		return "mesh_main";
 	else if (execution.model == ExecutionModelTaskEXT)
 	else if (execution.model == ExecutionModelTaskEXT)
@@ -2948,6 +3029,25 @@ string CompilerHLSL::get_inner_entry_point_name() const
 		SPIRV_CROSS_THROW("Unsupported execution model.");
 		SPIRV_CROSS_THROW("Unsupported execution model.");
 }
 }
 
 
+uint32_t CompilerHLSL::input_vertices_from_execution_mode(spirv_cross::SPIREntryPoint &execution) const
+{
+	uint32_t input_vertices = 1;
+
+	if (execution.flags.get(ExecutionModeInputLines))
+		input_vertices = 2;
+	else if (execution.flags.get(ExecutionModeInputLinesAdjacency))
+		input_vertices = 4;
+	else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
+		input_vertices = 6;
+	else if (execution.flags.get(ExecutionModeTriangles))
+		input_vertices = 3;
+	else if (execution.flags.get(ExecutionModeInputPoints))
+		input_vertices = 1;
+	else
+		SPIRV_CROSS_THROW("Unsupported execution model.");
+	return input_vertices;
+}
+
 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
 {
 {
 	if (func.self != ir.default_entry_point)
 	if (func.self != ir.default_entry_point)
@@ -3041,6 +3141,38 @@ void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &ret
 			var->parameter = &arg;
 			var->parameter = &arg;
 	}
 	}
 
 
+	if ((func.self == ir.default_entry_point || func.emits_geometry) &&
+	    get_entry_point().model == ExecutionModelGeometry)
+	{
+		auto &execution = get_entry_point();
+
+		uint32_t input_vertices = input_vertices_from_execution_mode(execution);
+
+		const char *prim;
+		if (execution.flags.get(ExecutionModeInputLinesAdjacency))
+			prim = "lineadj";
+		else if (execution.flags.get(ExecutionModeInputLines))
+			prim = "line";
+		else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
+			prim = "triangleadj";
+		else if (execution.flags.get(ExecutionModeTriangles))
+			prim = "triangle";
+		else
+			prim = "point";
+
+		const char *stream_type;
+		if (execution.flags.get(ExecutionModeOutputPoints))
+			stream_type = "PointStream";
+		else if (execution.flags.get(ExecutionModeOutputLineStrip))
+			stream_type = "LineStream";
+		else
+			stream_type = "TriangleStream";
+
+		if (func.self == ir.default_entry_point)
+			arglist.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]"));
+		arglist.push_back(join("inout ", stream_type, "<SPIRV_Cross_Output> ", "geometry_stream"));
+	}
+
 	decl += merge(arglist);
 	decl += merge(arglist);
 	decl += ")";
 	decl += ")";
 	statement(decl);
 	statement(decl);
@@ -3050,13 +3182,50 @@ void CompilerHLSL::emit_hlsl_entry_point()
 {
 {
 	SmallVector<string> arguments;
 	SmallVector<string> arguments;
 
 
-	if (require_input)
+	if (require_input && get_entry_point().model != ExecutionModelGeometry)
 		arguments.push_back("SPIRV_Cross_Input stage_input");
 		arguments.push_back("SPIRV_Cross_Input stage_input");
 
 
 	auto &execution = get_entry_point();
 	auto &execution = get_entry_point();
 
 
+	uint32_t input_vertices = 1;
+
 	switch (execution.model)
 	switch (execution.model)
 	{
 	{
+	case ExecutionModelGeometry:
+	{
+		input_vertices = input_vertices_from_execution_mode(execution);
+
+		string prim;
+		if (execution.flags.get(ExecutionModeInputLinesAdjacency))
+			prim = "lineadj";
+		else if (execution.flags.get(ExecutionModeInputLines))
+			prim = "line";
+		else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
+			prim = "triangleadj";
+		else if (execution.flags.get(ExecutionModeTriangles))
+			prim = "triangle";
+		else
+			prim = "point";
+
+		string stream_type;
+		if (execution.flags.get(ExecutionModeOutputPoints))
+		{
+			stream_type = "PointStream";
+		}
+		else if (execution.flags.get(ExecutionModeOutputLineStrip))
+		{
+			stream_type = "LineStream";
+		}
+		else
+		{
+			stream_type = "TriangleStream";
+		}
+
+		statement("[maxvertexcount(", execution.output_vertices, ")]");
+		arguments.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]"));
+		arguments.push_back(join("inout ", stream_type, "<SPIRV_Cross_Output> ", "geometry_stream"));
+		break;
+	}
 	case ExecutionModelTaskEXT:
 	case ExecutionModelTaskEXT:
 	case ExecutionModelMeshEXT:
 	case ExecutionModelMeshEXT:
 	case ExecutionModelGLCompute:
 	case ExecutionModelGLCompute:
@@ -3359,18 +3528,24 @@ void CompilerHLSL::emit_hlsl_entry_point()
 				}
 				}
 				else
 				else
 				{
 				{
-					statement(name, " = stage_input.", name, ";");
+					if (execution.model == ExecutionModelGeometry)
+					{
+						statement("for (int i = 0; i < ", input_vertices, "; i++)");
+						begin_scope();
+						statement(name, "[i] = stage_input[i].", name, ";");
+						end_scope();
+					}
+					else
+						statement(name, " = stage_input.", name, ";");
 				}
 				}
 			}
 			}
 		}
 		}
 	});
 	});
 
 
 	// Run the shader.
 	// Run the shader.
-	if (execution.model == ExecutionModelVertex ||
-	    execution.model == ExecutionModelFragment ||
-	    execution.model == ExecutionModelGLCompute ||
-	    execution.model == ExecutionModelMeshEXT ||
-	    execution.model == ExecutionModelTaskEXT)
+	if (execution.model == ExecutionModelVertex || execution.model == ExecutionModelFragment ||
+	    execution.model == ExecutionModelGLCompute || execution.model == ExecutionModelMeshEXT ||
+	    execution.model == ExecutionModelGeometry || execution.model == ExecutionModelTaskEXT)
 	{
 	{
 		// For mesh shaders, we receive special arguments that we must pass down as function arguments.
 		// For mesh shaders, we receive special arguments that we must pass down as function arguments.
 		// HLSL does not support proper reference types for passing these IO blocks,
 		// HLSL does not support proper reference types for passing these IO blocks,
@@ -3378,8 +3553,16 @@ void CompilerHLSL::emit_hlsl_entry_point()
 		SmallVector<string> arglist;
 		SmallVector<string> arglist;
 		auto &func = get<SPIRFunction>(ir.default_entry_point);
 		auto &func = get<SPIRFunction>(ir.default_entry_point);
 		// The arguments are marked out, avoid detecting reads and emitting inout.
 		// The arguments are marked out, avoid detecting reads and emitting inout.
+
 		for (auto &arg : func.arguments)
 		for (auto &arg : func.arguments)
 			arglist.push_back(to_expression(arg.id, false));
 			arglist.push_back(to_expression(arg.id, false));
+
+		if (execution.model == ExecutionModelGeometry)
+		{
+			arglist.push_back("stage_input");
+			arglist.push_back("geometry_stream");
+		}
+
 		statement(get_inner_entry_point_name(), "(", merge(arglist), ");");
 		statement(get_inner_entry_point_name(), "(", merge(arglist), ");");
 	}
 	}
 	else
 	else
@@ -4206,6 +4389,14 @@ bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
 	return false;
 	return false;
 }
 }
 
 
+void CompilerHLSL::append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist)
+{
+	CompilerGLSL::append_global_func_args(func, index, arglist);
+
+	if (func.emits_geometry)
+		arglist.push_back("geometry_stream");
+}
+
 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
 {
 {
 	if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
 	if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
@@ -6594,6 +6785,16 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 		statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");");
 		statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");");
 		break;
 		break;
 	}
 	}
+	case OpEmitVertex:
+	{
+		emit_geometry_stream_append();
+		break;
+	}
+	case OpEndPrimitive:
+	{
+		statement("geometry_stream.RestartStrip();");
+		break;
+	}
 	default:
 	default:
 		CompilerGLSL::emit_instruction(instruction);
 		CompilerGLSL::emit_instruction(instruction);
 		break;
 		break;
@@ -6812,6 +7013,9 @@ string CompilerHLSL::compile()
 	if (get_execution_model() == ExecutionModelMeshEXT)
 	if (get_execution_model() == ExecutionModelMeshEXT)
 		analyze_meshlet_writes();
 		analyze_meshlet_writes();
 
 
+	if (get_execution_model() == ExecutionModelGeometry)
+		discover_geometry_emitters();
+
 	// Subpass input needs SV_Position.
 	// Subpass input needs SV_Position.
 	if (need_subpass_input)
 	if (need_subpass_input)
 		active_input_builtins.set(BuiltInFragCoord);
 		active_input_builtins.set(BuiltInFragCoord);

+ 4 - 0
3rdparty/spirv-cross/spirv_hlsl.hpp

@@ -231,6 +231,7 @@ private:
 	std::string image_type_hlsl(const SPIRType &type, uint32_t id);
 	std::string image_type_hlsl(const SPIRType &type, uint32_t id);
 	std::string image_type_hlsl_modern(const SPIRType &type, uint32_t id);
 	std::string image_type_hlsl_modern(const SPIRType &type, uint32_t id);
 	std::string image_type_hlsl_legacy(const SPIRType &type, uint32_t id);
 	std::string image_type_hlsl_legacy(const SPIRType &type, uint32_t id);
+	uint32_t input_vertices_from_execution_mode(SPIREntryPoint &execution) const;
 	void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
 	void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
 	void emit_hlsl_entry_point();
 	void emit_hlsl_entry_point();
 	void emit_header() override;
 	void emit_header() override;
@@ -259,6 +260,8 @@ private:
 	std::string to_interpolation_qualifiers(const Bitset &flags) override;
 	std::string to_interpolation_qualifiers(const Bitset &flags) override;
 	std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
 	std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
 	bool emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0) override;
 	bool emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0) override;
+	void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist) override;
+
 	std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) override;
 	std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) override;
 	std::string to_sampler_expression(uint32_t id);
 	std::string to_sampler_expression(uint32_t id);
 	std::string to_resource_binding(const SPIRVariable &var);
 	std::string to_resource_binding(const SPIRVariable &var);
@@ -286,6 +289,7 @@ private:
 	                        uint32_t base_offset = 0) override;
 	                        uint32_t base_offset = 0) override;
 	void emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops);
 	void emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops);
 	void emit_mesh_tasks(SPIRBlock &block) override;
 	void emit_mesh_tasks(SPIRBlock &block) override;
+	void emit_geometry_stream_append();
 
 
 	const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override;
 	const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override;
 	void replace_illegal_names() override;
 	void replace_illegal_names() override;

+ 343 - 30
3rdparty/spirv-cross/spirv_msl.cpp

@@ -2222,6 +2222,27 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				break;
 				break;
 			}
 			}
 
 
+			case OpGroupNonUniformFAdd:
+			case OpGroupNonUniformFMul:
+			case OpGroupNonUniformFMin:
+			case OpGroupNonUniformFMax:
+			case OpGroupNonUniformIAdd:
+			case OpGroupNonUniformIMul:
+			case OpGroupNonUniformSMin:
+			case OpGroupNonUniformSMax:
+			case OpGroupNonUniformUMin:
+			case OpGroupNonUniformUMax:
+			case OpGroupNonUniformBitwiseAnd:
+			case OpGroupNonUniformBitwiseOr:
+			case OpGroupNonUniformBitwiseXor:
+			case OpGroupNonUniformLogicalAnd:
+			case OpGroupNonUniformLogicalOr:
+			case OpGroupNonUniformLogicalXor:
+				if ((get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) &&
+				    ops[3] == GroupOperationClusteredReduce)
+					added_arg_ids.insert(builtin_subgroup_invocation_id_id);
+				break;
+
 			case OpDemoteToHelperInvocation:
 			case OpDemoteToHelperInvocation:
 				if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
 				if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
 					added_arg_ids.insert(builtin_helper_invocation_id);
 					added_arg_ids.insert(builtin_helper_invocation_id);
@@ -7026,6 +7047,105 @@ void CompilerMSL::emit_custom_functions()
 			statement("");
 			statement("");
 			break;
 			break;
 
 
+			// C++ disallows partial specializations of function templates,
+			// hence the use of a struct.
+			// clang-format off
+#define FUNC_SUBGROUP_CLUSTERED(spv, msl, combine, op, ident) \
+		case SPVFuncImplSubgroupClustered##spv: \
+			statement("template<uint N, uint offset>"); \
+			statement("struct spvClustered" #spv "Detail;"); \
+			statement(""); \
+			statement("// Base cases"); \
+			statement("template<>"); \
+			statement("struct spvClustered" #spv "Detail<1, 0>"); \
+			begin_scope(); \
+			statement("template<typename T>"); \
+			statement("static T op(T value, uint)"); \
+			begin_scope(); \
+			statement("return value;"); \
+			end_scope(); \
+			end_scope_decl(); \
+			statement(""); \
+			statement("template<uint offset>"); \
+			statement("struct spvClustered" #spv "Detail<1, offset>"); \
+			begin_scope(); \
+			statement("template<typename T>"); \
+			statement("static T op(T value, uint lid)"); \
+			begin_scope(); \
+			statement("// If the target lane is inactive, then return identity."); \
+			if (msl_options.use_quadgroup_operation()) \
+				statement("if (!extract_bits((quad_vote::vote_t)quad_active_threads_mask(), (lid ^ offset), 1))"); \
+			else \
+				statement("if (!extract_bits(as_type<uint2>((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1))"); \
+			statement("    return " #ident ";"); \
+			if (msl_options.use_quadgroup_operation()) \
+				statement("return quad_shuffle_xor(value, offset);"); \
+			else \
+				statement("return simd_shuffle_xor(value, offset);"); \
+			end_scope(); \
+			end_scope_decl(); \
+			statement(""); \
+			statement("template<>"); \
+			statement("struct spvClustered" #spv "Detail<4, 0>"); \
+			begin_scope(); \
+			statement("template<typename T>"); \
+			statement("static T op(T value, uint)"); \
+			begin_scope(); \
+			statement("return quad_" #msl "(value);"); \
+			end_scope(); \
+			end_scope_decl(); \
+			statement(""); \
+			statement("template<uint offset>"); \
+			statement("struct spvClustered" #spv "Detail<4, offset>"); \
+			begin_scope(); \
+			statement("template<typename T>"); \
+			statement("static T op(T value, uint lid)"); \
+			begin_scope(); \
+			statement("// Here, we care if any of the lanes in the quad are active."); \
+			statement("uint quad_mask = extract_bits(as_type<uint2>((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4);"); \
+			statement("if (!quad_mask)"); \
+			statement("    return " #ident ";"); \
+			statement("// But we need to make sure we shuffle from an active lane."); \
+			if (msl_options.use_quadgroup_operation()) \
+				SPIRV_CROSS_THROW("Subgroup size with quadgroup operation cannot exceed 4."); \
+			else \
+				statement("return simd_shuffle(quad_" #msl "(value), ((lid ^ offset) & ~3) | ctz(quad_mask));"); \
+			end_scope(); \
+			end_scope_decl(); \
+			statement(""); \
+			statement("// General case"); \
+			statement("template<uint N, uint offset>"); \
+			statement("struct spvClustered" #spv "Detail"); \
+			begin_scope(); \
+			statement("template<typename T>"); \
+			statement("static T op(T value, uint lid)"); \
+			begin_scope(); \
+			statement("return " combine(msl, op, "spvClustered" #spv "Detail<N/2, offset>::op(value, lid)", "spvClustered" #spv "Detail<N/2, offset + N/2>::op(value, lid)") ";"); \
+			end_scope(); \
+			end_scope_decl(); \
+			statement(""); \
+			statement("template<uint N, typename T>"); \
+			statement("T spvClustered_" #msl "(T value, uint lid)"); \
+			begin_scope(); \
+			statement("return spvClustered" #spv "Detail<N, 0>::op(value, lid);"); \
+			end_scope(); \
+			statement(""); \
+			break
+#define BINOP(msl, op, l, r) l " " #op " " r
+#define BINFUNC(msl, op, l, r) #msl "(" l ", " r ")"
+
+		FUNC_SUBGROUP_CLUSTERED(Add, sum, BINOP, +, 0);
+		FUNC_SUBGROUP_CLUSTERED(Mul, product, BINOP, *, 1);
+		FUNC_SUBGROUP_CLUSTERED(Min, min, BINFUNC, , numeric_limits<T>::max());
+		FUNC_SUBGROUP_CLUSTERED(Max, max, BINFUNC, , numeric_limits<T>::min());
+		FUNC_SUBGROUP_CLUSTERED(And, and, BINOP, &, ~T(0));
+		FUNC_SUBGROUP_CLUSTERED(Or, or, BINOP, |, 0);
+		FUNC_SUBGROUP_CLUSTERED(Xor, xor, BINOP, ^, 0);
+			// clang-format on
+#undef FUNC_SUBGROUP_CLUSTERED
+#undef BINOP
+#undef BINFUNC
+
 		case SPVFuncImplQuadBroadcast:
 		case SPVFuncImplQuadBroadcast:
 			statement("template<typename T>");
 			statement("template<typename T>");
 			statement("inline T spvQuadBroadcast(T value, uint lane)");
 			statement("inline T spvQuadBroadcast(T value, uint lane)");
@@ -9126,7 +9246,7 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
 
 
 // If the physical type of a physical buffer pointer has been changed
 // If the physical type of a physical buffer pointer has been changed
 // to a ulong or ulongn vector, add a cast back to the pointer type.
 // to a ulong or ulongn vector, add a cast back to the pointer type.
-void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
+bool CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
 {
 {
 	auto *p_physical_type = maybe_get<SPIRType>(physical_type);
 	auto *p_physical_type = maybe_get<SPIRType>(physical_type);
 	if (p_physical_type &&
 	if (p_physical_type &&
@@ -9137,7 +9257,10 @@ void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *ty
 			expr += ".x";
 			expr += ".x";
 
 
 		expr = join("((", type_to_glsl(*type), ")", expr, ")");
 		expr = join("((", type_to_glsl(*type), ")", expr, ")");
+		return true;
 	}
 	}
+
+	return false;
 }
 }
 
 
 // Override for MSL-specific syntax instructions
 // Override for MSL-specific syntax instructions
@@ -9840,9 +9963,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 
 
 	case OpControlBarrier:
 	case OpControlBarrier:
 		// In GLSL a memory barrier is often followed by a control barrier.
 		// In GLSL a memory barrier is often followed by a control barrier.
-		// But in MSL, memory barriers are also control barriers, so don't
+		// But in MSL, memory barriers are also control barriers (before MSL 3.2), so don't
 		// emit a simple control barrier if a memory barrier has just been emitted.
 		// emit a simple control barrier if a memory barrier has just been emitted.
-		if (previous_instruction_opcode != OpMemoryBarrier)
+		if (previous_instruction_opcode != OpMemoryBarrier || msl_options.supports_msl_version(3, 2))
 			emit_barrier(ops[0], ops[1], ops[2]);
 			emit_barrier(ops[0], ops[1], ops[2]);
 		break;
 		break;
 
 
@@ -10441,10 +10564,20 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
 		return;
 		return;
 
 
 	string bar_stmt;
 	string bar_stmt;
-	if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
-		bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
+
+	if (!id_exe_scope && msl_options.supports_msl_version(3, 2))
+	{
+		// Just took 10 years to get a proper barrier, but hey!
+		bar_stmt = "atomic_thread_fence";
+	}
 	else
 	else
-		bar_stmt = "threadgroup_barrier";
+	{
+		if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
+			bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
+		else
+			bar_stmt = "threadgroup_barrier";
+	}
+
 	bar_stmt += "(";
 	bar_stmt += "(";
 
 
 	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
 	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
@@ -10452,7 +10585,8 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
 	// Use the | operator to combine flags if we can.
 	// Use the | operator to combine flags if we can.
 	if (msl_options.supports_msl_version(1, 2))
 	if (msl_options.supports_msl_version(1, 2))
 	{
 	{
-		string mem_flags = "";
+		string mem_flags;
+
 		// For tesc shaders, this also affects objects in the Output storage class.
 		// For tesc shaders, this also affects objects in the Output storage class.
 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
 		if (is_tesc_shader() ||
 		if (is_tesc_shader() ||
@@ -10493,6 +10627,55 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
 			bar_stmt += "mem_flags::mem_none";
 			bar_stmt += "mem_flags::mem_none";
 	}
 	}
 
 
+	if (!id_exe_scope && msl_options.supports_msl_version(3, 2))
+	{
+		// If there's no device-related memory in the barrier, demote to workgroup scope.
+		// glslang seems to emit device scope even for memoryBarrierShared().
+		if (mem_scope == ScopeDevice &&
+		    (mem_sem & (MemorySemanticsUniformMemoryMask |
+		                MemorySemanticsImageMemoryMask |
+		                MemorySemanticsCrossWorkgroupMemoryMask)) == 0)
+		{
+			mem_scope = ScopeWorkgroup;
+		}
+
+		// MSL 3.2 only supports seq_cst or relaxed.
+		if (mem_sem & (MemorySemanticsAcquireReleaseMask |
+		               MemorySemanticsAcquireMask |
+		               MemorySemanticsReleaseMask |
+		               MemorySemanticsSequentiallyConsistentMask))
+		{
+			bar_stmt += ", memory_order_seq_cst";
+		}
+		else
+		{
+			bar_stmt += ", memory_order_relaxed";
+		}
+
+		switch (mem_scope)
+		{
+		case ScopeDevice:
+			bar_stmt += ", thread_scope_device";
+			break;
+
+		case ScopeWorkgroup:
+			bar_stmt += ", thread_scope_threadgroup";
+			break;
+
+		case ScopeSubgroup:
+			bar_stmt += ", thread_scope_subgroup";
+			break;
+
+		case ScopeInvocation:
+			bar_stmt += ", thread_scope_thread";
+			break;
+
+		default:
+			// The default argument is device, which is conservative.
+			break;
+		}
+	}
+
 	bar_stmt += ");";
 	bar_stmt += ");";
 
 
 	statement(bar_stmt);
 	statement(bar_stmt);
@@ -13663,9 +13846,17 @@ string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
 	return get_type_address_space(type, argument.self, true);
 	return get_type_address_space(type, argument.self, true);
 }
 }
 
 
-bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags)
+bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags) const
+{
+	// Using volatile for coherent pre-3.2 is definitely not correct, but it's something.
+	// MSL 3.2 adds actual coherent qualifiers.
+	return flags.get(DecorationVolatile) ||
+	       (flags.get(DecorationCoherent) && !msl_options.supports_msl_version(3, 2));
+}
+
+bool CompilerMSL::decoration_flags_signal_coherent(const Bitset &flags) const
 {
 {
-	return flags.get(DecorationVolatile) || flags.get(DecorationCoherent);
+	return flags.get(DecorationCoherent) && msl_options.supports_msl_version(3, 2);
 }
 }
 
 
 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
@@ -13677,8 +13868,17 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 	    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
 	    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
 		flags = get_buffer_block_flags(id);
 		flags = get_buffer_block_flags(id);
 	else
 	else
+	{
 		flags = get_decoration_bitset(id);
 		flags = get_decoration_bitset(id);
 
 
+		if (type.basetype == SPIRType::Struct &&
+		    (has_decoration(type.self, DecorationBlock) ||
+		     has_decoration(type.self, DecorationBufferBlock)))
+		{
+			flags.merge_or(ir.get_buffer_block_type_flags(type));
+		}
+	}
+
 	const char *addr_space = nullptr;
 	const char *addr_space = nullptr;
 	switch (type.storage)
 	switch (type.storage)
 	{
 	{
@@ -13687,7 +13887,6 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 		break;
 		break;
 
 
 	case StorageClassStorageBuffer:
 	case StorageClassStorageBuffer:
-	case StorageClassPhysicalStorageBuffer:
 	{
 	{
 		// For arguments from variable pointers, we use the write count deduction, so
 		// For arguments from variable pointers, we use the write count deduction, so
 		// we should not assume any constness here. Only for global SSBOs.
 		// we should not assume any constness here. Only for global SSBOs.
@@ -13695,10 +13894,19 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 		if (!var || has_decoration(type.self, DecorationBlock))
 		if (!var || has_decoration(type.self, DecorationBlock))
 			readonly = flags.get(DecorationNonWritable);
 			readonly = flags.get(DecorationNonWritable);
 
 
+		if (decoration_flags_signal_coherent(flags))
+			readonly = false;
+
 		addr_space = readonly ? "const device" : "device";
 		addr_space = readonly ? "const device" : "device";
 		break;
 		break;
 	}
 	}
 
 
+	case StorageClassPhysicalStorageBuffer:
+		// We cannot fully trust NonWritable coming from glslang due to a bug in buffer_reference handling.
+		// There isn't much gain in emitting const in C++ languages anyway.
+		addr_space = "device";
+		break;
+
 	case StorageClassUniform:
 	case StorageClassUniform:
 	case StorageClassUniformConstant:
 	case StorageClassUniformConstant:
 	case StorageClassPushConstant:
 	case StorageClassPushConstant:
@@ -13787,7 +13995,9 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
 	}
 	}
 
 
-	if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread"))
+	if (decoration_flags_signal_coherent(flags) && strcmp(addr_space, "device") == 0)
+		return join("coherent device");
+	else if (decoration_flags_signal_volatile(flags) && strcmp(addr_space, "thread") != 0)
 		return join("volatile ", addr_space);
 		return join("volatile ", addr_space);
 	else
 	else
 		return addr_space;
 		return addr_space;
@@ -15411,7 +15621,8 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 
 
 	bool constref = !arg.alias_global_variable && !passed_by_value && is_pointer(var_type) && arg.write_count == 0;
 	bool constref = !arg.alias_global_variable && !passed_by_value && is_pointer(var_type) && arg.write_count == 0;
 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
-	if (type_is_msl_framebuffer_fetch(type))
+	// readonly coming from glslang is not reliable in all cases.
+	if (type_is_msl_framebuffer_fetch(type) || type_storage == StorageClassPhysicalStorageBuffer)
 		constref = false;
 		constref = false;
 	else if (type_storage == StorageClassUniformConstant)
 	else if (type_storage == StorageClassUniformConstant)
 		constref = true;
 		constref = true;
@@ -16639,6 +16850,10 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
 	{
 	{
+		auto *p_var = maybe_get_backing_variable(id);
+		if (p_var && p_var->basevariable)
+			p_var = maybe_get<SPIRVariable>(p_var->basevariable);
+
 		switch (img_type.access)
 		switch (img_type.access)
 		{
 		{
 		case AccessQualifierReadOnly:
 		case AccessQualifierReadOnly:
@@ -16655,9 +16870,6 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
 
 
 		default:
 		default:
 		{
 		{
-			auto *p_var = maybe_get_backing_variable(id);
-			if (p_var && p_var->basevariable)
-				p_var = maybe_get<SPIRVariable>(p_var->basevariable);
 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
 			{
 			{
 				img_type_name += ", access::";
 				img_type_name += ", access::";
@@ -16670,6 +16882,9 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
 			break;
 			break;
 		}
 		}
 		}
 		}
+
+		if (p_var && has_decoration(p_var->self, DecorationCoherent) && msl_options.supports_msl_version(3, 2))
+			img_type_name += ", memory_coherence_device";
 	}
 	}
 
 
 	img_type_name += ">";
 	img_type_name += ">";
@@ -16924,11 +17139,10 @@ case OpGroupNonUniform##op: \
 			emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
 			emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
-			/* Only cluster sizes of 4 are supported. */ \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
-			if (cluster_size != 4) \
-				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
-			emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
+			if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
+				add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
+			emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \
 		} \
 		} \
 		else \
 		else \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16953,11 +17167,10 @@ case OpGroupNonUniform##op: \
 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
-			/* Only cluster sizes of 4 are supported. */ \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
-			if (cluster_size != 4) \
-				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
-			emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
+			if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
+				add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
+			emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \
 		} \
 		} \
 		else \
 		else \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16976,11 +17189,10 @@ case OpGroupNonUniform##op: \
 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
-			/* Only cluster sizes of 4 are supported. */ \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
 			uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
-			if (cluster_size != 4) \
-				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
-			emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \
+			if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
+				add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
+			emit_subgroup_cluster_op_cast(result_type, id, cluster_size, ops[op_idx], #msl_op, type, type); \
 		} \
 		} \
 		else \
 		else \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
 			SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16996,9 +17208,11 @@ case OpGroupNonUniform##op: \
 	MSL_GROUP_OP(BitwiseAnd, and)
 	MSL_GROUP_OP(BitwiseAnd, and)
 	MSL_GROUP_OP(BitwiseOr, or)
 	MSL_GROUP_OP(BitwiseOr, or)
 	MSL_GROUP_OP(BitwiseXor, xor)
 	MSL_GROUP_OP(BitwiseXor, xor)
-	MSL_GROUP_OP(LogicalAnd, and)
-	MSL_GROUP_OP(LogicalOr, or)
-	MSL_GROUP_OP(LogicalXor, xor)
+	// Metal doesn't support boolean types in SIMD-group operations, so we
+	// have to emit some casts.
+	MSL_GROUP_OP_CAST(LogicalAnd, and, SPIRType::UShort)
+	MSL_GROUP_OP_CAST(LogicalOr, or, SPIRType::UShort)
+	MSL_GROUP_OP_CAST(LogicalXor, xor, SPIRType::UShort)
 		// clang-format on
 		// clang-format on
 #undef MSL_GROUP_OP
 #undef MSL_GROUP_OP
 #undef MSL_GROUP_OP_CAST
 #undef MSL_GROUP_OP_CAST
@@ -17026,6 +17240,83 @@ case OpGroupNonUniform##op: \
 	register_control_dependent_expression(id);
 	register_control_dependent_expression(id);
 }
 }
 
 
+void CompilerMSL::emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size,
+                                           uint32_t op0, const char *op)
+{
+	if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2))
+	{
+		if (cluster_size == 4)
+		{
+			emit_unary_func_op(result_type, result_id, op0, join("quad_", op).c_str());
+			return;
+		}
+		SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2.");
+	}
+	bool forward = should_forward(op0);
+	emit_op(result_type, result_id,
+	        join("spvClustered_", op, "<", cluster_size, ">(", to_unpacked_expression(op0), ", ",
+	             to_expression(builtin_subgroup_invocation_id_id), ")"),
+	        forward);
+	inherit_expression_dependencies(result_id, op0);
+}
+
+void CompilerMSL::emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size,
+                                                uint32_t op0, const char *op, SPIRType::BaseType input_type,
+                                                SPIRType::BaseType expected_result_type)
+{
+	if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2))
+	{
+		if (cluster_size == 4)
+		{
+			emit_unary_func_op_cast(result_type, result_id, op0, join("quad_", op).c_str(), input_type,
+			                        expected_result_type);
+			return;
+		}
+		SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2.");
+	}
+
+	auto &out_type = get<SPIRType>(result_type);
+	auto &expr_type = expression_type(op0);
+	auto expected_type = out_type;
+
+	// Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends.
+	expected_type.basetype = input_type;
+	expected_type.width = expr_type.width;
+
+	string cast_op;
+	if (expr_type.basetype != input_type)
+	{
+		if (expr_type.basetype == SPIRType::Boolean)
+			cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")");
+		else
+			cast_op = bitcast_glsl(expected_type, op0);
+	}
+	else
+		cast_op = to_unpacked_expression(op0);
+
+	string sg_op = join("spvClustered_", op, "<", cluster_size, ">");
+	string expr;
+	if (out_type.basetype != expected_result_type)
+	{
+		expected_type.basetype = expected_result_type;
+		expected_type.width = out_type.width;
+		if (out_type.basetype == SPIRType::Boolean)
+			expr = type_to_glsl(out_type);
+		else
+			expr = bitcast_glsl_op(out_type, expected_type);
+		expr += '(';
+		expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")");
+		expr += ')';
+	}
+	else
+	{
+		expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")");
+	}
+
+	emit_op(result_type, result_id, expr, should_forward(op0));
+	inherit_expression_dependencies(result_id, op0);
+}
+
 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
 {
 {
 	if (out_type.basetype == in_type.basetype)
 	if (out_type.basetype == in_type.basetype)
@@ -18097,6 +18388,28 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
 		}
 		}
 		break;
 		break;
 
 
+	case OpGroupNonUniformFAdd:
+	case OpGroupNonUniformFMul:
+	case OpGroupNonUniformFMin:
+	case OpGroupNonUniformFMax:
+	case OpGroupNonUniformIAdd:
+	case OpGroupNonUniformIMul:
+	case OpGroupNonUniformSMin:
+	case OpGroupNonUniformSMax:
+	case OpGroupNonUniformUMin:
+	case OpGroupNonUniformUMax:
+	case OpGroupNonUniformBitwiseAnd:
+	case OpGroupNonUniformBitwiseOr:
+	case OpGroupNonUniformBitwiseXor:
+	case OpGroupNonUniformLogicalAnd:
+	case OpGroupNonUniformLogicalOr:
+	case OpGroupNonUniformLogicalXor:
+		if ((compiler.get_execution_model() != ExecutionModelFragment ||
+		     compiler.msl_options.supports_msl_version(2, 2)) &&
+		    args[3] == GroupOperationClusteredReduce)
+			needs_subgroup_invocation_id = true;
+		break;
+
 	case OpArrayLength:
 	case OpArrayLength:
 	{
 	{
 		auto *var = compiler.maybe_get_backing_variable(args[2]);
 		auto *var = compiler.maybe_get_backing_variable(args[2]);

+ 31 - 2
3rdparty/spirv-cross/spirv_msl.hpp

@@ -818,6 +818,29 @@ protected:
 		SPVFuncImplSubgroupShuffleUp,
 		SPVFuncImplSubgroupShuffleUp,
 		SPVFuncImplSubgroupShuffleDown,
 		SPVFuncImplSubgroupShuffleDown,
 		SPVFuncImplSubgroupRotate,
 		SPVFuncImplSubgroupRotate,
+		SPVFuncImplSubgroupClusteredAdd,
+		SPVFuncImplSubgroupClusteredFAdd = SPVFuncImplSubgroupClusteredAdd,
+		SPVFuncImplSubgroupClusteredIAdd = SPVFuncImplSubgroupClusteredAdd,
+		SPVFuncImplSubgroupClusteredMul,
+		SPVFuncImplSubgroupClusteredFMul = SPVFuncImplSubgroupClusteredMul,
+		SPVFuncImplSubgroupClusteredIMul = SPVFuncImplSubgroupClusteredMul,
+		SPVFuncImplSubgroupClusteredMin,
+		SPVFuncImplSubgroupClusteredFMin = SPVFuncImplSubgroupClusteredMin,
+		SPVFuncImplSubgroupClusteredSMin = SPVFuncImplSubgroupClusteredMin,
+		SPVFuncImplSubgroupClusteredUMin = SPVFuncImplSubgroupClusteredMin,
+		SPVFuncImplSubgroupClusteredMax,
+		SPVFuncImplSubgroupClusteredFMax = SPVFuncImplSubgroupClusteredMax,
+		SPVFuncImplSubgroupClusteredSMax = SPVFuncImplSubgroupClusteredMax,
+		SPVFuncImplSubgroupClusteredUMax = SPVFuncImplSubgroupClusteredMax,
+		SPVFuncImplSubgroupClusteredAnd,
+		SPVFuncImplSubgroupClusteredBitwiseAnd = SPVFuncImplSubgroupClusteredAnd,
+		SPVFuncImplSubgroupClusteredLogicalAnd = SPVFuncImplSubgroupClusteredAnd,
+		SPVFuncImplSubgroupClusteredOr,
+		SPVFuncImplSubgroupClusteredBitwiseOr = SPVFuncImplSubgroupClusteredOr,
+		SPVFuncImplSubgroupClusteredLogicalOr = SPVFuncImplSubgroupClusteredOr,
+		SPVFuncImplSubgroupClusteredXor,
+		SPVFuncImplSubgroupClusteredBitwiseXor = SPVFuncImplSubgroupClusteredXor,
+		SPVFuncImplSubgroupClusteredLogicalXor = SPVFuncImplSubgroupClusteredXor,
 		SPVFuncImplQuadBroadcast,
 		SPVFuncImplQuadBroadcast,
 		SPVFuncImplQuadSwap,
 		SPVFuncImplQuadSwap,
 		SPVFuncImplReflectScalar,
 		SPVFuncImplReflectScalar,
@@ -871,6 +894,11 @@ protected:
 	void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
 	void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
 	void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override;
 	void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override;
 	void emit_subgroup_op(const Instruction &i) override;
 	void emit_subgroup_op(const Instruction &i) override;
+	void emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0,
+	                              const char *op);
+	void emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0,
+	                                   const char *op, SPIRType::BaseType input_type,
+	                                   SPIRType::BaseType expected_result_type);
 	std::string to_texture_op(const Instruction &i, bool sparse, bool *forward,
 	std::string to_texture_op(const Instruction &i, bool sparse, bool *forward,
 	                          SmallVector<uint32_t> &inherited_expressions) override;
 	                          SmallVector<uint32_t> &inherited_expressions) override;
 	void emit_fixup() override;
 	void emit_fixup() override;
@@ -1084,7 +1112,8 @@ protected:
 	bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const;
 	bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const;
 	std::string get_argument_address_space(const SPIRVariable &argument);
 	std::string get_argument_address_space(const SPIRVariable &argument);
 	std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false);
 	std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false);
-	static bool decoration_flags_signal_volatile(const Bitset &flags);
+	bool decoration_flags_signal_volatile(const Bitset &flags) const;
+	bool decoration_flags_signal_coherent(const Bitset &flags) const;
 	const char *to_restrict(uint32_t id, bool space);
 	const char *to_restrict(uint32_t id, bool space);
 	SPIRType &get_stage_in_struct_type();
 	SPIRType &get_stage_in_struct_type();
 	SPIRType &get_stage_out_struct_type();
 	SPIRType &get_stage_out_struct_type();
@@ -1154,7 +1183,7 @@ protected:
 	bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
 	bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
 	                                            bool &is_packed) override;
 	                                            bool &is_packed) override;
 	void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
 	void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
-	void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;
+	bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;
 
 
 	bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length);
 	bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length);
 	bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);
 	bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);