Przeglądaj źródła

Updated spirv-cross.

Бранимир Караџић 11 miesięcy temu
rodzic
commit
3434a173a2

+ 9 - 1
3rdparty/spirv-cross/spirv_common.hpp

@@ -578,7 +578,9 @@ struct SPIRType : IVariant
 		// Keep internal types at the end.
 		ControlPointArray,
 		Interpolant,
-		Char
+		Char,
+		// MSL specific type, that is used by 'object'(analog of 'task' from glsl) shader.
+		MeshGridProperties
 	};
 
 	// Scalar/vector/matrix support.
@@ -746,6 +748,10 @@ struct SPIRExpression : IVariant
 	// A list of expressions which this expression depends on.
 	SmallVector<ID> expression_dependencies;
 
+	// Similar as expression dependencies, but does not stop the tracking for force-temporary variables.
+	// We need to know the full chain from store back to any SSA variable.
+	SmallVector<ID> invariance_dependencies;
+
 	// By reading this expression, we implicitly read these expressions as well.
 	// Used by access chain Store and Load since we read multiple expressions in this case.
 	SmallVector<ID> implied_read_expressions;
@@ -1598,6 +1604,8 @@ struct AccessChainMeta
 	bool flattened_struct = false;
 	bool relaxed_precision = false;
 	bool access_meshlet_position_y = false;
+	bool chain_is_builtin = false;
+	spv::BuiltIn builtin = {};
 };
 
 enum ExtendedDecorations

+ 10 - 1
3rdparty/spirv-cross/spirv_cross.cpp

