Sfoglia il codice sorgente

Add specialization constant support in GL and Vulkan backends

Panagiotis Christopoulos Charitos 7 anni fa
parent
commit
47e4216d58

+ 1 - 0
src/anki/gr/Common.h

@@ -41,6 +41,7 @@ const U MAX_VERTEX_ATTRIBUTES = 8;
 const U MAX_COLOR_ATTACHMENTS = 4;
 const U MAX_MIPMAPS = 16;
 const U MAX_TEXTURE_LAYERS = 32;
+const U MAX_SPECIALIZED_CONSTS = 64;
 
 const U MAX_TEXTURE_BINDINGS = 8;
 const U MAX_UNIFORM_BUFFER_BINDINGS = 5;

+ 36 - 0
src/anki/gr/Shader.h

@@ -66,12 +66,48 @@ void writeShaderBlockMemory(ShaderVariableDataType type,
 	void* buffBegin,
 	const void* buffEnd);
 
+/// Specialization constant value.
+class ShaderSpecializationConstValue
+{
+public:
+	union
+	{
+		F32 m_float;
+		I32 m_int;
+	};
+
+	ShaderVariableDataType m_dataType;
+
+	ShaderSpecializationConstValue()
+		: m_int(0)
+		, m_dataType(ShaderVariableDataType::NONE)
+	{
+	}
+
+	explicit ShaderSpecializationConstValue(F32 f)
+		: m_float(f)
+		, m_dataType(ShaderVariableDataType::FLOAT)
+	{
+	}
+
+	explicit ShaderSpecializationConstValue(I32 i)
+		: m_int(i)
+		, m_dataType(ShaderVariableDataType::INT)
+	{
+	}
+
+	ShaderSpecializationConstValue(const ShaderSpecializationConstValue&) = default;
+
+	ShaderSpecializationConstValue& operator=(const ShaderSpecializationConstValue&) = default;
+};
+
 /// Shader init info.
 class ShaderInitInfo : public GrBaseInitInfo
 {
 public:
 	ShaderType m_shaderType = ShaderType::COUNT;
 	ConstWeakArray<U8> m_binary = {};
+	ConstWeakArray<ShaderSpecializationConstValue> m_constValues;
 
 	ShaderInitInfo()
 	{

+ 7 - 0
src/anki/gr/ShaderCompiler.cpp

@@ -40,6 +40,7 @@ static const char* SHADER_HEADER = R"(#version 450 core
 #	define ANKI_SS_BINDING(set_, binding_) binding = set_ * %u + binding_
 #	define ANKI_TEX_BINDING(set_, binding_) binding = set_ * %u + binding_
 #	define ANKI_IMAGE_BINDING(set_, binding_) binding = set_ * %u + binding_
+#	define ANKI_SPEC_CONST(binding_, type_, name_) const type_ name_ = _anki_spec_const_ ## binding_
 #else
 #	define gl_VertexID gl_VertexIndex
 #	define gl_InstanceID gl_InstanceIndex
@@ -47,6 +48,7 @@ static const char* SHADER_HEADER = R"(#version 450 core
 #	define ANKI_UBO_BINDING(set_, binding_) set = set_, binding = %u + binding_
 #	define ANKI_SS_BINDING(set_, binding_) set = set_, binding = %u + binding_
 #	define ANKI_IMAGE_BINDING(set_, binding_) set = set_, binding = %u + binding_
+#	define ANKI_SPEC_CONST(binding_, type_, name_) layout(constant_id = binding_) const type_ name_ = type_(0)
 #endif
 
 #if %u
@@ -306,6 +308,7 @@ Error ShaderCompiler::compile(CString source, const ShaderCompilerOptions& optio
 	// Compile
 	if(options.m_outLanguage == ShaderLanguage::GLSL)
 	{
+#if 0
 		std::vector<unsigned int> spv;
 		err = genSpirv(ctx, spv);
 		if(!err)
@@ -316,6 +319,10 @@ Error ShaderCompiler::compile(CString source, const ShaderCompilerOptions& optio
 			bin.resize(newSrc.length() + 1);
 			memcpy(&bin[0], &newSrc[0], bin.getSize());
 		}
+#else
+		bin.resize(fullSrc.getLength() + 1);
+		memcpy(&bin[0], &fullSrc[0], bin.getSize());
+#endif
 	}
 	else
 	{

+ 14 - 3
src/anki/gr/gl/Shader.cpp

@@ -18,19 +18,30 @@ Shader* Shader::newInstance(GrManager* manager, const ShaderInitInfo& init)
 	public:
 		ShaderPtr m_shader;
 		StringAuto m_source;
+		DynamicArrayAuto<ShaderSpecializationConstValue> m_constValues;
 
-		ShaderCreateCommand(Shader* shader, ConstWeakArray<U8> bin, const CommandBufferAllocator<U8>& alloc)
+		ShaderCreateCommand(Shader* shader,
+			ConstWeakArray<U8> bin,
+			ConstWeakArray<ShaderSpecializationConstValue> constValues,
+			const CommandBufferAllocator<U8>& alloc)
 			: m_shader(shader)
 			, m_source(alloc)
+			, m_constValues(alloc)
 		{
 			m_source.create(reinterpret_cast<const char*>(&bin[0]));
+
+			if(constValues.getSize())
+			{
+				m_constValues.create(constValues.getSize());
+				memcpy(&m_constValues[0], &constValues[0], m_constValues.getByteSize());
+			}
 		}
 
 		Error operator()(GlState&)
 		{
 			ShaderImpl& impl = static_cast<ShaderImpl&>(*m_shader);
 
-			Error err = impl.init(m_source.toCString());
+			Error err = impl.init(m_source.toCString(), m_constValues);
 
 			GlObject::State oldState =
 				impl.setStateAtomically((err) ? GlObject::State::ERROR : GlObject::State::CREATED);
@@ -54,7 +65,7 @@ Shader* Shader::newInstance(GrManager* manager, const ShaderInitInfo& init)
 	CommandBufferImpl& cmdbimpl = static_cast<CommandBufferImpl&>(*cmdb);
 	CommandBufferAllocator<U8> alloc = cmdbimpl.getInternalAllocator();
 
-	cmdbimpl.pushBackNewCommand<ShaderCreateCommand>(impl, init.m_binary, alloc);
+	cmdbimpl.pushBackNewCommand<ShaderCreateCommand>(impl, init.m_binary, init.m_constValues, alloc);
 	cmdbimpl.flush();
 
 	return impl;

+ 41 - 3
src/anki/gr/gl/ShaderImpl.cpp

@@ -33,7 +33,7 @@ ShaderImpl::~ShaderImpl()
 	destroyDeferred(getManager(), deleteShaders);
 }
 
-Error ShaderImpl::init(const CString& source)
+Error ShaderImpl::init(CString source, ConstWeakArray<ShaderSpecializationConstValue> constValues)
 {
 	ANKI_ASSERT(source);
 	ANKI_ASSERT(!isCreated());
@@ -47,8 +47,46 @@ Error ShaderImpl::init(const CString& source)
 
 	m_glType = gltype[U(m_shaderType)];
 
-	// 2) Gen name, create, compile and link
-	//
+	// Create a new shader with spec consts if needed
+	StringAuto newSrc(getAllocator());
+	if(constValues.getSize())
+	{
+		// Create const str
+		StringListAuto constStrLines(getAllocator());
+		U count = 0;
+		for(const ShaderSpecializationConstValue& constVal : constValues)
+		{
+			if(constVal.m_dataType == ShaderVariableDataType::INT)
+			{
+				constStrLines.pushBackSprintf("#define _anki_spec_const_%u %i", count, constVal.m_int);
+			}
+			else
+			{
+				ANKI_ASSERT(constVal.m_dataType == ShaderVariableDataType::FLOAT);
+				constStrLines.pushBackSprintf("#define _anki_spec_const_%u %f", count, constVal.m_float);
+			}
+
+			++count;
+		}
+		StringAuto constStr(getAllocator());
+		constStrLines.join("\n", constStr);
+
+		// Break the old source
+		StringListAuto lines(getAllocator());
+		lines.splitString(source, '\n');
+		ANKI_ASSERT(lines.getFront().find("#version") == 0);
+		lines.popFront();
+
+		// Append the const values
+		lines.pushFront(constStr.toCString());
+		lines.pushFront("#version 450 core");
+
+		// Create the new string
+		lines.join("\n", newSrc);
+		source = newSrc.toCString();
+	}
+
+	// Gen name, create and compile
 	const char* sourceStrs[1] = {nullptr};
 	sourceStrs[0] = &source[0];
 	m_glName = glCreateShader(m_glType);

+ 1 - 1
src/anki/gr/gl/ShaderImpl.h

@@ -35,7 +35,7 @@ public:
 		m_shaderType = init.m_shaderType;
 	}
 
-	ANKI_USE_RESULT Error init(const CString& source);
+	ANKI_USE_RESULT Error init(CString source, ConstWeakArray<ShaderSpecializationConstValue> constValues);
 };
 /// @}
 

+ 46 - 3
src/anki/gr/vulkan/ShaderImpl.cpp

@@ -6,7 +6,6 @@
 #include <anki/gr/vulkan/ShaderImpl.h>
 #include <anki/gr/vulkan/GrManagerImpl.h>
 #include <anki/gr/common/Misc.h>
-#include <anki/core/Trace.h>
 #include <SPIRV-Cross/spirv_cross.hpp>
 
 #define ANKI_DUMP_SHADERS ANKI_EXTRA_CHECKS
@@ -30,6 +29,18 @@ ShaderImpl::~ShaderImpl()
 	{
 		vkDestroyShaderModule(getDevice(), m_handle, nullptr);
 	}
+
+	if(m_specConstInfo.pMapEntries)
+	{
+		getAllocator().deleteArray(
+			const_cast<VkSpecializationMapEntry*>(m_specConstInfo.pMapEntries), m_specConstInfo.mapEntryCount);
+	}
+
+	if(m_specConstInfo.pData)
+	{
+		getAllocator().deleteArray(
+			static_cast<I32*>(const_cast<void*>(m_specConstInfo.pData)), m_specConstInfo.dataSize / sizeof(I32));
+	}
 }
 
 Error ShaderImpl::init(const ShaderInitInfo& inf)
@@ -67,12 +78,41 @@ Error ShaderImpl::init(const ShaderInitInfo& inf)
 	ANKI_VK_CHECK(vkCreateShaderModule(getDevice(), &ci, nullptr, &m_handle));
 
 	// Get reflection info
-	doReflection(inf.m_binary);
+	std::vector<spirv_cross::SpecializationConstant> specConstIds;
+	doReflection(inf.m_binary, specConstIds);
+
+	// Set spec info
+	if(specConstIds.size())
+	{
+		const U constCount = specConstIds.size();
+
+		m_specConstInfo.mapEntryCount = constCount;
+		m_specConstInfo.pMapEntries = getAllocator().newArray<VkSpecializationMapEntry>(constCount);
+		m_specConstInfo.dataSize = constCount * sizeof(I32);
+		m_specConstInfo.pData = getAllocator().newArray<I32>(constCount);
+
+		U count = 0;
+		for(const spirv_cross::SpecializationConstant& sconst : specConstIds)
+		{
+			// Set the entry
+			VkSpecializationMapEntry& entry = const_cast<VkSpecializationMapEntry&>(m_specConstInfo.pMapEntries[count]);
+			entry.constantID = sconst.constant_id;
+			entry.offset = count * sizeof(I32);
+			entry.size = sizeof(I32);
+
+			// Copy the data
+			U8* data = static_cast<U8*>(const_cast<void*>(m_specConstInfo.pData));
+			data += entry.offset;
+			*reinterpret_cast<I32*>(data) = inf.m_constValues[sconst.constant_id].m_int;
+
+			++count;
+		}
+	}
 
 	return Error::NONE;
 }
 
-void ShaderImpl::doReflection(ConstWeakArray<U8> spirv)
+void ShaderImpl::doReflection(ConstWeakArray<U8> spirv, std::vector<spirv_cross::SpecializationConstant>& specConstIds)
 {
 	spirv_cross::Compiler spvc(reinterpret_cast<const uint32_t*>(&spirv[0]), spirv.getSize() / sizeof(unsigned int));
 	spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
@@ -143,6 +183,9 @@ void ShaderImpl::doReflection(ConstWeakArray<U8> spirv)
 			m_attributeMask.set(location);
 		}
 	}
