Procházet zdrojové kódy

Updated spirv-cross.

Бранимир Караџић před 4 roky
rodič
revize
d2a16db33a

+ 92 - 84
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -686,15 +686,9 @@ string CompilerGLSL::compile()
 		statement("void main()");
 		statement("void main()");
 		begin_scope();
 		begin_scope();
 		statement("// Interlocks were used in a way not compatible with GLSL, this is very slow.");
 		statement("// Interlocks were used in a way not compatible with GLSL, this is very slow.");
-		if (options.es)
-			statement("beginInvocationInterlockNV();");
-		else
-			statement("beginInvocationInterlockARB();");
+		statement("SPIRV_Cross_beginInvocationInterlock();");
 		statement("spvMainInterlockedBody();");
 		statement("spvMainInterlockedBody();");
-		if (options.es)
-			statement("endInvocationInterlockNV();");
-		else
-			statement("endInvocationInterlockARB();");
+		statement("SPIRV_Cross_endInvocationInterlock();");
 		end_scope();
 		end_scope();
 	}
 	}
 
 
@@ -784,10 +778,12 @@ void CompilerGLSL::emit_header()
 		require_extension_internal("GL_ARB_post_depth_coverage");
 		require_extension_internal("GL_ARB_post_depth_coverage");
 
 
 	// Needed for: layout({pixel,sample}_interlock_[un]ordered) in;
 	// Needed for: layout({pixel,sample}_interlock_[un]ordered) in;
-	if (execution.flags.get(ExecutionModePixelInterlockOrderedEXT) ||
-	    execution.flags.get(ExecutionModePixelInterlockUnorderedEXT) ||
-	    execution.flags.get(ExecutionModeSampleInterlockOrderedEXT) ||
-	    execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT))
+	bool interlock_used = execution.flags.get(ExecutionModePixelInterlockOrderedEXT) ||
+	                      execution.flags.get(ExecutionModePixelInterlockUnorderedEXT) ||
+	                      execution.flags.get(ExecutionModeSampleInterlockOrderedEXT) ||
+	                      execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT);
+
+	if (interlock_used)
 	{
 	{
 		if (options.es)
 		if (options.es)
 		{
 		{
@@ -876,6 +872,24 @@ void CompilerGLSL::emit_header()
 			statement("#define SPIRV_CROSS_LOOP");
 			statement("#define SPIRV_CROSS_LOOP");
 			statement("#endif");
 			statement("#endif");
 		}
 		}
+		else if (ext == "GL_NV_fragment_shader_interlock")
+		{
+			statement("#extension GL_NV_fragment_shader_interlock : require");
+			statement("#define SPIRV_Cross_beginInvocationInterlock() beginInvocationInterlockNV()");
+			statement("#define SPIRV_Cross_endInvocationInterlock() endInvocationInterlockNV()");
+		}
+		else if (ext == "GL_ARB_fragment_shader_interlock")
+		{
+			statement("#ifdef GL_ARB_fragment_shader_interlock");
+			statement("#extension GL_ARB_fragment_shader_interlock : enable");
+			statement("#define SPIRV_Cross_beginInvocationInterlock() beginInvocationInterlockARB()");
+			statement("#define SPIRV_Cross_endInvocationInterlock() endInvocationInterlockARB()");
+			statement("#elif defined(GL_INTEL_fragment_shader_ordering)");
+			statement("#extension GL_INTEL_fragment_shader_ordering : enable");
+			statement("#define SPIRV_Cross_beginInvocationInterlock() beginFragmentShaderOrderingINTEL()");
+			statement("#define SPIRV_Cross_endInvocationInterlock()");
+			statement("#endif");
+		}
 		else
 		else
 			statement("#extension ", ext, " : require");
 			statement("#extension ", ext, " : require");
 	}
 	}
@@ -1056,14 +1070,24 @@ void CompilerGLSL::emit_header()
 		if (execution.flags.get(ExecutionModePostDepthCoverage))
 		if (execution.flags.get(ExecutionModePostDepthCoverage))
 			inputs.push_back("post_depth_coverage");
 			inputs.push_back("post_depth_coverage");
 
 
+		if (interlock_used)
+			statement("#if defined(GL_ARB_fragment_shader_interlock)");
+
 		if (execution.flags.get(ExecutionModePixelInterlockOrderedEXT))
 		if (execution.flags.get(ExecutionModePixelInterlockOrderedEXT))