@@ -2569,6 +2569,15 @@ void Compiler::add_active_interface_variable(uint32_t var_id)
 
 void Compiler::inherit_expression_dependencies(uint32_t dst, uint32_t source_expression)
 {
+	auto *ptr_e = maybe_get<SPIRExpression>(dst);
+
+	if (is_position_invariant() && ptr_e && maybe_get<SPIRExpression>(source_expression))
+	{
+		auto &deps = ptr_e->invariance_dependencies;
+		if (std::find(deps.begin(), deps.end(), source_expression) == deps.end())
+			deps.push_back(source_expression);
+	}
+
 	// Don't inherit any expression dependencies if the expression in dst
 	// is not a forwarded temporary.
 	if (forwarded_temporaries.find(dst) == end(forwarded_temporaries) ||
@@ -2577,7 +2586,7 @@ void Compiler::inherit_expression_dependencies(uint32_t dst, uint32_t source_exp
 		return;
 	}
 
-	auto &e = get<SPIRExpression>(dst);
+	auto &e = *ptr_e;
 	auto *phi = maybe_get<SPIRVariable>(source_expression);
 	if (phi && phi->phi_variable)
 	{

+ 4 - 0
3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp

@@ -928,6 +928,8 @@ void ParsedIR::reset_all_of_type(Types type)
 
 void ParsedIR::add_typed_id(Types type, ID id)
 {
+	assert(id < ids.size());
+
 	if (loop_iteration_depth_hard != 0)
 		SPIRV_CROSS_THROW("Cannot add typed ID while looping over it.");
 
@@ -1030,6 +1032,8 @@ ParsedIR::LoopLock &ParsedIR::LoopLock::operator=(LoopLock &&other) SPIRV_CROSS_
 
 void ParsedIR::make_constant_null(uint32_t id, uint32_t type, bool add_to_typed_id_set)
 {
+	assert(id < ids.size());
+
 	auto &constant_type = get<SPIRType>(type);
 
 	if (constant_type.pointer)

+ 45 - 11
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -6438,7 +6438,7 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
 		if (splat)
 		{
 			res += convert_to_string(c.scalar(vector, 0));
-			if (is_legacy())
+			if (is_legacy() && !has_extension("GL_EXT_gpu_shader4"))
 			{
 				// Fake unsigned constant literals with signed ones if possible.
 				// Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed.
@@ -6457,7 +6457,7 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
 				else
 				{
 					res += convert_to_string(c.scalar(vector, i));
-					if (is_legacy())
+					if (is_legacy() && !has_extension("GL_EXT_gpu_shader4"))
 					{
 						// Fake unsigned constant literals with signed ones if possible.
 						// Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed.
@@ -10210,6 +10210,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 	bool pending_array_enclose = false;
 	bool dimension_flatten = false;
 	bool access_meshlet_position_y = false;
+	bool chain_is_builtin = false;
+	spv::BuiltIn chained_builtin = {};
 
 	if (auto *base_expr = maybe_get<SPIRExpression>(base))
 	{
@@ -10367,6 +10369,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 				auto builtin = ir.meta[base].decoration.builtin_type;
 				bool mesh_shader = get_execution_model() == ExecutionModelMeshEXT;
 
+				chain_is_builtin = true;
+				chained_builtin = builtin;
+
 				switch (builtin)
 				{
 				case BuiltInCullDistance:
@@ -10502,6 +10507,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 					{
 						access_meshlet_position_y = true;
 					}
+
+					chain_is_builtin = true;
+					chained_builtin = builtin;
 				}
 				else
 				{
@@ -10721,6 +10729,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 		meta->storage_physical_type = physical_type;
 		meta->relaxed_precision = relaxed_precision;
 		meta->access_meshlet_position_y = access_meshlet_position_y;
+		meta->chain_is_builtin = chain_is_builtin;
+		meta->builtin = chained_builtin;
 	}
 
 	return expr;
@@ -11766,13 +11776,13 @@ void CompilerGLSL::disallow_forwarding_in_expression_chain(const SPIRExpression
 	// Allow trivially forwarded expressions like OpLoad or trivial shuffles,
 	// these will be marked as having suppressed usage tracking.
 	// Our only concern is to make sure arithmetic operations are done in similar ways.
-	if (expression_is_forwarded(expr.self) && !expression_suppresses_usage_tracking(expr.self) &&
-	    forced_invariant_temporaries.count(expr.self) == 0)
+	if (forced_invariant_temporaries.count(expr.self) == 0)
 	{
-		force_temporary_and_recompile(expr.self);
+		if (!expression_suppresses_usage_tracking(expr.self))
+			force_temporary_and_recompile(expr.self);
 		forced_invariant_temporaries.insert(expr.self);
 
-		for (auto &dependent : expr.expression_dependencies)
+		for (auto &dependent : expr.invariance_dependencies)
 			disallow_forwarding_in_expression_chain(get<SPIRExpression>(dependent));
 	}
 }
@@ -12336,6 +12346,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			flattened_structs[ops[1]] = true;
 		if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis)
 			set_decoration(ops[1], DecorationRelaxedPrecision);
+		if (meta.chain_is_builtin)
+			set_decoration(ops[1], DecorationBuiltIn, meta.builtin);
 
 		// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
 		// temporary which could be subject to invalidation.
@@ -13229,13 +13241,24 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		uint32_t op0 = ops[2];
 		uint32_t op1 = ops[3];
 
-		// Needs special handling.
+		auto &out_type = get<SPIRType>(result_type);
+
 		bool forward = should_forward(op0) && should_forward(op1);
-		auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(",
-		                 to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")");
+		string cast_op0, cast_op1;
+		auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, int_type, op0, op1, false);
+
+		// Needs special handling.
+		auto expr = join(cast_op0, " - ", cast_op1, " * ", "(", cast_op0, " / ", cast_op1, ")");
 
 		if (implicit_integer_promotion)
+		{
 			expr = join(type_to_glsl(get<SPIRType>(result_type)), '(', expr, ')');
+		}
+		else if (out_type.basetype != int_type)
+		{
+			expected_type.basetype = int_type;
+			expr = join(bitcast_glsl_op(out_type, expected_type), '(', expr, ')');
+		}
 
 		emit_op(result_type, result_id, expr, forward);
 		inherit_expression_dependencies(result_id, op0);
@@ -15668,7 +15691,16 @@ string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg)
 
 	if (type.pointer)
 	{
-		if (arg.write_count && arg.read_count)
+		// If we're passing around block types to function, we really mean reference in a pointer sense,
+		// but DXC does not like inout for mesh blocks, so workaround that. out is technically not correct,
+		// but it works in practice due to legalization. It's ... not great, but you gotta do what you gotta do.
+		// GLSL will never hit this case since it's not valid.
+		if (type.storage == StorageClassOutput && get_execution_model() == ExecutionModelMeshEXT &&
+		    has_decoration(type.self, DecorationBlock) && is_builtin_type(type) && arg.write_count)
+		{
+			direction = "out ";
+		}
+		else if (arg.write_count && arg.read_count)
 			direction = "inout ";
 		else if (arg.write_count)
 			direction = "out ";
@@ -15945,7 +15977,7 @@ string CompilerGLSL::image_type_glsl(const SPIRType &type, uint32_t id, bool /*m
 	case DimBuffer:
 		if (options.es && options.version < 320)
 			require_extension_internal("GL_EXT_texture_buffer");
-		else if (!options.es && options.version < 300)
+		else if (!options.es && options.version < 140)
 			require_extension_internal("GL_EXT_texture_buffer_object");
 		res += "Buffer";
 		break;
@@ -16488,6 +16520,8 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags)
 	{
 		auto &var = get<SPIRVariable>(v);
 		var.deferred_declaration = false;
+		if (var.storage == StorageClassTaskPayloadWorkgroupEXT)
+			continue;
 
 		if (variable_decl_is_remapped_storage(var, StorageClassWorkgroup))
 		{

+ 82 - 6
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -4775,13 +4775,13 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
 {
 	auto ops = stream(instruction);
 
-	auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
+	uint32_t result_type = ops[0];
+	uint32_t id = ops[1];
+	uint32_t ptr = ops[2];
+
+	auto *chain = maybe_get<SPIRAccessChain>(ptr);
 	if (chain)
 	{
-		uint32_t result_type = ops[0];
-		uint32_t id = ops[1];
-		uint32_t ptr = ops[2];
-
 		auto &type = get<SPIRType>(result_type);
 		bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
 
@@ -4819,7 +4819,36 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
 		}
 	}
 	else
-		CompilerGLSL::emit_instruction(instruction);
+	{
+		// Very special case where we cannot rely on IO lowering.
+		// Mesh shader clip/cull arrays ... Cursed.
+		auto &res_type = get<SPIRType>(result_type);
+		if (get_execution_model() == ExecutionModelMeshEXT &&
+		    has_decoration(ptr, DecorationBuiltIn) &&
+		    (get_decoration(ptr, DecorationBuiltIn) == BuiltInClipDistance ||
+		     get_decoration(ptr, DecorationBuiltIn) == BuiltInCullDistance) &&
+		    is_array(res_type) && !is_array(get<SPIRType>(res_type.parent_type)) &&
+		    to_array_size_literal(res_type) > 1)
+		{
+			track_expression_read(ptr);
+			string load_expr = "{ ";
+			uint32_t num_elements = to_array_size_literal(res_type);
+			for (uint32_t i = 0; i < num_elements; i++)
+			{
+				load_expr += join(to_expression(ptr), ".", index_to_swizzle(i));
+				if (i + 1 < num_elements)
+					load_expr += ", ";
+			}
+			load_expr += " }";
+			emit_op(result_type, id, load_expr, false);
+			register_read(id, ptr, false);
+			inherit_expression_dependencies(id, ptr);
+		}
+		else
+		{
+			CompilerGLSL::emit_instruction(instruction);
+		}
+	}
 }
 
 void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
@@ -6903,3 +6932,50 @@ bool CompilerHLSL::is_user_type_structured(uint32_t id) const
 	}
 	return false;
 }
+
+void CompilerHLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+{
+	// Loading a full array of ClipDistance needs special consideration in mesh shaders
+	// since we cannot lower them by wrapping the variables in global statics.
+	// Fortunately, clip/cull is a proper vector in HLSL so we can lower with simple rvalue casts.
+	if (get_execution_model() != ExecutionModelMeshEXT ||
+	    !has_decoration(target_id, DecorationBuiltIn) ||
+	    !is_array(expr_type))
+	{
+		CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
+		return;
+	}
+
+	auto builtin = BuiltIn(get_decoration(target_id, DecorationBuiltIn));
+	if (builtin != BuiltInClipDistance && builtin != BuiltInCullDistance)
+	{
+		CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
+		return;
+	}
+
+	// Array of array means one thread is storing clip distance for all vertices. Nonsensical?
+	if (is_array(get<SPIRType>(expr_type.parent_type)))
+		SPIRV_CROSS_THROW("Attempting to store all mesh vertices in one go. This is not supported.");
+
+	uint32_t num_clip = to_array_size_literal(expr_type);
+	if (num_clip > 4)
+		SPIRV_CROSS_THROW("Number of clip or cull distances exceeds 4, this will not work with mesh shaders.");
+
+	if (num_clip == 1)
+	{
+		// We already emit array here.
+		CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
+		return;
+	}
+
+	auto unrolled_expr = join("float", num_clip, "(");
+	for (uint32_t i = 0; i < num_clip; i++)
+	{
+		unrolled_expr += join(expr, "[", i, "]");
+		if (i + 1 < num_clip)
+			unrolled_expr += ", ";
+	}
+
+	unrolled_expr += ")";
+	expr = std::move(unrolled_expr);
+}

+ 2 - 0
3rdparty/spirv-cross/spirv_hlsl.hpp

@@ -408,6 +408,8 @@ private:
 	std::vector<TypeID> composite_selection_workaround_types;
 
 	std::string get_inner_entry_point_name() const;
+
+	void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
 };
 } // namespace SPIRV_CROSS_NAMESPACE
 

+ 615 - 35
3rdparty/spirv-cross/spirv_msl.cpp

@@ -202,6 +202,9 @@ uint32_t CompilerMSL::get_resource_array_size(const SPIRType &type, uint32_t id)
 {
 	uint32_t array_size = to_array_size_literal(type);
 
+	if (id == 0)
+		return array_size;
+
 	// If we have argument buffers, we need to honor the ABI by using the correct array size
 	// from the layout. Only use shader declared size if we're not using argument buffers.
 	uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
@@ -269,7 +272,7 @@ void CompilerMSL::build_implicit_builtins()
 	    (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
 	     active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
 	     active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
-	bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
+	bool need_local_invocation_index = (msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId)) || is_mesh_shader();
 	bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
 	bool force_frag_depth_passthrough =
 	    get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input &&
@@ -278,7 +281,7 @@ void CompilerMSL::build_implicit_builtins()
 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
 	    need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
 	    needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation ||
-		has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough)
+	    has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough || is_mesh_shader())
 	{
 		bool has_frag_coord = false;
 		bool has_sample_id = false;
@@ -325,6 +328,13 @@ void CompilerMSL::build_implicit_builtins()
 				}
 			}
 
+			if (builtin == BuiltInPrimitivePointIndicesEXT ||
+			    builtin == BuiltInPrimitiveLineIndicesEXT ||
+			    builtin == BuiltInPrimitiveTriangleIndicesEXT)
+			{
+				builtin_mesh_primitive_indices_id = var.self;
+			}
+
 			if (var.storage != StorageClassInput)
 				return;
 
@@ -1057,6 +1067,53 @@ void CompilerMSL::build_implicit_builtins()
 		set_decoration(var_id, DecorationBuiltIn, BuiltInPosition);
 		mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id);
 	}
+
+	if (is_mesh_shader())
+	{
+		uint32_t offset = ir.increase_bound_by(2);
+		uint32_t type_ptr_id = offset;
+		uint32_t var_id = offset + 1;
+
+		// Create variable to store meshlet size.
+		uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 2);
+		SPIRType uint_type_ptr = get<SPIRType>(type_id);
+		uint_type_ptr.op = OpTypePointer;
+		uint_type_ptr.pointer = true;
+		uint_type_ptr.pointer_depth++;
+		uint_type_ptr.parent_type = type_id;
+		uint_type_ptr.storage = StorageClassWorkgroup;
+
+		auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
+		ptr_type.self = type_id;
+		set<SPIRVariable>(var_id, type_ptr_id, StorageClassWorkgroup);
+		set_name(var_id, "spvMeshSizes");
+		builtin_mesh_sizes_id = var_id;
+	}
+
+	if (get_execution_model() == spv::ExecutionModelTaskEXT)
+	{
+		uint32_t offset = ir.increase_bound_by(3);
+		uint32_t type_id = offset;
+		uint32_t type_ptr_id = offset + 1;
+		uint32_t var_id = offset + 2;
+
+		SPIRType mesh_grid_type { OpTypeStruct };
+		mesh_grid_type.basetype = SPIRType::MeshGridProperties;
+		set<SPIRType>(type_id, mesh_grid_type);
+
+		SPIRType mesh_grid_type_ptr = mesh_grid_type;
+		mesh_grid_type_ptr.op = spv::OpTypePointer;
+		mesh_grid_type_ptr.pointer = true;
+		mesh_grid_type_ptr.pointer_depth++;
+		mesh_grid_type_ptr.parent_type = type_id;
+		mesh_grid_type_ptr.storage = StorageClassOutput;
+
+		auto &ptr_in_type = set<SPIRType>(type_ptr_id, mesh_grid_type_ptr);
+		ptr_in_type.self = type_id;
+		set<SPIRVariable>(var_id, type_ptr_id, StorageClassOutput);
+		set_name(var_id, "spvMgp");
+		builtin_task_grid_id = var_id;
+	}
 }
 
 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
@@ -1509,6 +1566,10 @@ void CompilerMSL::emit_entry_point_declarations()
 		statement(CompilerGLSL::variable_decl(var), ";");
 		var.deferred_declaration = false;
 	}
+
+	// Holds SetMeshOutputsEXT information. Threadgroup since first thread wins.
+	if (processing_entry_point && is_mesh_shader())
+		statement("threadgroup uint2 spvMeshSizes;");
 }
 
 string CompilerMSL::compile()
@@ -1544,6 +1605,8 @@ string CompilerMSL::compile()
 	backend.native_pointers = true;
 	backend.nonuniform_qualifier = "";
 	backend.support_small_type_sampling_result = true;
+	backend.force_merged_mesh_block = false;
+	backend.force_gl_in_out_block = get_execution_model() == ExecutionModelMeshEXT;
 	backend.supports_empty_struct = true;
 	backend.support_64bit_switch = true;
 	backend.boolean_in_struct_remapped_type = SPIRType::Short;
@@ -1559,6 +1622,9 @@ string CompilerMSL::compile()
 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
 
+	if (is_mesh_shader() && !get_entry_point().flags.get(ExecutionModeOutputPoints))
+		msl_options.enable_point_size_builtin = false;
+
 	// Initialize array here rather than constructor, MSVC 2013 workaround.
 	for (auto &id : next_metal_resource_ids)
 		id = 0;
@@ -1566,6 +1632,11 @@ string CompilerMSL::compile()
 	fixup_anonymous_struct_names();
 	fixup_type_alias();
 	replace_illegal_names();
+	if (get_execution_model() == ExecutionModelMeshEXT)
+	{
+		// Emit proxy entry-point for the sake of copy-pass
+		emit_mesh_entry_point();
+	}
 	sync_entry_point_aliases_and_names();
 
 	build_function_control_flow_graphs_and_analyze();
@@ -1617,9 +1688,17 @@ string CompilerMSL::compile()
 	// Create structs to hold input, output and uniform variables.
 	// Do output first to ensure out. is declared at top of entry function.
 	qual_pos_var_name = "";
-	stage_out_var_id = add_interface_block(StorageClassOutput);
-	patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
-	stage_in_var_id = add_interface_block(StorageClassInput);
+	if (is_mesh_shader())
+	{
+		fixup_implicit_builtin_block_names(get_execution_model());
+	}
+	else
+	{
+		stage_out_var_id = add_interface_block(StorageClassOutput);
+		patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
+		stage_in_var_id = add_interface_block(StorageClassInput);
+	}
+
 	if (is_tese_shader())
 		patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
 
@@ -1628,6 +1707,12 @@ string CompilerMSL::compile()
 	if (is_tessellation_shader())
 		stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
 
+	if (is_mesh_shader())
+	{
+		mesh_out_per_vertex = add_meshlet_block(false);
+		mesh_out_per_primitive = add_meshlet_block(true);
+	}
+
 	// Metal vertex functions that define no output must disable rasterization and return void.
 	if (!stage_out_var_id)
 		is_rasterization_disabled = true;
@@ -1762,12 +1847,18 @@ void CompilerMSL::localize_global_variables()
 	{
 		uint32_t v_id = *iter;
 		auto &var = get<SPIRVariable>(v_id);
-		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
+		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup ||
+		    var.storage == StorageClassTaskPayloadWorkgroupEXT)
 		{
 			if (!variable_is_lut(var))
 				entry_func.add_local_variable(v_id);
 			iter = global_variables.erase(iter);
 		}
+		else if (var.storage == StorageClassOutput && is_mesh_shader())
+		{
+			entry_func.add_local_variable(v_id);
+			iter = global_variables.erase(iter);
+		}
 		else
 			iter++;
 	}
@@ -2105,6 +2196,15 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				break;
 			}
 
+			case OpSetMeshOutputsEXT:
+			{
+				if (builtin_local_invocation_index_id != 0)
+					added_arg_ids.insert(builtin_local_invocation_index_id);
+				if (builtin_mesh_sizes_id != 0)
+					added_arg_ids.insert(builtin_mesh_sizes_id);
+				break;
+			}
+
 			default:
 				break;
 			}
@@ -2117,6 +2217,9 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 			// We should consider a more unified system here to reduce boiler-plate.
 			// This kind of analysis is done in several places ...
 		}
+
+		if (b.terminator == SPIRBlock::EmitMeshTasks && builtin_task_grid_id != 0)
+			added_arg_ids.insert(builtin_task_grid_id);
 	}
 
 	function_global_vars[func_id] = added_arg_ids;
@@ -2206,6 +2309,17 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
 				if (is_tese_shader() && msl_options.raw_buffer_tese_input && var.storage == StorageClassInput)
 					set_decoration(next_id, DecorationNonWritable);
 			}
+			else if (is_builtin && is_mesh_shader())
+			{
+				uint32_t next_id = ir.increase_bound_by(1);
+				func.add_parameter(type_id, next_id, true);
+				auto &v = set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
+				v.storage = StorageClassWorkgroup;
+
+				// Ensure the existing variable has a valid name and the new variable has all the same meta info
+				set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
+				ir.meta[next_id] = ir.meta[arg_id];
+			}
 			else if (is_builtin && has_decoration(p_type->self, DecorationBlock))
 			{
 				// Get the pointee type
@@ -4490,6 +4604,42 @@ uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageCla
 	return ib_ptr_var_id;
 }
 
+uint32_t CompilerMSL::add_meshlet_block(bool per_primitive)
+{
+	// Accumulate the variables that should appear in the interface struct.
+	SmallVector<SPIRVariable *> vars;
+
+	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
+		if (var.storage != StorageClassOutput || var.self == builtin_mesh_primitive_indices_id)
+			return;
+		if (is_per_primitive_variable(var) != per_primitive)
+			return;
+		vars.push_back(&var);
+	});
+
+	if (vars.empty())
+		return 0;
+
+	uint32_t next_id = ir.increase_bound_by(1);
+	auto &type = set<SPIRType>(next_id, SPIRType(OpTypeStruct));
+	type.basetype = SPIRType::Struct;
+
+	InterfaceBlockMeta meta;
+	for (auto *p_var : vars)
+	{
+		meta.strip_array = true;
+		meta.allow_local_declaration = false;
+		add_variable_to_interface_block(StorageClassOutput, "", type, *p_var, meta);
+	}
+
+	if (per_primitive)
+		set_name(type.self, "spvPerPrimitive");
+	else
+		set_name(type.self, "spvPerVertex");
+
+	return next_id;
+}
+
 // Ensure that the type is compatible with the builtin.
 // If it is, simply return the given type ID.
 // Otherwise, create a new type, and return it's ID.
@@ -5482,6 +5632,19 @@ void CompilerMSL::emit_custom_templates()
 			begin_scope();
 			statement("return elements[pos];");
 			end_scope();
+			if (get_execution_model() == spv::ExecutionModelMeshEXT ||
+			    get_execution_model() == spv::ExecutionModelTaskEXT)
+			{
+				statement("");
+				statement("object_data T& operator [] (size_t pos) object_data");
+				begin_scope();
+				statement("return elements[pos];");
+				end_scope();
+				statement("constexpr const object_data T& operator [] (size_t pos) const object_data");
+				begin_scope();
+				statement("return elements[pos];");
+				end_scope();
+			}
 			end_scope_decl();
 			statement("");
 			break;
@@ -7609,6 +7772,18 @@ void CompilerMSL::emit_custom_functions()
 			statement("");
 			break;
 
+		case SPVFuncImplSetMeshOutputsEXT:
+			statement("void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)");
+			begin_scope();
+			statement("if (gl_LocalInvocationIndex == 0)");
+			begin_scope();
+			statement("spvMeshSizes.x = vertexCount;");
+			statement("spvMeshSizes.y = primitiveCount;");
+			end_scope();
+			end_scope();
+			statement("");
+			break;
+
 		default:
 			break;
 		}
@@ -7710,6 +7885,23 @@ void CompilerMSL::emit_resources()
 	emit_interface_block(patch_stage_out_var_id);
 	emit_interface_block(stage_in_var_id);
 	emit_interface_block(patch_stage_in_var_id);
+
+	if (get_execution_model() == ExecutionModelMeshEXT)
+	{
+		auto &execution = get_entry_point();
+		const char *topology = "";
+		if (execution.flags.get(ExecutionModeOutputTrianglesEXT))
+			topology = "topology::triangle";
+		else if (execution.flags.get(ExecutionModeOutputLinesEXT))
+			topology = "topology::line";
+		else if (execution.flags.get(ExecutionModeOutputPoints))
+			topology = "topology::point";
+
+		const char *per_primitive = mesh_out_per_primitive ? "spvPerPrimitive" : "void";
+		statement("using spvMesh_t = mesh<", "spvPerVertex, ", per_primitive, ", ", execution.output_vertices, ", ",
+		          execution.output_primitives, ", ", topology, ">;");
+		statement("");
+	}
 }
 
 // Emit declarations for the specialization Metal function constants
@@ -7733,7 +7925,7 @@ void CompilerMSL::emit_specialization_constants_and_structs()
 			mark_scalar_layout_structs(type);
 	});
 
-	bool builtin_block_type_is_required = false;
+	bool builtin_block_type_is_required = is_mesh_shader();
 	// Very special case. If gl_PerVertex is initialized as an array (tessellation)
 	// we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
@@ -9925,6 +10117,14 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		break;
 	}
 
+	case OpSetMeshOutputsEXT:
+	{
+		flush_variable_declaration(builtin_mesh_primitive_indices_id);
+		add_spv_func_and_recompile(SPVFuncImplSetMeshOutputsEXT);
+		statement("spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, ", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");");
+		break;
+	}
+
 	default:
 		CompilerGLSL::emit_instruction(instruction);
 		break;
@@ -9966,8 +10166,13 @@ void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
 
 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
 {
-	if (get_execution_model() != ExecutionModelGLCompute && !is_tesc_shader())
+	auto model = get_execution_model();
+
+	if (model != ExecutionModelGLCompute && model != ExecutionModelTaskEXT &&
+	    model != ExecutionModelMeshEXT && !is_tesc_shader())
+	{
 		return;
+	}
 
 	uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
 	uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
@@ -11008,6 +11213,21 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
 			if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
 				set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
 		}
+
+		// add `taskPayloadSharedEXT` variable to entry-point arguments
+		for (auto &v : func.local_variables)
+		{
+			auto &var = get<SPIRVariable>(v);
+			if (var.storage != StorageClassTaskPayloadWorkgroupEXT)
+				continue;
+
+			add_local_variable_name(v);
+			SPIRFunction::Parameter arg = {};
+			arg.id = v;
+			arg.type = var.basetype;
+			arg.alias_global_variable = true;
+			decl += join(", ", argument_decl(arg), " [[payload]]");
+		}
 	}
 
 	for (auto &arg : func.arguments)
@@ -11325,7 +11545,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 		if (args.has_array_offsets)
 		{
 			forward = forward && should_forward(args.offset);
-			farg_str += ", " + to_expression(args.offset);
+			farg_str += ", " + to_unpacked_expression(args.offset);
 		}
 
 		// Const offsets gather or swizzled gather puts the component before the other args.
@@ -11338,7 +11558,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 
 	// Texture coordinates
 	forward = forward && should_forward(args.coord);
-	auto coord_expr = to_enclosed_expression(args.coord);
+	auto coord_expr = to_enclosed_unpacked_expression(args.coord);
 	auto &coord_type = expression_type(args.coord);
 	bool coord_is_fp = type_is_floating_point(coord_type);
 	bool is_cube_fetch = false;
@@ -11462,14 +11682,14 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 			if (type.basetype != SPIRType::UInt)
 				tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, args.offset), ", 0)");
 			else
-				tex_coords += join(" + uint2(", to_enclosed_expression(args.offset), ", 0)");
+				tex_coords += join(" + uint2(", to_enclosed_unpacked_expression(args.offset), ", 0)");
 		}
 		else
 		{
 			if (type.basetype != SPIRType::UInt)
 				tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.offset);
 			else
-				tex_coords += " + " + to_enclosed_expression(args.offset);
+				tex_coords += " + " + to_enclosed_unpacked_expression(args.offset);
 		}
 	}
 
@@ -11556,10 +11776,10 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 
 		string dref_expr;
 		if (args.base.is_proj)
-			dref_expr = join(to_enclosed_expression(args.dref), " / ",
+			dref_expr = join(to_enclosed_unpacked_expression(args.dref), " / ",
 			                 to_extract_component_expression(args.coord, alt_coord_component));
 		else
-			dref_expr = to_expression(args.dref);
+			dref_expr = to_unpacked_expression(args.dref);
 
 		if (sampling_type_needs_f32_conversion(dref_type))
 			dref_expr = convert_to_f32(dref_expr, 1);
@@ -11610,7 +11830,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 	if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
 	{
 		forward = forward && should_forward(bias);
-		farg_str += ", bias(" + to_expression(bias) + ")";
+		farg_str += ", bias(" + to_unpacked_expression(bias) + ")";
 	}
 
 	// Metal does not support LOD for 1D textures.
@@ -11619,7 +11839,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 		forward = forward && should_forward(lod);
 		if (args.base.is_fetch)
 		{
-			farg_str += ", " + to_expression(lod);
+			farg_str += ", " + to_unpacked_expression(lod);
 		}
 		else if (msl_options.sample_dref_lod_array_as_grad && args.dref && imgtype.image.arrayed)
 		{
@@ -11676,12 +11896,12 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 				extent = "float3(1.0)";
 				break;
 			}
-			farg_str += join(", ", grad_opt, "(", grad_coord, "exp2(", to_expression(lod), " - 0.5) / ", extent,
-			                 ", exp2(", to_expression(lod), " - 0.5) / ", extent, ")");
+			farg_str += join(", ", grad_opt, "(", grad_coord, "exp2(", to_unpacked_expression(lod), " - 0.5) / ", extent,
+			                 ", exp2(", to_unpacked_expression(lod), " - 0.5) / ", extent, ")");
 		}
 		else
 		{
-			farg_str += ", level(" + to_expression(lod) + ")";
+			farg_str += ", level(" + to_unpacked_expression(lod) + ")";
 		}
 	}
 	else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
@@ -11727,7 +11947,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 			grad_opt = "unsupported_gradient_dimension";
 			break;
 		}
-		farg_str += join(", ", grad_opt, "(", grad_coord, to_expression(grad_x), ", ", to_expression(grad_y), ")");
+		farg_str += join(", ", grad_opt, "(", grad_coord, to_unpacked_expression(grad_x), ", ", to_unpacked_expression(grad_y), ")");
 	}
 
 	if (args.min_lod)
@@ -11736,7 +11956,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 			SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
 
 		forward = forward && should_forward(args.min_lod);
-		farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
+		farg_str += ", min_lod_clamp(" + to_unpacked_expression(args.min_lod) + ")";
 	}
 
 	// Add offsets
@@ -11745,7 +11965,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 	if (args.offset && !args.base.is_fetch && !args.has_array_offsets)
 	{
 		forward = forward && should_forward(args.offset);
-		offset_expr = to_expression(args.offset);
+		offset_expr = to_unpacked_expression(args.offset);
 		offset_type = &expression_type(args.offset);
 	}
 
@@ -11811,7 +12031,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool
 	{
 		forward = forward && should_forward(args.sample);
 		farg_str += ", ";
-		farg_str += to_expression(args.sample);
+		farg_str += to_unpacked_expression(args.sample);
 	}
 
 	*p_forward = forward;
@@ -12463,12 +12683,50 @@ string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_
 				((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
 				  variable_storage_requires_stage_io(StorageClassOutput)) ||
 				 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
-				  variable_storage_requires_stage_io(StorageClassInput)));
+				  variable_storage_requires_stage_io(StorageClassInput))) ||
+				is_mesh_shader();
 		if (is_ib_in_out && is_member_builtin(type, index, &builtin))
 			is_using_builtin_array = true;
 		array_type = type_to_array_glsl(physical_type, orig_id);
 	}
 
+	if (is_mesh_shader())
+	{
+		BuiltIn builtin = BuiltInMax;
+		if (is_member_builtin(type, index, &builtin))
+		{
+			if (builtin == BuiltInPrimitiveShadingRateKHR)
+			{
+				// not supported in metal 3.0
+				is_using_builtin_array = false;
+				return "";
+			}
+
+			SPIRType metallic_type = *declared_type;
+			if (builtin == BuiltInCullPrimitiveEXT)
+				metallic_type.basetype = SPIRType::Boolean;
+			else if (builtin == BuiltInPrimitiveId || builtin == BuiltInLayer || builtin == BuiltInViewportIndex)
+				metallic_type.basetype = SPIRType::UInt;
+
+			is_using_builtin_array = true;
+			std::string result;
+			if (has_member_decoration(type.self, orig_id, DecorationBuiltIn))
+			{
+				// avoid '_RESERVED_IDENTIFIER_FIXUP_' in variable name
+				result = join(type_to_glsl(metallic_type, orig_id, false), " ", qualifier,
+				              builtin_to_glsl(builtin, StorageClassOutput), member_attribute_qualifier(type, index),
+				              array_type, ";");
+			}
+			else
+			{
+				result = join(type_to_glsl(metallic_type, orig_id, false), " ", qualifier,
+				              to_member_name(type, index), member_attribute_qualifier(type, index), array_type, ";");
+			}
+			is_using_builtin_array = false;
+			return result;
+		}
+	}
+
 	if (orig_id)
 	{
 		auto *data_type = declared_type;
@@ -12522,6 +12780,16 @@ void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_
 		statement("char _m", index, "_pad", "[", pad_len, "];");
 	}
 
+	BuiltIn builtin = BuiltInMax;
+	if (is_mesh_shader() && is_member_builtin(type, index, &builtin))
+	{
+		if (!has_active_builtin(builtin, StorageClassOutput) && !has_active_builtin(builtin, StorageClassInput))
+		{
+			// Do not emit unused builtins in mesh-output blocks
+			return;
+		}
+	}
+
 	// Handle HLSL-style 0-based vertex/instance index.
 	builtin_declaration = true;
 	statement(to_struct_member(type, member_type_id, index, qualifier));
@@ -12595,9 +12863,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
 	}
 
-	// Vertex and tessellation evaluation function outputs
-	if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader()) &&
-	    type.storage == StorageClassOutput)
+	bool use_semantic_stage_output = is_mesh_shader() || is_tese_shader() ||
+	                                 (execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation);
+
+	// Vertex, mesh and tessellation evaluation function outputs
+	if ((type.storage == StorageClassOutput || is_mesh_shader()) && use_semantic_stage_output)
 	{
 		if (is_builtin)
 		{
@@ -12616,6 +12886,9 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
 				/* fallthrough */
 			case BuiltInPosition:
 			case BuiltInLayer:
+			case BuiltInCullPrimitiveEXT:
+			case BuiltInPrimitiveShadingRateKHR:
+			case BuiltInPrimitiveId:
 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
 
 			case BuiltInClipDistance:
@@ -12790,7 +13063,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
 			{
 				if (!quals.empty())
 					quals += ", ";
-				if (has_member_decoration(type.self, index, DecorationNoPerspective) || builtin == BuiltInBaryCoordNoPerspKHR)
+
+				if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR)
+					SPIRV_CROSS_THROW("Centroid interpolation not supported for barycentrics in MSL.");
+
+				if (has_member_decoration(type.self, index, DecorationNoPerspective))
 					quals += "centroid_no_perspective";
 				else
 					quals += "centroid_perspective";
@@ -12799,7 +13076,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in
 			{
 				if (!quals.empty())
 					quals += ", ";
-				if (has_member_decoration(type.self, index, DecorationNoPerspective) || builtin == BuiltInBaryCoordNoPerspKHR)
+
+				if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR)
+					SPIRV_CROSS_THROW("Sample interpolation not supported for barycentrics in MSL.");
+
+				if (has_member_decoration(type.self, index, DecorationNoPerspective))
 					quals += "sample_no_perspective";
 				else
 					quals += "sample_perspective";
@@ -13078,6 +13359,12 @@ string CompilerMSL::func_type_decl(SPIRType &type)
 	case ExecutionModelKernel:
 		entry_type = "kernel";
 		break;
+	case ExecutionModelMeshEXT:
+		entry_type = "[[mesh]]";
+		break;
+	case ExecutionModelTaskEXT:
+		entry_type = "[[object]]";
+		break;
 	default:
 		entry_type = "unknown";
 		break;
@@ -13096,6 +13383,11 @@ bool CompilerMSL::is_tese_shader() const
 	return get_execution_model() == ExecutionModelTessellationEvaluation;
 }
 
+bool CompilerMSL::is_mesh_shader() const
+{
+	return get_execution_model() == spv::ExecutionModelMeshEXT;
+}
+
 bool CompilerMSL::uses_explicit_early_fragment_test()
 {
 	auto &ep_flags = get_entry_point().flags;
@@ -13211,6 +13503,16 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
 			if (!addr_space)
 				addr_space = "device";
 		}
+
+		if (is_mesh_shader())
+			addr_space = "threadgroup";
+		break;
+
+	case StorageClassTaskPayloadWorkgroupEXT:
+		if (is_mesh_shader())
+			addr_space = "const object_data";
+		else
+			addr_space = "object_data";
 		break;
 
 	default:
@@ -13612,6 +13914,20 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
 			                " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
 		}
 	}
+
+	if (is_mesh_shader())
+	{
+		if (!ep_args.empty())
+			ep_args += ", ";
+		ep_args += join("spvMesh_t spvMesh");
+	}
+
+	if (get_execution_model() == ExecutionModelTaskEXT)
+	{
+		if (!ep_args.empty())
+			ep_args += ", ";
+		ep_args += join("mesh_grid_properties spvMgp");
+	}
 }
 
 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
@@ -14050,6 +14366,14 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
 		});
 	}
 
+	if (is_mesh_shader())
+	{
+		// If shader doesn't call SetMeshOutputsEXT, nothing should be rendered.
+		// No need to barrier after this, because only thread 0 writes to this later.
+		entry_func.fixup_hooks_in.push_back([this]() { statement("if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;"); });
+		entry_func.fixup_hooks_out.push_back([this]() { emit_mesh_outputs(); });
+	}
+
 	// Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
 		auto &type = get_variable_data_type(var);
@@ -14850,7 +15174,8 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
 
 	if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
 		decl = join(cv_qualifier, type_to_glsl(type, arg.id));
-	else if (builtin)
+	else if (builtin && builtin_type != spv::BuiltInPrimitiveTriangleIndicesEXT &&
+	         builtin_type != spv::BuiltInPrimitiveLineIndicesEXT && builtin_type != spv::BuiltInPrimitivePointIndicesEXT)
 	{
 		// Only use templated array for Clip/Cull distance when feasible.
 		// In other scenarios, we need need to override array length for tess levels (if used as outputs),
@@ -15487,6 +15812,9 @@ string CompilerMSL::to_qualifiers_glsl(uint32_t id)
 	auto *var = maybe_get<SPIRVariable>(id);
 	auto &type = expression_type(id);
 
+	if (type.storage == StorageClassTaskPayloadWorkgroupEXT)
+		quals += "object_data ";
+
 	if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)))
 		quals += "threadgroup ";
 