+
+	// Spec consts
+	specConstIds = spvc.get_specialization_constants();
 }
 
 } // end namespace anki

+ 14 - 4
src/anki/gr/vulkan/ShaderImpl.h

@@ -9,7 +9,13 @@
 #include <anki/gr/vulkan/VulkanObject.h>
 #include <anki/gr/vulkan/DescriptorSet.h>
 #include <anki/util/BitSet.h>
-#include <vector>
+#include <iosfwd>
+
+// Forward
+namespace spirv_cross
+{
+struct SpecializationConstant;
+} // end namespace spirv_cross
 
 namespace anki
 {
@@ -38,11 +44,15 @@ public:
 
 	ANKI_USE_RESULT Error init(const ShaderInitInfo& init);
 
+	const VkSpecializationInfo* getSpecConstInfo() const
+	{
+		return (m_specConstInfo.mapEntryCount) ? &m_specConstInfo : nullptr;
+	}
+
 private:
-	/// Generate SPIRV from GLSL.
-	ANKI_USE_RESULT Error genSpirv(const CString& source, std::vector<unsigned int>& spirv);
+	VkSpecializationInfo m_specConstInfo = {};
 
-	void doReflection(ConstWeakArray<U8> spirv);
+	void doReflection(ConstWeakArray<U8> spirv, std::vector<spirv_cross::SpecializationConstant>& specConstIds);
 };
 /// @}
 