-			inputs.push_back("pixel_interlock_ordered");
+			statement("layout(pixel_interlock_ordered) in;");
 		else if (execution.flags.get(ExecutionModePixelInterlockUnorderedEXT))
 		else if (execution.flags.get(ExecutionModePixelInterlockUnorderedEXT))
-			inputs.push_back("pixel_interlock_unordered");
+			statement("layout(pixel_interlock_unordered) in;");
 		else if (execution.flags.get(ExecutionModeSampleInterlockOrderedEXT))
 		else if (execution.flags.get(ExecutionModeSampleInterlockOrderedEXT))
-			inputs.push_back("sample_interlock_ordered");
+			statement("layout(sample_interlock_ordered) in;");
 		else if (execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT))
 		else if (execution.flags.get(ExecutionModeSampleInterlockUnorderedEXT))
-			inputs.push_back("sample_interlock_unordered");
+			statement("layout(sample_interlock_unordered) in;");
+
+		if (interlock_used)
+		{
+			statement("#elif !defined(GL_INTEL_fragment_shader_ordering)");
+			statement("#error Fragment Shader Interlock/Ordering extension missing!");
+			statement("#endif");
+		}
 
 
 		if (!options.es && execution.flags.get(ExecutionModeDepthGreater))
 		if (!options.es && execution.flags.get(ExecutionModeDepthGreater))
 			statement("layout(depth_greater) out float gl_FragDepth;");
 			statement("layout(depth_greater) out float gl_FragDepth;");
@@ -4375,19 +4399,7 @@ string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expressio
 
 
 string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read)
 string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read)
 {
 {
-	// If we need to transpose, it will also take care of unpacking rules.
-	auto *e = maybe_get<SPIRExpression>(id);
-	bool need_transpose = e && e->need_transpose;
-	bool is_remapped = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID);
-	bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked);
-	if (!need_transpose && (is_remapped || is_packed))
-	{
-		return unpack_expression_type(to_expression(id, register_expression_read), expression_type(id),
-		                              get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID),
-		                              has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), false);
-	}
-	else
-		return to_enclosed_expression(id, register_expression_read);
+	return enclose_expression(to_unpacked_expression(id, register_expression_read));
 }
 }
 
 
 string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
 string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
@@ -6263,48 +6275,49 @@ bool CompilerGLSL::to_trivial_mix_op(const SPIRType &type, string &op, uint32_t
 	if (!backend.use_constructor_splatting && value_type.vecsize != lerptype.vecsize)
 	if (!backend.use_constructor_splatting && value_type.vecsize != lerptype.vecsize)
 		return false;
 		return false;
 
 
+	// Only valid way in SPIR-V 1.4 to use matrices in select is a scalar select.
+	// matrix(scalar) constructor fills in diagnonals, so gets messy very quickly.
+	// Just avoid this case.
+	if (value_type.columns > 1)
+		return false;
+
 	// If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor.
 	// If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor.
 	bool ret = true;
 	bool ret = true;
-	for (uint32_t col = 0; col < value_type.columns; col++)
+	for (uint32_t row = 0; ret && row < value_type.vecsize; row++)
 	{
 	{
-		for (uint32_t row = 0; row < value_type.vecsize; row++)
+		switch (type.basetype)
 		{
 		{
-			switch (type.basetype)
-			{
-			case SPIRType::Short:
-			case SPIRType::UShort:
-				ret = cleft->scalar_u16(col, row) == 0 && cright->scalar_u16(col, row) == 1;
-				break;
-
-			case SPIRType::Int:
-			case SPIRType::UInt:
-				ret = cleft->scalar(col, row) == 0 && cright->scalar(col, row) == 1;
-				break;
+		case SPIRType::Short:
+		case SPIRType::UShort:
+			ret = cleft->scalar_u16(0, row) == 0 && cright->scalar_u16(0, row) == 1;
+			break;
 
 
-			case SPIRType::Half:
-				ret = cleft->scalar_f16(col, row) == 0.0f && cright->scalar_f16(col, row) == 1.0f;
-				break;
+		case SPIRType::Int:
+		case SPIRType::UInt:
+			ret = cleft->scalar(0, row) == 0 && cright->scalar(0, row) == 1;
+			break;
 
 
-			case SPIRType::Float:
-				ret = cleft->scalar_f32(col, row) == 0.0f && cright->scalar_f32(col, row) == 1.0f;
-				break;
+		case SPIRType::Half:
+			ret = cleft->scalar_f16(0, row) == 0.0f && cright->scalar_f16(0, row) == 1.0f;
+			break;
 
 
-			case SPIRType::Double:
-				ret = cleft->scalar_f64(col, row) == 0.0 && cright->scalar_f64(col, row) == 1.0;
-				break;
+		case SPIRType::Float:
+			ret = cleft->scalar_f32(0, row) == 0.0f && cright->scalar_f32(0, row) == 1.0f;
+			break;
 
 
-			case SPIRType::Int64:
-			case SPIRType::UInt64:
-				ret = cleft->scalar_u64(col, row) == 0 && cright->scalar_u64(col, row) == 1;
-				break;
+		case SPIRType::Double:
+			ret = cleft->scalar_f64(0, row) == 0.0 && cright->scalar_f64(0, row) == 1.0;
+			break;
 
 
-			default:
-				return false;
-			}
-		}
+		case SPIRType::Int64:
+		case SPIRType::UInt64:
+			ret = cleft->scalar_u64(0, row) == 0 && cright->scalar_u64(0, row) == 1;
+			break;
 
 
-		if (!ret)
+		default:
+			ret = false;
 			break;
 			break;
+		}
 	}
 	}
 
 
 	if (ret)
 	if (ret)
@@ -8524,7 +8537,7 @@ const char *CompilerGLSL::index_to_swizzle(uint32_t index)
 	case 3:
 	case 3:
 		return "w";
 		return "w";
 	default:
 	default:
-		SPIRV_CROSS_THROW("Swizzle index out of range");
+		return "x";		// Don't crash, but engage the "undefined behavior" described for out-of-bounds logical addressing in spec.
 	}
 	}
 }
 }
 
 