@@ -15671,6 +15999,8 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member)
 		break;
 	case SPIRType::RayQuery:
 		return "raytracing::intersection_query<raytracing::instancing, raytracing::triangle_data>";
+	case SPIRType::MeshGridProperties:
+		return "mesh_grid_properties";
 
 	default:
 		return "unknown_type";
@@ -15785,6 +16115,9 @@ bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable
 			return true;
 		}
 
+		if (is_mesh_shader())
+			return variable.storage == StorageClassOutput;
+
 		return variable.storage == StorageClassOutput && is_tesc_shader() && is_stage_output_variable_masked(variable);
 	}
 	else if (storage == StorageClassStorageBuffer)
@@ -16554,6 +16887,8 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 	case BuiltInLayer:
 		if (is_tesc_shader())
 			break;
+		if (is_mesh_shader())
+			break;
 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
 		    !is_stage_output_builtin_masked(builtin))
 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
@@ -16611,6 +16946,9 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
 		// In SPIR-V 1.6 with Volatile HelperInvocation, we cannot emit a fixup early.
 		return "simd_is_helper_thread()";
 
+	case BuiltInPrimitiveId:
+		return "gl_PrimitiveID";
+
 	default:
 		break;
 	}
@@ -16644,6 +16982,8 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
 	// Vertex function out
 	case BuiltInClipDistance:
 		return "clip_distance";
