ソースを参照

Add more shader reflection code

Panagiotis Christopoulos Charitos 6 年 前
コミット
752125e5f4

+ 78 - 3
src/anki/shader_compiler/ShaderProgramCompiler.cpp

@@ -152,6 +152,8 @@ private:
 
 	ANKI_USE_RESULT Error opaqueReflection(
 		const spirv_cross::Resource& res, DynamicArrayAuto<ShaderProgramBinaryOpaque>& opaques) const;
+
+	ANKI_USE_RESULT Error constsReflection(DynamicArrayAuto<ShaderProgramBinaryConstant>& consts) const;
 };
 
 Error SpirvReflector::blockReflection(
@@ -173,7 +175,7 @@ Error SpirvReflector::blockReflection(
 		const std::string name = (!res.name.empty()) ? res.name : get_fallback_name(fallbackId);
 		if(name.length() == 0 || name.length() > MAX_SHADER_BINARY_NAME_LENGTH)
 		{
-			ANKI_SHADER_COMPILER_LOGE("Too big of a name: %s", name.c_str());
+			ANKI_SHADER_COMPILER_LOGE("Wrong name length: %s", name.length() ? name.c_str() : " ");
 			return Error::USER_DATA;
 		}
 		memcpy(newBlock.m_name.getBegin(), name.c_str(), name.length() + 1);
@@ -284,7 +286,7 @@ Error SpirvReflector::opaqueReflection(
 	const std::string name = (!res.name.empty()) ? res.name : get_fallback_name(fallbackId);
 	if(name.length() == 0 || name.length() > MAX_SHADER_BINARY_NAME_LENGTH)
 	{
-		ANKI_SHADER_COMPILER_LOGE("Too big of a name: %s", name.c_str());
+		ANKI_SHADER_COMPILER_LOGE("Wrong name length: %s", name.length() ? name.c_str() : " ");
 		return Error::USER_DATA;
 	}
 	memcpy(newOpaque.m_name.getBegin(), name.c_str(), name.length() + 1);
@@ -350,6 +352,72 @@ Error SpirvReflector::opaqueReflection(
 	return Error::NONE;
 }
 
+Error SpirvReflector::constsReflection(DynamicArrayAuto<ShaderProgramBinaryConstant>& consts) const
+{
+	spirv_cross::SmallVector<spirv_cross::SpecializationConstant> specConsts = get_specialization_constants();
+	for(const spirv_cross::SpecializationConstant& c : specConsts)
+	{
+		ShaderProgramBinaryConstant newConst;
+
+		const spirv_cross::SPIRConstant cc = get<spirv_cross::SPIRConstant>(c.id);
+		const spirv_cross::SPIRType type = get<spirv_cross::SPIRType>(cc.constant_type);
+
+		const std::string name = get_name(c.id);
+		if(name.length() == 0 || name.length() > MAX_SHADER_BINARY_NAME_LENGTH)
+		{
+			ANKI_SHADER_COMPILER_LOGE("Wrong name length: %s", name.length() ? name.c_str() : " ");
+			return Error::USER_DATA;
+		}
+		memcpy(newConst.m_name.getBegin(), name.c_str(), name.length() + 1);
+
+		newConst.m_constantId = c.constant_id;
+
+		switch(type.basetype)
+		{
+		case spirv_cross::SPIRType::UInt:
+		case spirv_cross::SPIRType::Int:
+			newConst.m_type = ShaderVariableDataType::INT;
+			break;
+		case spirv_cross::SPIRType::Float:
+			newConst.m_type = ShaderVariableDataType::FLOAT;
+			break;
+		default:
+			ANKI_SHADER_COMPILER_LOGE("Can't determine the type of the spec constant: %s", name.c_str());
+			return Error::USER_DATA;
+		}
+
+		// Add it
+		Bool found = false;
+		for(const ShaderProgramBinaryConstant& other : consts)
+		{
+			const Bool nameSame = strcmp(other.m_name.getBegin(), newConst.m_name.getBegin()) == 0;
+			const Bool typeSame = other.m_type == newConst.m_type;
+			const Bool idSame = other.m_constantId == newConst.m_constantId;
+
+			const Bool err0 = nameSame && (!typeSame || !idSame);
+			const Bool err1 = idSame && (!nameSame || !typeSame);
+			if(err0 || err1)
+			{
+				ANKI_SHADER_COMPILER_LOGE("Linking error");
+				return Error::USER_DATA;
+			}
+
+			if(idSame)
+			{
+				found = true;
+				break;
+			}
+		}
+
+		if(!found)
+		{
+			consts.emplaceBack(newConst);
+		}
+	}
+
+	return Error::NONE;
+}
+
 Error SpirvReflector::performSpirvReflection(ShaderProgramBinaryReflection& refl,
 	Array<ConstWeakArray<U8, PtrSize>, U32(ShaderType::COUNT)> spirv,
 	GenericMemoryPoolAllocator<U8> tmpAlloc,
@@ -411,6 +479,9 @@ Error SpirvReflector::performSpirvReflection(ShaderProgramBinaryReflection& refl
 			ANKI_CHECK(compiler.opaqueReflection(res, opaques));
 		}
 
+		// Spec consts
+		ANKI_CHECK(compiler.constsReflection(specializationConstants));
+
 		// TODO
 	}
 
@@ -432,6 +503,10 @@ Error SpirvReflector::performSpirvReflection(ShaderProgramBinaryReflection& refl
 	opaques.moveAndReset(firstOpaque, size, storage);
 	refl.m_opaques.setArray(firstOpaque, size);
 
+	ShaderProgramBinaryConstant* firstConst;
+	specializationConstants.moveAndReset(firstConst, size, storage);
+	refl.m_specializationConstants.setArray(firstConst, size);
+
 	return Error::NONE;
 }
 
@@ -821,7 +896,7 @@ void disassembleShaderProgramBinary(const ShaderProgramBinary& binary, StringAut
 			lines.pushBackSprintf(ANKI_TAB ANKI_TAB "Specialization constants\n");
 			for(const ShaderProgramBinaryConstant& c : variant.m_reflection.m_specializationConstants)
 			{
-				lines.pushBackSprintf(ANKI_TAB ANKI_TAB ANKI_TAB "%16s type %4u id %4u\n",
+				lines.pushBackSprintf(ANKI_TAB ANKI_TAB ANKI_TAB "%-32s type %4u id %4u\n",
 					c.m_name.getBegin(),
 					U32(c.m_type),
 					c.m_constantId);

+ 2 - 0
src/anki/shader_compiler/ShaderProgramParser.cpp

@@ -79,6 +79,8 @@ static const char* SHADER_HEADER = R"(#version 450 core
 	layout(constant_id = id) const I32 ANKI_CONCATENATE(x, _0) = defltVal[0]; \
 	layout(constant_id = id + 1) const I32 ANKI_CONCATENATE(x, _1) = defltVal[1]; \
 	layout(constant_id = id + 2) const I32 ANKI_CONCATENATE(x, _2) = defltVal[2]
+
+#define ANKI_SPECIALIZATION_CONSTANT_F32(x, id, defltVal) layout(constant_id = id) const F32 x = defltVal
 )";
 
 ShaderProgramParser::ShaderProgramParser(CString fname,

+ 4 - 2
tests/shader_compiler/ShaderProgramCompiler.cpp

@@ -27,7 +27,7 @@ layout(set = 1, binding = 0) uniform u_
 #endif
 };
 
-layout(set = 1, binding = 1) uniform u2_
+layout(set = 1, binding = 1) buffer u2_
 {
 	Mat4 u_mvp2[INSTANCE_COUNT];
 #if PASS > 1
@@ -41,6 +41,8 @@ layout(set = 0, binding = 0) uniform texture2D u_tex[3];
 layout(set = 0, binding = 1) uniform sampler u_sampler;
 
 #pragma anki start vert
+ANKI_SPECIALIZATION_CONSTANT_F32(specConst, 1, 0);
+
 out gl_PerVertex
 {
 	Vec4 gl_Position;
@@ -48,7 +50,7 @@ out gl_PerVertex
 
 void main()
 {
-	gl_Position = u_mvp[gl_InstanceID] * u_mvp2[gl_InstanceID] * Vec4(gl_VertexID);
+	gl_Position = u_mvp[gl_InstanceID] * u_mvp2[gl_InstanceID] * Vec4(gl_VertexID) * Vec4(specConst);
 }
 #pragma anki end