+ 8 - 2
src/anki/gr/vulkan/ShaderProgramImpl.cpp

@@ -130,12 +130,15 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 				continue;
 			}
 
+			const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[stype]);
+
 			VkPipelineShaderStageCreateInfo& inf = m_shaderCreateInfos[m_shaderCreateInfoCount++];
 			inf = {};
 			inf.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
 			inf.stage = convertShaderTypeBit(static_cast<ShaderTypeBit>(1 << stype));
 			inf.pName = "main";
-			inf.module = scast<const ShaderImpl*>(m_shaders[stype].get())->m_handle;
+			inf.module = shaderImpl.m_handle;
+			inf.pSpecializationInfo = shaderImpl.getSpecConstInfo();
 		}
 	}
 
@@ -152,6 +155,8 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 	//
 	if(!graphicsProg)
 	{
+		const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[ShaderType::COMPUTE]);
+
 		VkComputePipelineCreateInfo ci = {};
 		ci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
 		ci.layout = m_pplineLayout.getHandle();
@@ -159,7 +164,8 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		ci.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
 		ci.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
 		ci.stage.pName = "main";
-		ci.stage.module = scast<const ShaderImpl*>(m_shaders[ShaderType::COMPUTE].get())->m_handle;
+		ci.stage.module = shaderImpl.m_handle;
+		ci.stage.pSpecializationInfo = shaderImpl.getSpecConstInfo();
 
 		ANKI_VK_CHECK(vkCreateComputePipelines(
 			getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr, &m_computePpline));