@@ -8540,7 +8553,7 @@ void CompilerGLSL::access_chain_internal_append_index(std::string &expr, uint32_
 	if (index_is_literal)
 	if (index_is_literal)
 		expr += convert_to_string(index);
 		expr += convert_to_string(index);
 	else
 	else
-		expr += to_expression(index, register_expression_read);
+		expr += to_unpacked_expression(index, register_expression_read);
 
 
 	expr += "]";
 	expr += "]";
 }
 }
@@ -8748,13 +8761,16 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 
 
 			access_chain_is_arrayed = true;
 			access_chain_is_arrayed = true;
 		}
 		}
-		// For structs, the index refers to a constant, which indexes into the members.
+		// For structs, the index refers to a constant, which indexes into the members, possibly through a redirection mapping.
 		// We also check if this member is a builtin, since we then replace the entire expression with the builtin one.
 		// We also check if this member is a builtin, since we then replace the entire expression with the builtin one.
 		else if (type->basetype == SPIRType::Struct)
 		else if (type->basetype == SPIRType::Struct)
 		{
 		{
 			if (!is_literal)
 			if (!is_literal)
 				index = evaluate_constant_u32(index);
 				index = evaluate_constant_u32(index);
 
 
+			if (index < uint32_t(type->member_type_index_redirection.size()))
+				index = type->member_type_index_redirection[index];
+
 			if (index >= type->member_types.size())
 			if (index >= type->member_types.size())
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
 
 
@@ -8804,7 +8820,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 			if (is_literal)
 			if (is_literal)
 				expr += convert_to_string(index);
 				expr += convert_to_string(index);
 			else
 			else
-				expr += to_expression(index, register_expression_read);
+				expr += to_unpacked_expression(index, register_expression_read);
 			expr += "]";
 			expr += "]";
 
 
 			type_id = type->parent_type;
 			type_id = type->parent_type;
@@ -8886,7 +8902,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 			else
 			else
 			{
 			{
 				expr += "[";
 				expr += "[";
-				expr += to_expression(index, register_expression_read);
+				expr += to_unpacked_expression(index, register_expression_read);
 				expr += "]";
 				expr += "]";
 			}
 			}
 
 
@@ -9878,7 +9894,7 @@ void CompilerGLSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_ex
 				convert_non_uniform_expression(lhs, lhs_expression);
 				convert_non_uniform_expression(lhs, lhs_expression);
 
 
 			// We might need to cast in order to store to a builtin.
 			// We might need to cast in order to store to a builtin.
-			cast_to_builtin_store(lhs_expression, rhs, expression_type(rhs_expression));
+			cast_to_variable_store(lhs_expression, rhs, expression_type(rhs_expression));
 
 
 			// Tries to optimize assignments like "<lhs> = <lhs> op expr".
 			// Tries to optimize assignments like "<lhs> = <lhs> op expr".
 			// While this is purely cosmetic, this is important for legacy ESSL where loop
 			// While this is purely cosmetic, this is important for legacy ESSL where loop
@@ -10043,7 +10059,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
 			expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
 
 
 		// We might need to cast in order to load from a builtin.
 		// We might need to cast in order to load from a builtin.
-		cast_from_builtin_load(ptr, expr, type);
+		cast_from_variable_load(ptr, expr, type);
 
 
 		// We might be trying to load a gl_Position[N], where we should be
 		// We might be trying to load a gl_Position[N], where we should be
 		// doing float4[](gl_in[i].gl_Position, ...) instead.
 		// doing float4[](gl_in[i].gl_Position, ...) instead.
@@ -12675,11 +12691,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		// If the interlock is complex, we emit this elsewhere.
 		// If the interlock is complex, we emit this elsewhere.
 		if (!interlocked_is_complex)
 		if (!interlocked_is_complex)
 		{
 		{
-			if (options.es)
-				statement("beginInvocationInterlockNV();");
-			else
-				statement("beginInvocationInterlockARB();");
-
+			statement("SPIRV_Cross_beginInvocationInterlock();");
 			flush_all_active_variables();
 			flush_all_active_variables();
 			// Make sure forwarding doesn't propagate outside interlock region.
 			// Make sure forwarding doesn't propagate outside interlock region.
 		}
 		}
@@ -12689,11 +12701,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 		// If the interlock is complex, we emit this elsewhere.
 		// If the interlock is complex, we emit this elsewhere.
 		if (!interlocked_is_complex)
 		if (!interlocked_is_complex)
 		{
 		{
-			if (options.es)
-				statement("endInvocationInterlockNV();");
-			else
-				statement("endInvocationInterlockARB();");
-
+			statement("SPIRV_Cross_endInvocationInterlock();");
 			flush_all_active_variables();
 			flush_all_active_variables();
 			// Make sure forwarding doesn't propagate outside interlock region.
 			// Make sure forwarding doesn't propagate outside interlock region.
 		}
 		}
@@ -13097,7 +13105,7 @@ string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg)
 
 
 string CompilerGLSL::to_initializer_expression(const SPIRVariable &var)
 string CompilerGLSL::to_initializer_expression(const SPIRVariable &var)
 {
 {
-	return to_expression(var.initializer);
+	return to_unpacked_expression(var.initializer);
 }
 }
 
 
 string CompilerGLSL::to_zero_initialized_expression(uint32_t type_id)
 string CompilerGLSL::to_zero_initialized_expression(uint32_t type_id)
@@ -13145,7 +13153,7 @@ string CompilerGLSL::variable_decl(const SPIRVariable &variable)
 	{
 	{
 		uint32_t expr = variable.static_expression;
 		uint32_t expr = variable.static_expression;
 		if (ir.ids[expr].get_type() != TypeUndef)
 		if (ir.ids[expr].get_type() != TypeUndef)
-			res += join(" = ", to_expression(variable.static_expression));
+			res += join(" = ", to_unpacked_expression(variable.static_expression));
 		else if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
 		else if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
 			res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable)));
 			res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable)));
 	}
 	}