+	case BuiltInCullDistance:
+		return "cull_distance";
 	case BuiltInPointSize:
 		return "point_size";
 	case BuiltInPosition:
@@ -16691,6 +17031,8 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
 			else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
 				SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
 			return "primitive_id";
+		case ExecutionModelMeshEXT:
+			return "primitive_id";
 		default:
 			SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
 		}
@@ -16720,7 +17062,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
 		// Shouldn't be reached.
 		SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
 	case BuiltInViewIndex:
-		if (execution.model != ExecutionModelFragment)
+		if (execution.model != ExecutionModelFragment && execution.model != ExecutionModelMeshEXT)
 			SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
 		// The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
 		// so we can get it from there.
@@ -16825,6 +17167,9 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
 		return "barycentric_coord";
 
+	case BuiltInCullPrimitiveEXT:
+		return "primitive_culled";
+
 	default:
 		return "unsupported-built-in";
 	}
@@ -16941,6 +17286,13 @@ string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
 	case BuiltInDeviceIndex:
 		return "int";
 
+	case BuiltInPrimitivePointIndicesEXT:
+		return "uint";
+	case BuiltInPrimitiveLineIndicesEXT:
+		return "uint2";
+	case BuiltInPrimitiveTriangleIndicesEXT:
+		return "uint3";
+
 	default:
 		return "unsupported-built-in-type";
 	}