+ 5 - 0
src/anki/util/BitSet.h

@@ -130,6 +130,11 @@ public:
 		return !getAny();
 	}
 
+	operator bool() const
+	{
+		return getAny();
+	}
+
 	/// Set or unset a bit at the given position.
 	template<typename TInt>
 	void set(TInt pos, Bool setBits = true)

+ 23 - 0
src/anki/util/StringList.h

@@ -74,6 +74,16 @@ public:
 		Base::getBack() = std::move(str);
 	}
 
+	/// Push front plain CString
+	void pushFront(Allocator alloc, CString cstr)
+	{
+		String str;
+		str.create(alloc, cstr);
+
+		Base::emplaceFront(alloc);
+		Base::getFront() = std::move(str);
+	}
+
 	/// Split a string using a separator (@a separator) and return these strings in a string list.
 	void splitString(Allocator alloc, const CString& s, const Char separator, Bool keepEmpty = false);
 };
@@ -130,6 +140,19 @@ public:
 		Base::pushBack(m_alloc, cstr);
 	}
 
+	/// Push front plain CString.
+	void pushFront(CString cstr)
+	{
+		Base::pushFront(m_alloc, cstr);
+	}
+
+	/// Pop front element.
+	void popFront()
+	{
+		getFront().destroy(m_alloc);
+		Base::popFront(m_alloc);
+	}
+
 	/// Split a string using a separator (@a separator) and return these strings in a string list.
 	void splitString(const CString& s, const Base::Char separator, Bool keepEmpty = false)
 	{

+ 112 - 2
tests/gr/Gr.cpp

@@ -352,7 +352,8 @@ static void* setStorage(PtrSize size, CommandBufferPtr& cmdb, U set, U binding)
 
 const PixelFormat DS_FORMAT = PixelFormat(ComponentFormat::D24S8, TransformFormat::UNORM);
 
-static ShaderPtr createShader(CString src, ShaderType type, GrManager& gr)
+static ShaderPtr createShader(
+	CString src, ShaderType type, GrManager& gr, ConstWeakArray<ShaderSpecializationConstValue> specVals = {})
 {
 	HeapAllocator<U8> alloc(allocAligned, nullptr);
 	ShaderCompiler comp(alloc);
@@ -364,7 +365,10 @@ static ShaderPtr createShader(CString src, ShaderType type, GrManager& gr)
 	DynamicArrayAuto<U8> bin(alloc);
 	ANKI_TEST_EXPECT_NO_ERR(comp.compile(src, options, bin));
 
-	return gr.newShader({type, WeakArray<U8>(&bin[0], bin.getSize())});
+	ShaderInitInfo initInf(type, WeakArray<U8>(&bin[0], bin.getSize()));
+	initInf.m_constValues = specVals;
+
+	return gr.newShader(initInf);
 }
 
 static ShaderProgramPtr createProgram(CString vertSrc, CString fragSrc, GrManager& gr)
@@ -1944,4 +1948,110 @@ void main()
 	COMMON_END()
 }
 