@@ -14930,7 +14938,7 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
 		else
 		else
 		{
 		{
 			emit_block_hints(block);
 			emit_block_hints(block);
-			statement("switch (", to_expression(block.condition), ")");
+			statement("switch (", to_unpacked_expression(block.condition), ")");
 		}
 		}
 		begin_scope();
 		begin_scope();
 
 
@@ -15066,7 +15074,7 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
 			{
 			{
 				// OpReturnValue can return Undef, so don't emit anything for this case.
 				// OpReturnValue can return Undef, so don't emit anything for this case.
 				if (ir.ids[block.return_value].get_type() != TypeUndef)
 				if (ir.ids[block.return_value].get_type() != TypeUndef)
-					statement("return ", to_expression(block.return_value), ";");
+					statement("return ", to_unpacked_expression(block.return_value), ";");
 			}
 			}
 		}
 		}
 		else if (!cfg.node_terminates_control_flow_in_sub_graph(current_function->entry_block, block.self) ||
 		else if (!cfg.node_terminates_control_flow_in_sub_graph(current_function->entry_block, block.self) ||
@@ -15085,7 +15093,7 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
 	case SPIRBlock::Kill:
 	case SPIRBlock::Kill:
 		statement(backend.discard_literal, ";");
 		statement(backend.discard_literal, ";");
 		if (block.return_value)
 		if (block.return_value)
-			statement("return ", to_expression(block.return_value), ";");
+			statement("return ", to_unpacked_expression(block.return_value), ";");
 		break;
 		break;
 
 
 	case SPIRBlock::Unreachable:
 	case SPIRBlock::Unreachable:
@@ -15380,7 +15388,7 @@ void CompilerGLSL::unroll_array_from_complex_load(uint32_t target_id, uint32_t s
 	}
 	}
 }
 }
 
 