@@ -18452,6 +18804,7 @@ void CompilerMSL::analyze_argument_buffers()
 
 		uint32_t member_index = 0;
 		uint32_t next_arg_buff_index = 0;
+		uint32_t prev_was_scalar_on_array_offset = 0;
 		for (auto &resource : resources)
 		{
 			auto &var = *resource.var;
@@ -18464,7 +18817,9 @@ void CompilerMSL::analyze_argument_buffers()
 			// member_index and next_arg_buff_index are incremented when padding members are added.
 			if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0)
 			{
-				auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
+				auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index - prev_was_scalar_on_array_offset);
+				rez_bind.count -= prev_was_scalar_on_array_offset;
+
 				while (resource.index > next_arg_buff_index)
 				{
 					switch (rez_bind.basetype)
@@ -18503,12 +18858,19 @@ void CompilerMSL::analyze_argument_buffers()
 
 					// After padding, retrieve the resource again. It will either be more padding, or the actual resource.
 					rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
+					prev_was_scalar_on_array_offset = 0;
 				}
 
+				uint32_t count = rez_bind.count;
+
+				// If the current resource is an array in the descriptor, but is a scalar
+				// in the shader, only the first element will be consumed. The next pass
+				// will add a padding member to consume the remaining array elements.
+				if (count > 1 && type.array.empty())
+					count = prev_was_scalar_on_array_offset = 1;
+
 				// Adjust the number of slots consumed by current member itself.