+ANKI_TEST(Gr, SpecConsts)
+{
+	COMMON_BEGIN()
+
+	static const char* VERT_SRC = R"(
+ANKI_SPEC_CONST(0, int, const0);
+ANKI_SPEC_CONST(2, float, const1);
+
+out gl_PerVertex
+{
+	vec4 gl_Position;
+};
+
+layout(location = 0) flat out int out_const0;
+layout(location = 1) flat out float out_const1;
+
+void main()
+{
+	vec2 uv = vec2(gl_VertexID & 1, gl_VertexID >> 1) * 2.0;
+	vec2 pos = uv * 2.0 - 1.0;
+
+	gl_Position = vec4(pos, 0.0, 1.0);
+
+	out_const0 = const0;
+	out_const1 = const1;
+}
+)";
+
+	static const char* FRAG_SRC = R"(
+ANKI_SPEC_CONST(0, int, const0);
+ANKI_SPEC_CONST(1, float, const1);
+
+layout(location = 0) flat in int in_const0;
+layout(location = 1) flat in float in_const1;
+
+layout(location = 0) out vec4 out_color;
+
+layout(ANKI_SS_BINDING(0, 0)) buffer s_
+{
+	uvec4 u_result;
+};
+
+void main()
+{
+	out_color = vec4(1.0);
+
+	if(gl_FragCoord.x == 0.5 && gl_FragCoord.y == 0.5)
+	{
+		if(in_const0 != 2147483647 || in_const1 != 1234.5678 || const0 != -2147483647 || const1 != -1.0)
+		{
+			u_result = uvec4(1u);
+		}
+		else
+		{
+			u_result = uvec4(2u);
+		}
+	}
+}
+)";
+
+	ShaderPtr vert = createShader(VERT_SRC,
+		ShaderType::VERTEX,
+		*gr,
+		Array<ShaderSpecializationConstValue, 3>{{ShaderSpecializationConstValue(2147483647),
+			ShaderSpecializationConstValue(-1.0f),
+			ShaderSpecializationConstValue(1234.5678f)}});
+	ShaderPtr frag = createShader(FRAG_SRC,
+		ShaderType::FRAGMENT,
+		*gr,
+		Array<ShaderSpecializationConstValue, 2>{
+			{ShaderSpecializationConstValue(-2147483647), ShaderSpecializationConstValue(-1.0f)}});
+	ShaderProgramPtr prog = gr->newShaderProgram(ShaderProgramInitInfo(vert, frag));
+
+	// Create the result buffer
+	BufferPtr resultBuff =
+		gr->newBuffer(BufferInitInfo(sizeof(UVec4), BufferUsageBit::STORAGE_COMPUTE_WRITE, BufferMapAccessBit::READ));
+
+	// Draw
+	gr->beginFrame();
+
+	CommandBufferInitInfo cinit;
+	cinit.m_flags = CommandBufferFlag::GRAPHICS_WORK;
+	CommandBufferPtr cmdb = gr->newCommandBuffer(cinit);
+
+	cmdb->setViewport(0, 0, WIDTH, HEIGHT);
+	cmdb->bindShaderProgram(prog);
+	cmdb->bindStorageBuffer(0, 0, resultBuff, 0, resultBuff->getSize());
+	cmdb->beginRenderPass(createDefaultFb(*gr), {}, {});
+	cmdb->drawArrays(PrimitiveTopology::TRIANGLES, 3);
+	cmdb->endRenderPass();
+	cmdb->flush();
+
+	gr->swapBuffers();
+	gr->finish();
+
+	// Get the result
+	UVec4* result = static_cast<UVec4*>(resultBuff->map(0, resultBuff->getSize(), BufferMapAccessBit::READ));
+	ANKI_TEST_EXPECT_EQ(result->x(), 2);
+	ANKI_TEST_EXPECT_EQ(result->y(), 2);
+	ANKI_TEST_EXPECT_EQ(result->z(), 2);
+	ANKI_TEST_EXPECT_EQ(result->w(), 2);
+	resultBuff->unmap();
+
+	COMMON_END()
+}
+
 } // end namespace anki