-void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
 {
 {
 	// We will handle array cases elsewhere.
 	// We will handle array cases elsewhere.
 	if (!expr_type.array.empty())
 	if (!expr_type.array.empty())
@@ -15439,7 +15447,7 @@ void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr,
 		expr = bitcast_expression(expr_type, expected_type, expr);
 		expr = bitcast_expression(expr_type, expected_type, expr);
 }
 }
 
 
-void CompilerGLSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
 {
 {
 	auto *var = maybe_get_backing_variable(target_id);
 	auto *var = maybe_get_backing_variable(target_id);
 	if (var)
 	if (var)

+ 2 - 2
3rdparty/spirv-cross/spirv_glsl.hpp

@@ -903,8 +903,8 @@ protected:
 	// Builtins in GLSL are always specific signedness, but the SPIR-V can declare them
 	// Builtins in GLSL are always specific signedness, but the SPIR-V can declare them
 	// as either unsigned or signed.
 	// as either unsigned or signed.
 	// Sometimes we will need to automatically perform casts on load and store to make this work.
 	// Sometimes we will need to automatically perform casts on load and store to make this work.
-	virtual void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
-	virtual void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
+	virtual void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
+	virtual void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
 	void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr);
 	void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr);
 	bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id);
 	bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id);
 	void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id);
 	void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id);

+ 171 - 26
3rdparty/spirv-cross/spirv_msl.cpp

@@ -1512,6 +1512,14 @@ void CompilerMSL::preprocess_op_codes()
 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
 	                          (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
 	                          (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
 		needs_sample_id = true;
 		needs_sample_id = true;
+
+	if (is_intersection_query())
+	{
+		add_header_line("#if __METAL_VERSION__ >= 230");
+		add_header_line("#include <metal_raytracing>");
+		add_header_line("using namespace metal::raytracing;");
+		add_header_line("#endif");
+	}
 }
 }
 
 
 // Move the Private and Workgroup global variables to the entry function.
 // Move the Private and Workgroup global variables to the entry function.
@@ -8142,7 +8150,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		expr += "(";
 		expr += "(";
 		for (uint32_t col = 0; col < type.columns; col++)
 		for (uint32_t col = 0; col < type.columns; col++)
 		{
 		{
-			expr += to_enclosed_expression(a);
+			expr += to_enclosed_unpacked_expression(a);
 			expr += " * ";
 			expr += " * ";
 			expr += to_extract_component_expression(b, col);
 			expr += to_extract_component_expression(b, col);
 			if (col + 1 < type.columns)
 			if (col + 1 < type.columns)
@@ -8247,19 +8255,19 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		auto &res_type = get<SPIRType>(type.member_types[1]);
 		auto &res_type = get<SPIRType>(type.member_types[1]);
 		if (opcode == OpIAddCarry)
 		if (opcode == OpIAddCarry)
 		{
 		{
-			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
-			          to_enclosed_expression(op1), ";");
+			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
+					  to_enclosed_unpacked_expression(op0), " + ", to_enclosed_unpacked_expression(op1), ";");
 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
-			          "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
-			          " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
+			          "(1), ", type_to_glsl(res_type), "(0), ", to_unpacked_expression(result_id), ".", to_member_name(type, 0),
+			          " >= max(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), "));");
 		}
 		}
 		else
 		else
 		{
 		{
-			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
-			          to_enclosed_expression(op1), ";");
+			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_unpacked_expression(op0), " - ",
+			          to_enclosed_unpacked_expression(op1), ";");
 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
