Przeglądaj źródła

Updated spirv-cross.

Бранимир Караџић 5 lat temu
rodzic
commit
847d79dc11

+ 143 - 1
3rdparty/spirv-cross/spirv_cross.cpp

@@ -1652,6 +1652,148 @@ size_t Compiler::get_declared_struct_size_runtime_array(const SPIRType &type, si
 	return size;
 	return size;
 }
 }
 
 
+uint32_t Compiler::evaluate_spec_constant_u32(const SPIRConstantOp &spec) const
+{
+	auto &result_type = get<SPIRType>(spec.basetype);
+	if (result_type.basetype != SPIRType::UInt && result_type.basetype != SPIRType::Int && result_type.basetype != SPIRType::Boolean)
+		SPIRV_CROSS_THROW("Only 32-bit integers and booleans are currently supported when evaluating specialization constants.\n");
+	if (!is_scalar(result_type))
+		SPIRV_CROSS_THROW("Spec constant evaluation must be a scalar.\n");
+
+	uint32_t value = 0;
+
+	const auto eval_u32 = [&](uint32_t id) -> uint32_t {
+		auto &type = expression_type(id);
+		if (type.basetype != SPIRType::UInt && type.basetype != SPIRType::Int && type.basetype != SPIRType::Boolean)
+			SPIRV_CROSS_THROW("Only 32-bit integers and booleans are currently supported when evaluating specialization constants.\n");
+		if (!is_scalar(type))
+			SPIRV_CROSS_THROW("Spec constant evaluation must be a scalar.\n");
+		if (const auto *c = this->maybe_get<SPIRConstant>(id))
+			return c->scalar();
+		else
+			return evaluate_spec_constant_u32(this->get<SPIRConstantOp>(id));
+	};
+
+#define binary_spec_op(op, binary_op) \
+	case Op##op: value = eval_u32(spec.arguments[0]) binary_op eval_u32(spec.arguments[1]); break
+#define binary_spec_op_cast(op, binary_op, type) \
+	case Op##op: value = uint32_t(type(eval_u32(spec.arguments[0])) binary_op type(eval_u32(spec.arguments[1]))); break
+
+	// Support the basic opcodes which are typically used when computing array sizes.
+	switch (spec.opcode)
+	{
+	binary_spec_op(IAdd, +);
+	binary_spec_op(ISub, -);
+	binary_spec_op(IMul, *);
+	binary_spec_op(BitwiseAnd, &);
+	binary_spec_op(BitwiseOr, |);
+	binary_spec_op(BitwiseXor, ^);
+	binary_spec_op(LogicalAnd, &);
+	binary_spec_op(LogicalOr, |);
+	binary_spec_op(ShiftLeftLogical, <<);
+	binary_spec_op(ShiftRightLogical, >>);
+	binary_spec_op_cast(ShiftRightArithmetic, >>, int32_t);
+	binary_spec_op(LogicalEqual, ==);
+	binary_spec_op(LogicalNotEqual, !=);
+	binary_spec_op(IEqual, ==);
+	binary_spec_op(INotEqual, !=);
+	binary_spec_op(ULessThan, <);
+	binary_spec_op(ULessThanEqual, <=);
+	binary_spec_op(UGreaterThan, >);
+	binary_spec_op(UGreaterThanEqual, >=);
+	binary_spec_op_cast(SLessThan, <, int32_t);
+	binary_spec_op_cast(SLessThanEqual, <=, int32_t);
+	binary_spec_op_cast(SGreaterThan, >, int32_t);
+	binary_spec_op_cast(SGreaterThanEqual, >=, int32_t);
+#undef binary_spec_op
+#undef binary_spec_op_cast
+
+	case OpLogicalNot:
+		value = uint32_t(!eval_u32(spec.arguments[0]));
+		break;
+
+	case OpNot:
+		value = ~eval_u32(spec.arguments[0]);
+		break;
+
+	case OpSNegate:
+		value = -eval_u32(spec.arguments[0]);
+		break;
+
+	case OpSelect:
+		value = eval_u32(spec.arguments[0]) ? eval_u32(spec.arguments[1]) : eval_u32(spec.arguments[2]);
+		break;
+
+	case OpUMod:
+	{
+		uint32_t a = eval_u32(spec.arguments[0]);
+		uint32_t b = eval_u32(spec.arguments[1]);
+		if (b == 0)
+			SPIRV_CROSS_THROW("Undefined behavior in UMod, b == 0.\n");
+		value = a % b;
+		break;
+	}
+
+	case OpSRem:
+	{
+		auto a = int32_t(eval_u32(spec.arguments[0]));
+		auto b = int32_t(eval_u32(spec.arguments[1]));
+		if (b == 0)
+			SPIRV_CROSS_THROW("Undefined behavior in SRem, b == 0.\n");
+		value = a % b;
+		break;
+	}
+
+	case OpSMod:
+	{
+		auto a = int32_t(eval_u32(spec.arguments[0]));
+		auto b = int32_t(eval_u32(spec.arguments[1]));
+		if (b == 0)
+			SPIRV_CROSS_THROW("Undefined behavior in SMod, b == 0.\n");
+		auto v = a % b;
+
+		// Makes sure we match the sign of b, not a.
+		if ((b < 0 && v > 0) || (b > 0 && v < 0))
+			v += b;
+		value = v;
+		break;
+	}
+
+	case OpUDiv:
+	{
+		uint32_t a = eval_u32(spec.arguments[0]);
+		uint32_t b = eval_u32(spec.arguments[1]);
+		if (b == 0)
+			SPIRV_CROSS_THROW("Undefined behavior in UDiv, b == 0.\n");
+		value = a / b;
+		break;
+	}
+
+	case OpSDiv:
+	{
+		auto a = int32_t(eval_u32(spec.arguments[0]));
+		auto b = int32_t(eval_u32(spec.arguments[1]));
+		if (b == 0)
+			SPIRV_CROSS_THROW("Undefined behavior in SDiv, b == 0.\n");
+		value = a / b;
+		break;
+	}
+
+	default:
+		SPIRV_CROSS_THROW("Unsupported spec constant opcode for evaluation.\n");
+	}
+
+	return value;
+}
+
+uint32_t Compiler::evaluate_constant_u32(uint32_t id) const
+{
+	if (const auto *c = maybe_get<SPIRConstant>(id))
+		return c->scalar();
+	else
+		return evaluate_spec_constant_u32(get<SPIRConstantOp>(id));
+}
+
 size_t Compiler::get_declared_struct_member_size(const SPIRType &struct_type, uint32_t index) const
 size_t Compiler::get_declared_struct_member_size(const SPIRType &struct_type, uint32_t index) const
 {
 {
 	if (struct_type.member_types.empty())
 	if (struct_type.member_types.empty())
@@ -1686,7 +1828,7 @@ size_t Compiler::get_declared_struct_member_size(const SPIRType &struct_type, ui
 	{
 	{
 		// For arrays, we can use ArrayStride to get an easy check.
 		// For arrays, we can use ArrayStride to get an easy check.
 		bool array_size_literal = type.array_size_literal.back();
 		bool array_size_literal = type.array_size_literal.back();
-		uint32_t array_size = array_size_literal ? type.array.back() : get<SPIRConstant>(type.array.back()).scalar();
+		uint32_t array_size = array_size_literal ? type.array.back() : evaluate_constant_u32(type.array.back());
 		return type_struct_member_array_stride(struct_type, index) * array_size;
 		return type_struct_member_array_stride(struct_type, index) * array_size;
 	}
 	}
 	else if (type.basetype == SPIRType::Struct)
 	else if (type.basetype == SPIRType::Struct)

+ 3 - 0
3rdparty/spirv-cross/spirv_cross.hpp

@@ -1060,6 +1060,9 @@ protected:
 
 
 	bool flush_phi_required(BlockID from, BlockID to) const;
 	bool flush_phi_required(BlockID from, BlockID to) const;
 
 
+	uint32_t evaluate_spec_constant_u32(const SPIRConstantOp &spec) const;
+	uint32_t evaluate_constant_u32(uint32_t id) const;
+
 private:
 private:
 	// Used only to implement the old deprecated get_entry_point() interface.
 	// Used only to implement the old deprecated get_entry_point() interface.
 	const SPIREntryPoint &get_first_entry_point(const std::string &name) const;
 	const SPIREntryPoint &get_first_entry_point(const std::string &name) const;

+ 24 - 59
3rdparty/spirv-cross/spirv_glsl.cpp

@@ -6755,7 +6755,7 @@ void CompilerGLSL::emit_subgroup_op(const Instruction &i)
 	uint32_t result_type = ops[0];
 	uint32_t result_type = ops[0];
 	uint32_t id = ops[1];
 	uint32_t id = ops[1];
 
 
-	auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
+	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
 	if (scope != ScopeSubgroup)
 	if (scope != ScopeSubgroup)
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 
 
@@ -6889,7 +6889,7 @@ case OpGroupNonUniform##op: \
 
 
 	case OpGroupNonUniformQuadSwap:
 	case OpGroupNonUniformQuadSwap:
 	{
 	{
-		uint32_t direction = get<SPIRConstant>(ops[4]).scalar();
+		uint32_t direction = evaluate_constant_u32(ops[4]);
 		if (direction == 0)
 		if (direction == 0)
 			emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapHorizontal");
 			emit_unary_func_op(result_type, id, ops[3], "subgroupQuadSwapHorizontal");
 		else if (direction == 1)
 		else if (direction == 1)
@@ -7635,7 +7635,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
 		else if (type->basetype == SPIRType::Struct)
 		else if (type->basetype == SPIRType::Struct)
 		{
 		{
 			if (!is_literal)
 			if (!is_literal)
-				index = get<SPIRConstant>(index).scalar();
+				index = evaluate_constant_u32(index);
 
 
 			if (index >= type->member_types.size())
 			if (index >= type->member_types.size())
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
@@ -8156,7 +8156,7 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(
 		// We also check if this member is a builtin, since we then replace the entire expression with the builtin one.
 		// We also check if this member is a builtin, since we then replace the entire expression with the builtin one.
 		else if (type->basetype == SPIRType::Struct)
 		else if (type->basetype == SPIRType::Struct)
 		{
 		{
-			index = get<SPIRConstant>(index).scalar();
+			index = evaluate_constant_u32(index);
 
 
 			if (index >= type->member_types.size())
 			if (index >= type->member_types.size())
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
 				SPIRV_CROSS_THROW("Member index is out of bounds!");
@@ -8184,7 +8184,7 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(
 			auto *constant = maybe_get<SPIRConstant>(index);
 			auto *constant = maybe_get<SPIRConstant>(index);
 			if (constant)
 			if (constant)
 			{
 			{
-				index = get<SPIRConstant>(index).scalar();
+				index = evaluate_constant_u32(index);
 				offset += index * (row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride);
 				offset += index * (row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride);
 			}
 			}
 			else
 			else
@@ -8213,7 +8213,7 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(
 			auto *constant = maybe_get<SPIRConstant>(index);
 			auto *constant = maybe_get<SPIRConstant>(index);
 			if (constant)
 			if (constant)
 			{
 			{
-				index = get<SPIRConstant>(index).scalar();
+				index = evaluate_constant_u32(index);
 				offset += index * (row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8));
 				offset += index * (row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8));
 			}
 			}
 			else
 			else
@@ -10805,14 +10805,14 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 
 
 		if (opcode == OpMemoryBarrier)
 		if (opcode == OpMemoryBarrier)
 		{
 		{
-			memory = get<SPIRConstant>(ops[0]).scalar();
-			semantics = get<SPIRConstant>(ops[1]).scalar();
+			memory = evaluate_constant_u32(ops[0]);
+			semantics = evaluate_constant_u32(ops[1]);
 		}
 		}
 		else
 		else
 		{
 		{
-			execution_scope = get<SPIRConstant>(ops[0]).scalar();
-			memory = get<SPIRConstant>(ops[1]).scalar();
-			semantics = get<SPIRConstant>(ops[2]).scalar();
+			execution_scope = evaluate_constant_u32(ops[0]);
+			memory = evaluate_constant_u32(ops[1]);
+			semantics = evaluate_constant_u32(ops[2]);
 		}
 		}
 
 
 		if (execution_scope == ScopeSubgroup || memory == ScopeSubgroup)
 		if (execution_scope == ScopeSubgroup || memory == ScopeSubgroup)
@@ -10841,8 +10841,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
 			if (next && next->op == OpControlBarrier)
 			if (next && next->op == OpControlBarrier)
 			{
 			{
 				auto *next_ops = stream(*next);
 				auto *next_ops = stream(*next);
-				uint32_t next_memory = get<SPIRConstant>(next_ops[1]).scalar();
-				uint32_t next_semantics = get<SPIRConstant>(next_ops[2]).scalar();
+				uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
+				uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
 				next_semantics = mask_relevant_memory_semantics(next_semantics);
 				next_semantics = mask_relevant_memory_semantics(next_semantics);
 
 
 				bool memory_scope_covered = false;
 				bool memory_scope_covered = false;
@@ -11795,15 +11795,7 @@ uint32_t CompilerGLSL::to_array_size_literal(const SPIRType &type, uint32_t inde
 	{
 	{
 		// Use the default spec constant value.
 		// Use the default spec constant value.
 		// This is the best we can do.
 		// This is the best we can do.
-		uint32_t array_size_id = type.array[index];
-
-		// Explicitly check for this case. The error message you would get (bad cast) makes no sense otherwise.
-		if (ir.ids[array_size_id].get_type() == TypeConstantOp)
-			SPIRV_CROSS_THROW("An array size was found to be an OpSpecConstantOp. This is not supported since "
-			                  "SPIRV-Cross cannot deduce the actual size here.");
-
-		uint32_t array_size = get<SPIRConstant>(array_size_id).scalar();
-		return array_size;
+		return evaluate_constant_u32(type.array[index]);
 	}
 	}
 }
 }
 
 
@@ -12740,64 +12732,37 @@ void CompilerGLSL::branch(BlockID from, uint32_t cond, BlockID true_block, Block
 	auto &from_block = get<SPIRBlock>(from);
 	auto &from_block = get<SPIRBlock>(from);
 	BlockID merge_block = from_block.merge == SPIRBlock::MergeSelection ? from_block.next_block : BlockID(0);
 	BlockID merge_block = from_block.merge == SPIRBlock::MergeSelection ? from_block.next_block : BlockID(0);
 
 
-	// If we branch directly to a selection merge target, we don't need a code path.
-	// This covers both merge out of if () / else () as well as a break for switch blocks.
-	bool true_sub = !is_conditional(true_block);
-	bool false_sub = !is_conditional(false_block);
+	// If we branch directly to our selection merge target, we don't need a code path.
+	bool true_block_needs_code = true_block != merge_block || flush_phi_required(from, true_block);
+	bool false_block_needs_code = false_block != merge_block || flush_phi_required(from, false_block);
 
 
-	bool true_block_is_selection_merge = true_block == merge_block;
-	bool false_block_is_selection_merge = false_block == merge_block;
+	if (!true_block_needs_code && !false_block_needs_code)
+		return;
 
 
-	if (true_sub)
+	emit_block_hints(get<SPIRBlock>(from));
+
+	if (true_block_needs_code)
 	{
 	{
-		emit_block_hints(get<SPIRBlock>(from));
 		statement("if (", to_expression(cond), ")");
 		statement("if (", to_expression(cond), ")");
 		begin_scope();
 		begin_scope();
 		branch(from, true_block);
 		branch(from, true_block);
 		end_scope();
 		end_scope();
 
 
-		// If we merge to continue, we handle that explicitly in emit_block_chain(),
-		// so there is no need to branch to it directly here.
-		// break; is required to handle ladder fallthrough cases, so keep that in for now, even
-		// if we could potentially handle it in emit_block_chain().
-		if (false_sub || (!false_block_is_selection_merge && is_continue(false_block)) || is_break(false_block))
+		if (false_block_needs_code)
 		{
 		{
 			statement("else");
 			statement("else");
 			begin_scope();
 			begin_scope();
 			branch(from, false_block);
 			branch(from, false_block);
 			end_scope();
 			end_scope();
 		}
 		}
-		else if (flush_phi_required(from, false_block))
-		{
-			statement("else");
-			begin_scope();
-			flush_phi(from, false_block);
-			end_scope();
-		}
 	}
 	}
-	else if (false_sub)
+	else if (false_block_needs_code)
 	{
 	{
 		// Only need false path, use negative conditional.
 		// Only need false path, use negative conditional.
-		emit_block_hints(get<SPIRBlock>(from));
 		statement("if (!", to_enclosed_expression(cond), ")");
 		statement("if (!", to_enclosed_expression(cond), ")");
 		begin_scope();
 		begin_scope();
 		branch(from, false_block);
 		branch(from, false_block);
 		end_scope();
 		end_scope();
-
-		if ((!true_block_is_selection_merge && is_continue(true_block)) || is_break(true_block))
-		{
-			statement("else");
-			begin_scope();
-			branch(from, true_block);
-			end_scope();
-		}
-		else if (flush_phi_required(from, true_block))
-		{
-			statement("else");
-			begin_scope();
-			flush_phi(from, true_block);
-			end_scope();
-		}
 	}
 	}
 }
 }
 
 

+ 10 - 10
3rdparty/spirv-cross/spirv_hlsl.cpp

@@ -790,7 +790,7 @@ uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
 			if (type.array_size_literal[i])
 			if (type.array_size_literal[i])
 				array_multiplier *= type.array[i];
 				array_multiplier *= type.array[i];
 			else
 			else
-				array_multiplier *= get<SPIRConstant>(type.array[i]).scalar();
+				array_multiplier *= evaluate_constant_u32(type.array[i]);
 		}
 		}
 		elements += array_multiplier * type.columns;
 		elements += array_multiplier * type.columns;
 	}
 	}
@@ -2860,7 +2860,7 @@ void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
 			}
 			}
 			else if (gather)
 			else if (gather)
 			{
 			{
-				uint32_t comp_num = get<SPIRConstant>(comp).scalar();
+				uint32_t comp_num = evaluate_constant_u32(comp);
 				if (hlsl_options.shader_model >= 50)
 				if (hlsl_options.shader_model >= 50)
 				{
 				{
 					switch (comp_num)
 					switch (comp_num)
@@ -4454,7 +4454,7 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
 	uint32_t result_type = ops[0];
 	uint32_t result_type = ops[0];
 	uint32_t id = ops[1];
 	uint32_t id = ops[1];
 
 
-	auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
+	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
 	if (scope != ScopeSubgroup)
 	if (scope != ScopeSubgroup)
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 
 
@@ -4611,7 +4611,7 @@ case OpGroupNonUniform##op: \
 
 
 	case OpGroupNonUniformQuadSwap:
 	case OpGroupNonUniformQuadSwap:
 	{
 	{
-		uint32_t direction = get<SPIRConstant>(ops[4]).scalar();
+		uint32_t direction = evaluate_constant_u32(ops[4]);
 		if (direction == 0)
 		if (direction == 0)
 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
 		else if (direction == 1)
 		else if (direction == 1)
@@ -5269,13 +5269,13 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 
 
 		if (opcode == OpMemoryBarrier)
 		if (opcode == OpMemoryBarrier)
 		{
 		{
-			memory = get<SPIRConstant>(ops[0]).scalar();
-			semantics = get<SPIRConstant>(ops[1]).scalar();
+			memory = evaluate_constant_u32(ops[0]);
+			semantics = evaluate_constant_u32(ops[1]);
 		}
 		}
 		else
 		else
 		{
 		{
-			memory = get<SPIRConstant>(ops[1]).scalar();
-			semantics = get<SPIRConstant>(ops[2]).scalar();
+			memory = evaluate_constant_u32(ops[1]);
+			semantics = evaluate_constant_u32(ops[2]);
 		}
 		}
 
 
 		if (memory == ScopeSubgroup)
 		if (memory == ScopeSubgroup)
@@ -5295,8 +5295,8 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
 			if (next && next->op == OpControlBarrier)
 			if (next && next->op == OpControlBarrier)
 			{
 			{
 				auto *next_ops = stream(*next);
 				auto *next_ops = stream(*next);
-				uint32_t next_memory = get<SPIRConstant>(next_ops[1]).scalar();
-				uint32_t next_semantics = get<SPIRConstant>(next_ops[2]).scalar();
+				uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
+				uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
 				next_semantics = mask_relevant_memory_semantics(next_semantics);
 				next_semantics = mask_relevant_memory_semantics(next_semantics);
 
 
 				// There is no "just execution barrier" in HLSL.
 				// There is no "just execution barrier" in HLSL.

+ 9 - 15
3rdparty/spirv-cross/spirv_msl.cpp

@@ -7175,8 +7175,8 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
 		return;
 		return;
 
 
-	uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
-	uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
+	uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
+	uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
 	// Use the wider of the two scopes (smaller value)
 	// Use the wider of the two scopes (smaller value)
 	exe_scope = min(exe_scope, mem_scope);
 	exe_scope = min(exe_scope, mem_scope);
 
 
@@ -7187,7 +7187,7 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
 		bar_stmt = "threadgroup_barrier";
 		bar_stmt = "threadgroup_barrier";
 	bar_stmt += "(";
 	bar_stmt += "(";
 
 
-	uint32_t mem_sem = id_mem_sem ? get<SPIRConstant>(id_mem_sem).scalar() : uint32_t(MemorySemanticsMaskNone);
+	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
 
 
 	// Use the | operator to combine flags if we can.
 	// Use the | operator to combine flags if we can.
 	if (msl_options.supports_msl_version(1, 2))
 	if (msl_options.supports_msl_version(1, 2))
@@ -8534,13 +8534,7 @@ string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
 // The ID must be a scalar constant.
 // The ID must be a scalar constant.
 string CompilerMSL::to_component_argument(uint32_t id)
 string CompilerMSL::to_component_argument(uint32_t id)
 {
 {
-	if (ir.ids[id].get_type() != TypeConstant)
-	{
-		SPIRV_CROSS_THROW("ID " + to_string(id) + " is not an OpConstant.");
-		return "component::x";
-	}
-
-	uint32_t component_index = get<SPIRConstant>(id).scalar();
+	uint32_t component_index = evaluate_constant_u32(id);
 	switch (component_index)
 	switch (component_index)
 	{
 	{
 	case 0:
 	case 0:
@@ -11820,7 +11814,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
 	uint32_t result_type = ops[0];
 	uint32_t result_type = ops[0];
 	uint32_t id = ops[1];
 	uint32_t id = ops[1];
 
 
-	auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
+	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
 	if (scope != ScopeSubgroup)
 	if (scope != ScopeSubgroup)
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
 
 
@@ -11920,7 +11914,7 @@ case OpGroupNonUniform##op: \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
 			/* Only cluster sizes of 4 are supported. */ \
 			/* Only cluster sizes of 4 are supported. */ \
-			uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
+			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
 			if (cluster_size != 4) \
 			if (cluster_size != 4) \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
@@ -11949,7 +11943,7 @@ case OpGroupNonUniform##op: \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
 			/* Only cluster sizes of 4 are supported. */ \
 			/* Only cluster sizes of 4 are supported. */ \
-			uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
+			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
 			if (cluster_size != 4) \
 			if (cluster_size != 4) \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
@@ -11972,7 +11966,7 @@ case OpGroupNonUniform##op: \
 		else if (operation == GroupOperationClusteredReduce) \
 		else if (operation == GroupOperationClusteredReduce) \
 		{ \
 		{ \
 			/* Only cluster sizes of 4 are supported. */ \
 			/* Only cluster sizes of 4 are supported. */ \
-			uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
+			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
 			if (cluster_size != 4) \
 			if (cluster_size != 4) \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
 			emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
 			emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
@@ -12010,7 +12004,7 @@ case OpGroupNonUniform##op: \
 		// n 2  | 3   0   1
 		// n 2  | 3   0   1
 		// e 3  | 2   1   0
 		// e 3  | 2   1   0
 		// Notice that target = source ^ (direction + 1).
 		// Notice that target = source ^ (direction + 1).
-		uint32_t mask = get<SPIRConstant>(ops[4]).scalar() + 1;
+		uint32_t mask = evaluate_constant_u32(ops[4]) + 1;
 		uint32_t mask_id = ir.increase_bound_by(1);
 		uint32_t mask_id = ir.increase_bound_by(1);
 		set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
 		set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
 		emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
 		emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");