-				// Use the count value from the app, instead of the shader, in case the
-				// shader is only accessing part, or even one element, of the array.
-				next_arg_buff_index += resource.plane_count * rez_bind.count;
+				next_arg_buff_index += resource.plane_count * count;
 			}
 
 			string mbr_name = ensure_valid_name(resource.name, "m");
@@ -18799,6 +19161,224 @@ void CompilerMSL::emit_block_hints(const SPIRBlock &)
 {
 }
 
+void CompilerMSL::emit_mesh_entry_point()
+{
+	auto &ep = get_entry_point();
+	auto &f = get<SPIRFunction>(ir.default_entry_point);
+
+	const uint32_t func_id = ir.increase_bound_by(3);
+	const uint32_t block_id = func_id + 1;
+	const uint32_t ret_id = func_id + 2;
+	auto &wrapped_main = set<SPIRFunction>(func_id, f.return_type, f.function_type);
+
+	wrapped_main.blocks.push_back(block_id);
+	wrapped_main.entry_block = block_id;
+
+	auto &wrapped_entry = set<SPIRBlock>(block_id);
+	wrapped_entry.terminator = SPIRBlock::Return;
+
+	// Push call to original 'main'
+	Instruction ix = {};
+	ix.op = OpFunctionCall;
+	ix.offset = uint32_t(ir.spirv.size());
+	ix.length = 3;
+
+	ir.spirv.push_back(f.return_type);
+	ir.spirv.push_back(ret_id);
+	ir.spirv.push_back(ep.self);
+
+	wrapped_entry.ops.push_back(ix);
+
+	// relace entry-point for new one
+	SPIREntryPoint proxy_ep = ep;
+	proxy_ep.self = func_id;
+	ir.entry_points.insert(std::make_pair(func_id, proxy_ep));
+	ir.meta[func_id] = ir.meta[ir.default_entry_point];
+	ir.meta[ir.default_entry_point].decoration.alias.clear();
+
+	ir.default_entry_point = func_id;
+}
+
+void CompilerMSL::emit_mesh_outputs()
+{
+	auto &mode = get_entry_point();
+
+	// predefined thread count or zero, if specialization constant is in use
+	uint32_t num_invocations = 0;
+	if (mode.workgroup_size.id_x == 0 && mode.workgroup_size.id_y == 0 && mode.workgroup_size.id_z == 0)
+		num_invocations = mode.workgroup_size.x * mode.workgroup_size.y * mode.workgroup_size.z;
+
+	statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
+	statement("if (spvMeshSizes.y == 0)");
+	begin_scope();
+	statement("return;");
+	end_scope();
+	statement("spvMesh.set_primitive_count(spvMeshSizes.y);");
+
+	statement("const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);");
+
+	if (mesh_out_per_vertex != 0)
+	{
+		auto &type_vert = get<SPIRType>(mesh_out_per_vertex);
+
+		if (num_invocations < mode.output_vertices)
+		{
+			statement("for (uint spvVI = gl_LocalInvocationIndex; spvVI < spvMeshSizes.x; spvVI += spvThreadCount)");
+		}
+		else
+		{
+			statement("const uint spvVI = gl_LocalInvocationIndex;");
+			statement("if (gl_LocalInvocationIndex < spvMeshSizes.x)");
+		}
+
+		begin_scope();
+
+		statement("spvPerVertex spvV = {};");
+		for (uint32_t index = 0; index < uint32_t(type_vert.member_types.size()); ++index)
+		{
+			uint32_t orig_var = get_extended_member_decoration(type_vert.self, index, SPIRVCrossDecorationInterfaceOrigID);
+			uint32_t orig_id = get_extended_member_decoration(type_vert.self, index, SPIRVCrossDecorationInterfaceMemberIndex);
+
+			// Clip/cull distances are special-case
+			if (orig_var == 0 && orig_id == (~0u))
+				continue;
+
+			auto &orig = get<SPIRVariable>(orig_var);
+			auto &orig_type = get<SPIRType>(orig.basetype);
+
+			// FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc.
+
+			BuiltIn builtin = BuiltInMax;
+			std::string access;
+			if (orig_type.basetype == SPIRType::Struct)
+			{
+				if (has_member_decoration(orig_type.self, orig_id, DecorationBuiltIn))
+					builtin = BuiltIn(get_member_decoration(orig_type.self, orig_id, DecorationBuiltIn));
+
+				switch (builtin)
+				{
+				case BuiltInPosition:
+				case BuiltInPointSize:
+				case BuiltInClipDistance:
+				case BuiltInCullDistance:
+					access = "." + builtin_to_glsl(builtin, StorageClassOutput);
+					break;
+				default:
+					access = "." + to_member_name(orig_type, orig_id);
+					break;
+				}
+
+				if (has_member_decoration(type_vert.self, index, DecorationIndex))
+				{
+					// Declare the Clip/CullDistance as [[user(clip/cullN)]].
+					const uint32_t orig_index = get_member_decoration(type_vert.self, index, DecorationIndex);
+					access += "[" + to_string(orig_index) + "]";
+					statement("spvV.", builtin_to_glsl(builtin, StorageClassOutput), "[", orig_index, "] = ", to_name(orig_var), "[spvVI]", access, ";");
+				}
+			}
+
+			statement("spvV.", to_member_name(type_vert, index), " = ", to_name(orig_var), "[spvVI]", access, ";");
+			if (options.vertex.flip_vert_y && builtin == BuiltInPosition)
+			{
+				statement("spvV.", to_member_name(type_vert, index), ".y = -(", "spvV.",
+				          to_member_name(type_vert, index), ".y);", "    // Invert Y-axis for Metal");
+			}
+		}
+		statement("spvMesh.set_vertex(spvVI, spvV);");
+		end_scope();
+	}
+
+	if (mesh_out_per_primitive != 0 || builtin_mesh_primitive_indices_id != 0)
+	{
+		if (num_invocations < mode.output_primitives)
+		{
+			statement("for (uint spvPI = gl_LocalInvocationIndex; spvPI < spvMeshSizes.y; spvPI += spvThreadCount)");
+		}
+		else
+		{
+			statement("const uint spvPI = gl_LocalInvocationIndex;");
+			statement("if (gl_LocalInvocationIndex < spvMeshSizes.y)");
+		}
+
+		// FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc.
+
+		begin_scope();
+
+		if (builtin_mesh_primitive_indices_id != 0)
+		{
+			if (mode.flags.get(ExecutionModeOutputTrianglesEXT))
+			{
+				statement("spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);");
+				statement("spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);");
+				statement("spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);");
+			}
+			else if (mode.flags.get(ExecutionModeOutputLinesEXT))
+			{
+				statement("spvMesh.set_index(spvPI * 2u + 0u, gl_PrimitiveLineIndicesEXT[spvPI].x);");
+				statement("spvMesh.set_index(spvPI * 2u + 1u, gl_PrimitiveLineIndicesEXT[spvPI].y);");
+			}
+			else
+			{
+				statement("spvMesh.set_index(spvPI, gl_PrimitivePointIndicesEXT[spvPI]);");
+			}
+		}
+
+		if (mesh_out_per_primitive != 0)
+		{
+			auto &type_prim = get<SPIRType>(mesh_out_per_primitive);
+			statement("spvPerPrimitive spvP = {};");
+			for (uint32_t index = 0; index < uint32_t(type_prim.member_types.size()); ++index)
+			{
+				uint32_t orig_var =
+						get_extended_member_decoration(type_prim.self, index, SPIRVCrossDecorationInterfaceOrigID);
+				uint32_t orig_id =
+						get_extended_member_decoration(type_prim.self, index, SPIRVCrossDecorationInterfaceMemberIndex);
+				auto &orig = get<SPIRVariable>(orig_var);
+				auto &orig_type = get<SPIRType>(orig.basetype);
+
+				BuiltIn builtin = BuiltInMax;
+				std::string access;
+				if (orig_type.basetype == SPIRType::Struct)
+				{
+					if (has_member_decoration(orig_type.self, orig_id, DecorationBuiltIn))
+						builtin = BuiltIn(get_member_decoration(orig_type.self, orig_id, DecorationBuiltIn));
+
+					switch (builtin)
+					{
+					case BuiltInPrimitiveId:
+					case BuiltInLayer:
+					case BuiltInViewportIndex:
+					case BuiltInCullPrimitiveEXT:
+					case BuiltInPrimitiveShadingRateKHR:
+						access = "." + builtin_to_glsl(builtin, StorageClassOutput);
+						break;
+					default:
+						access = "." + to_member_name(orig_type, orig_id);
+					}
+				}
+				statement("spvP.", to_member_name(type_prim, index), " = ", to_name(orig_var), "[spvPI]", access, ";");
+			}
+			statement("spvMesh.set_primitive(spvPI, spvP);");
+		}
+
+		end_scope();
+	}
+}
+
+void CompilerMSL::emit_mesh_tasks(SPIRBlock &block)
+{
+	// GLSL: Once this instruction is called, the workgroup must be terminated immediately, and the mesh shaders are launched.
+	// TODO: find relieble and clean of terminating shader.
+	flush_variable_declaration(builtin_task_grid_id);
+	statement("spvMgp.set_threadgroups_per_grid(uint3(", to_unpacked_expression(block.mesh.groups[0]), ", ",
+	          to_unpacked_expression(block.mesh.groups[1]), ", ", to_unpacked_expression(block.mesh.groups[2]), "));");
+	// This is correct if EmitMeshTasks is called in the entry function for shader.
+	// Only viable solutions would be:
+	// - Caller ensures the SPIR-V is inlined, then this always holds true.
+	// - Pass down a "should terminate" bool to leaf functions and chain return (horrible and disgusting, let's not).
+	statement("return;");
+}
+
 string CompilerMSL::additional_fixed_sample_mask_str() const
 {
 	char print_buffer[32];

+ 13 - 0
3rdparty/spirv-cross/spirv_msl.hpp

@@ -840,6 +840,7 @@ protected:
 		SPVFuncImplImageFence,
 		SPVFuncImplTextureCast,
 		SPVFuncImplMulExtended,
+		SPVFuncImplSetMeshOutputsEXT,
 	};
 
 	// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