-			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
-			          " >= ", to_enclosed_expression(op1), ");");
+			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_unpacked_expression(op0),
+			          " >= ", to_enclosed_unpacked_expression(op1), ");");
 		}
 		}
 		break;
 		break;
 	}
 	}
@@ -8274,10 +8282,10 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		auto &type = get<SPIRType>(result_type);
 		auto &type = get<SPIRType>(result_type);
 		emit_uninitialized_temporary_expression(result_type, result_id);
 		emit_uninitialized_temporary_expression(result_type, result_id);
 
 
-		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
-		          to_enclosed_expression(op1), ";");
-		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
-		          to_expression(op1), ");");
+		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
+				  to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";");
+		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(",
+				  to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");");
 		break;
 		break;
 	}
 	}
 
 
@@ -8332,8 +8340,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t id = ops[1];
 		uint32_t id = ops[1];
 		uint32_t a = ops[2], b = ops[3];
 		uint32_t a = ops[2], b = ops[3];
 		bool forward = should_forward(a) && should_forward(b);
 		bool forward = should_forward(a) && should_forward(b);
-		emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
-		        forward);
+		emit_op(result_type, id, join("int(short(", to_unpacked_expression(a), ")) * int(short(", to_unpacked_expression(b), "))"), forward);
 		inherit_expression_dependencies(id, a);
 		inherit_expression_dependencies(id, a);
 		inherit_expression_dependencies(id, b);
 		inherit_expression_dependencies(id, b);
 		break;
 		break;
@@ -8345,8 +8352,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 		uint32_t id = ops[1];
 		uint32_t id = ops[1];
 		uint32_t a = ops[2], b = ops[3];
 		uint32_t a = ops[2], b = ops[3];
 		bool forward = should_forward(a) && should_forward(b);
 		bool forward = should_forward(a) && should_forward(b);
-		emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
-		        forward);
+		emit_op(result_type, id, join("uint(ushort(", to_unpacked_expression(a), ")) * uint(ushort(", to_unpacked_expression(b), "))"), forward);
 		inherit_expression_dependencies(id, a);
 		inherit_expression_dependencies(id, a);
 		inherit_expression_dependencies(id, b);
 		inherit_expression_dependencies(id, b);
 		break;
 		break;
@@ -8373,6 +8379,98 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
 		break; // Nothing to do in the body
 		break; // Nothing to do in the body
 
 
+	case OpConvertUToAccelerationStructureKHR:
+		SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL.");
+	case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
+		SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL.");
+
+	case OpRayQueryInitializeKHR:
+	{
+		flush_variable_declaration(ops[0]);
+
+		statement(to_expression(ops[0]), ".reset(", "ray(", to_expression(ops[4]), ", ", to_expression(ops[6]), ", ",
+		          to_expression(ops[5]), ", ", to_expression(ops[7]), "), ", to_expression(ops[1]),
+		          ", intersection_params());");
+		break;
+	}
+	case OpRayQueryProceedKHR:
+	{
+		flush_variable_declaration(ops[0]);
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".next()"), false);
+		break;
+	}
+#define MSL_RAY_QUERY_IS_CANDIDATE get<SPIRConstant>(ops[3]).scalar_i32() == 0
+
+#define MSL_RAY_QUERY_GET_OP(op, msl_op)                                                   \
+	case OpRayQueryGet##op##KHR:                                                           \
+		flush_variable_declaration(ops[2]);                                                \
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \
+		break
+
+#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op)                                                          \
+	case OpRayQueryGet##op##KHR:                                                                                 \
+		flush_variable_declaration(ops[2]);                                                                      \
+		if (MSL_RAY_QUERY_IS_CANDIDATE)                                                                          \
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \
+		else                                                                                                     \
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \
+		break
+
+#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op)
+#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op)
+
+		MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance);
+		MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_direction);
+		MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_origin);
+		MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord);
+		MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform);
+		MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform);
+		MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing);
+
+	case OpRayQueryGetIntersectionTypeKHR:
+		flush_variable_declaration(ops[2]);
+		if (MSL_RAY_QUERY_IS_CANDIDATE)
+			emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_candidate_intersection_type()) - 1"),
+			        false);
+		else
+			emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_committed_intersection_type())"), false);
+		break;
+	case OpRayQueryGetIntersectionTKHR:
+		flush_variable_declaration(ops[2]);
+		if (MSL_RAY_QUERY_IS_CANDIDATE)
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_candidate_triangle_distance()"), false);
+		else
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_committed_distance()"), false);
+		break;
+	case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
+	{
+		flush_variable_declaration(ops[0]);
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".is_candidate_non_opaque_bounding_box()"), false);
+		break;
+	}
+	case OpRayQueryConfirmIntersectionKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[0]), ".commit_triangle_intersection();");
+		break;
+	case OpRayQueryGenerateIntersectionKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[0]), ".commit_bounding_box_intersection(", to_expression(ops[1]), ");");
+		break;
+	case OpRayQueryTerminateKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[0]), ".abort();");
+		break;
+#undef MSL_RAY_QUERY_GET_OP
+#undef MSL_RAY_QUERY_IS_CANDIDATE
+#undef MSL_RAY_QUERY_IS_OP2
+#undef MSL_RAY_QUERY_GET_OP2
+#undef MSL_RAY_QUERY_OP_INNER2
 	default:
 	default:
 		CompilerGLSL::emit_instruction(instruction);
 		CompilerGLSL::emit_instruction(instruction);
 		break;
 		break;
@@ -11295,6 +11393,12 @@ bool CompilerMSL::is_sample_rate() const
 	        (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
 	        (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
 }
 }
 
 
+bool CompilerMSL::is_intersection_query() const
+{
+	auto &caps = get_declared_capabilities();
+	return std::find(caps.begin(), caps.end(), CapabilityRayQueryKHR) != caps.end();
+}
+
 void CompilerMSL::entry_point_args_builtin(string &ep_args)
 void CompilerMSL::entry_point_args_builtin(string &ep_args)
 {
 {
 	// Builtin variables
 	// Builtin variables
@@ -11773,6 +11877,10 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
 			}
 			}
 			break;
 			break;
 		}
 		}
+		case SPIRType::AccelerationStructure:
+			ep_args += ", " + type_to_glsl(type, var_id) + " " + r.name;
+			ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
+			break;
 		default:
 		default:
 			if (!ep_args.empty())
 			if (!ep_args.empty())
 				ep_args += ", ";
 				ep_args += ", ";
@@ -13126,9 +13234,6 @@ void CompilerMSL::sync_entry_point_aliases_and_names()
 
 
 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
 {
 {
-	if (index < uint32_t(type.member_type_index_redirection.size()))
-		index = type.member_type_index_redirection[index];
-
 	auto *var = maybe_get<SPIRVariable>(base);
 	auto *var = maybe_get<SPIRVariable>(base);
 	// If this is a buffer array, we have to dereference the buffer pointers.
 	// If this is a buffer array, we have to dereference the buffer pointers.
 	// Otherwise, if this is a pointer expression, dereference it.
 	// Otherwise, if this is a pointer expression, dereference it.
@@ -13243,8 +13348,23 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
 
 
 	// Scalars
 	// Scalars
 	case SPIRType::Boolean:
 	case SPIRType::Boolean:
-		type_name = "bool";
+	{
+		auto *var = maybe_get_backing_variable(id);
+		if (var && var->basevariable)
+			var = &get<SPIRVariable>(var->basevariable);
+
+		// Need to special-case threadgroup booleans. They are supposed to be logical
+		// storage, but MSL compilers will sometimes crash if you use threadgroup bool.
+		// Workaround this by using 16-bit types instead and fixup on load-store to this data.
+		// FIXME: We have no sane way of working around this problem if a struct member is boolean
+		// and that struct is used as a threadgroup variable, but ... sigh.
+		if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup)
+			type_name = "short";
+		else
+			type_name = "bool";
 		break;
 		break;
+	}
+
 	case SPIRType::Char:
 	case SPIRType::Char:
 	case SPIRType::SByte:
 	case SPIRType::SByte:
 		type_name = "char";
 		type_name = "char";
@@ -13283,6 +13403,16 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
 	case SPIRType::Double:
 	case SPIRType::Double:
 		type_name = "double"; // Currently unsupported
 		type_name = "double"; // Currently unsupported
 		break;
 		break;
+	case SPIRType::AccelerationStructure:
+		if (msl_options.supports_msl_version(2, 4))
+			type_name = "acceleration_structure<instancing>";
+		else if (msl_options.supports_msl_version(2, 3))
+			type_name = "instance_acceleration_structure";
+		else
+			SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above.");
+		break;
+	case SPIRType::RayQuery:
+		return "intersection_query<instancing, triangle_data>";
 
 
 	default:
 	default:
 		return "unknown_type";
 		return "unknown_type";
@@ -13327,6 +13457,7 @@ string CompilerMSL::type_to_array_glsl(const SPIRType &type)
 	{
 	{
 	case SPIRType::AtomicCounter:
 	case SPIRType::AtomicCounter:
 	case SPIRType::ControlPointArray:
 	case SPIRType::ControlPointArray:
+	case SPIRType::RayQuery:
 	{
 	{
 		return CompilerGLSL::type_to_array_glsl(type);
 		return CompilerGLSL::type_to_array_glsl(type);
 	}
 	}
@@ -15240,11 +15371,13 @@ void CompilerMSL::MemberSorter::sort()
 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
 	}
 	}
 
 
+	// If we're sorting by Offset, this might affect user code which accesses a buffer block.
+	// We will need to redirect member indices from defined index to sorted index using reverse lookup.
 	if (sort_aspect == SortAspect::Offset)
 	if (sort_aspect == SortAspect::Offset)
 	{
 	{
-		// If we're sorting by Offset, this might affect user code which accesses a buffer block.
-		// We will need to redirect member indices from one index to sorted index.
-		type.member_type_index_redirection = std::move(mbr_idxs);
+		type.member_type_index_redirection.resize(mbr_cnt);
+		for (uint32_t map_idx = 0; map_idx < mbr_cnt; map_idx++)
+			type.member_type_index_redirection[mbr_idxs[map_idx]] = map_idx;
 	}
 	}
 }
 }
 
 