@@ -868,6 +869,9 @@ protected:
 	std::string type_to_glsl(const SPIRType &type, uint32_t id, bool member);
 	std::string type_to_glsl(const SPIRType &type, uint32_t id = 0) override;
 	void emit_block_hints(const SPIRBlock &block) override;
+	void emit_mesh_entry_point();
+	void emit_mesh_outputs();
+	void emit_mesh_tasks(SPIRBlock &block) override;
 
 	// Allow Metal to use the array<T> template to make arrays a value type
 	std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id) override;
@@ -919,6 +923,7 @@ protected:
 
 	bool is_tesc_shader() const;
 	bool is_tese_shader() const;
+	bool is_mesh_shader() const;
 
 	void preprocess_op_codes();
 	void localize_global_variables();
@@ -933,6 +938,7 @@ protected:
 	                                            std::unordered_set<uint32_t> &processed_func_ids);
 	uint32_t add_interface_block(spv::StorageClass storage, bool patch = false);
 	uint32_t add_interface_block_pointer(uint32_t ib_var_id, spv::StorageClass storage);
+	uint32_t add_meshlet_block(bool per_primitive);
 
 	struct InterfaceBlockMeta
 	{
@@ -1104,12 +1110,17 @@ protected:
 	uint32_t builtin_stage_input_size_id = 0;
 	uint32_t builtin_local_invocation_index_id = 0;
 	uint32_t builtin_workgroup_size_id = 0;
+	uint32_t builtin_mesh_primitive_indices_id = 0;
+	uint32_t builtin_mesh_sizes_id = 0;
+	uint32_t builtin_task_grid_id = 0;
 	uint32_t builtin_frag_depth_id = 0;
 	uint32_t swizzle_buffer_id = 0;
 	uint32_t buffer_size_buffer_id = 0;
 	uint32_t view_mask_buffer_id = 0;
 	uint32_t dynamic_offsets_buffer_id = 0;
 	uint32_t uint_type_id = 0;
+	uint32_t shared_uint_type_id = 0;
+	uint32_t meshlet_type_id = 0;
 	uint32_t argument_buffer_padding_buffer_type_id = 0;
 	uint32_t argument_buffer_padding_image_type_id = 0;
 	uint32_t argument_buffer_padding_sampler_type_id = 0;
@@ -1174,6 +1185,8 @@ protected:
 	VariableID stage_out_ptr_var_id = 0;
 	VariableID tess_level_inner_var_id = 0;
 	VariableID tess_level_outer_var_id = 0;
+	VariableID mesh_out_per_vertex = 0;
+	VariableID mesh_out_per_primitive = 0;
 	VariableID stage_out_masked_builtin_type_id = 0;
 
 	// Handle HLSL-style 0-based vertex/instance index.