@@ -15294,12 +15427,16 @@ void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t
 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
 }
 }
 
 
-void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
 {
 {
 	auto *var = maybe_get_backing_variable(source_id);
 	auto *var = maybe_get_backing_variable(source_id);
 	if (var)
 	if (var)
 		source_id = var->self;
 		source_id = var->self;
 
 
+	// Type fixups for workgroup variables if they are booleans.
+	if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+		expr = join(type_to_glsl(expr_type), "(", expr, ")");
+
 	// Only interested in standalone builtin variables.
 	// Only interested in standalone builtin variables.
 	if (!has_decoration(source_id, DecorationBuiltIn))
 	if (!has_decoration(source_id, DecorationBuiltIn))
 		return;
 		return;
@@ -15386,12 +15523,20 @@ void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr,
 	}
 	}
 }
 }
 
 
-void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
 {
 {
 	auto *var = maybe_get_backing_variable(target_id);
 	auto *var = maybe_get_backing_variable(target_id);
 	if (var)
 	if (var)
 		target_id = var->self;
 		target_id = var->self;
 
 
+	// Type fixups for workgroup variables if they are booleans.
+	if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+	{
+		auto short_type = expr_type;
+		short_type.basetype = SPIRType::Short;
+		expr = join(type_to_glsl(short_type), "(", expr, ")");
+	}
+
 	// Only interested in standalone builtin variables.
 	// Only interested in standalone builtin variables.
 	if (!has_decoration(target_id, DecorationBuiltIn))
 	if (!has_decoration(target_id, DecorationBuiltIn))
 		return;
 		return;

+ 3 - 2
3rdparty/spirv-cross/spirv_msl.hpp

@@ -864,6 +864,7 @@ protected:
 	std::string to_swizzle_expression(uint32_t id);
 	std::string to_swizzle_expression(uint32_t id);
 	std::string to_buffer_size_expression(uint32_t id);
 	std::string to_buffer_size_expression(uint32_t id);
 	bool is_sample_rate() const;
 	bool is_sample_rate() const;
+	bool is_intersection_query() const;
 	bool is_direct_input_builtin(spv::BuiltIn builtin);
 	bool is_direct_input_builtin(spv::BuiltIn builtin);
 	std::string builtin_qualifier(spv::BuiltIn builtin);
 	std::string builtin_qualifier(spv::BuiltIn builtin);
 	std::string builtin_type_decl(spv::BuiltIn builtin, uint32_t id = 0);
 	std::string builtin_type_decl(spv::BuiltIn builtin, uint32_t id = 0);
@@ -959,8 +960,8 @@ protected:
 
 
 	bool does_shader_write_sample_mask = false;
 	bool does_shader_write_sample_mask = false;
 
 
-	void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
-	void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
+	void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
+	void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
 	void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override;
 	void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override;
 
 
 	void analyze_sampled_image_usage();
 	void analyze_sampled_image_usage();