Przeglądaj źródła

Complete the RT work in the shader program resource

Panagiotis Christopoulos Charitos 5 lat temu
rodzic
commit
33bc6bea5b

+ 6 - 0
anki/resource/ResourceManager.h

@@ -210,6 +210,12 @@ public:
 	/// Get the total number of completed async tasks.
 	ANKI_INTERNAL U64 getAsyncTaskCompletedCount() const;
 
+	/// Return the container of program libraries.
+	const ShaderProgramResourceSystem& getShaderProgramResourceSystem() const
+	{
+		return *m_shaderProgramSystem;
+	}
+
 private:
 	GrManager* m_gr = nullptr;
 	PhysicsWorld* m_physics = nullptr;

+ 43 - 3
anki/resource/ShaderProgramResource.cpp

@@ -5,6 +5,7 @@
 
 #include <anki/resource/ShaderProgramResource.h>
 #include <anki/resource/ResourceManager.h>
+#include <anki/resource/ShaderProgramResourceSystem.h>
 #include <anki/gr/ShaderProgram.h>
 #include <anki/gr/GrManager.h>
 #include <anki/util/Filesystem.h>
@@ -269,16 +270,17 @@ void ShaderProgramResource::initVariant(const ShaderProgramResourceVariantInitIn
 
 	// Get the binary program variant
 	const ShaderProgramBinaryVariant* binaryVariant = nullptr;
+	U64 mutationHash = 0;
 	if(m_mutators.getSize())
 	{
 		// Create the mutation hash
-		const U64 hash = computeHash(info.m_mutation.getBegin(), m_mutators.getSize() * sizeof(info.m_mutation[0]));
+		mutationHash = computeHash(info.m_mutation.getBegin(), m_mutators.getSize() * sizeof(info.m_mutation[0]));
 
 		// Search for the mutation in the binary
 		// TODO optimize the search
 		for(const ShaderProgramBinaryMutation& mutation : binary.m_mutations)
 		{
-			if(mutation.m_hash == hash)
+			if(mutation.m_hash == mutationHash)
 			{
 				binaryVariant = &binary.m_variants[mutation.m_variantIndex];
 				break;
@@ -410,7 +412,45 @@ void ShaderProgramResource::initVariant(const ShaderProgramResourceVariantInitIn
 	else
 	{
 		ANKI_ASSERT(!!(m_shaderStages & ShaderTypeBit::ALL_RAY_TRACING));
-		ANKI_ASSERT(!"TODO");
+
+		// Find the library
+		CString libName = &binary.m_libraryName[0];
+		ANKI_ASSERT(libName.getLength() > 0);
+
+		const ShaderProgramResourceSystem& progSystem = getManager().getShaderProgramResourceSystem();
+		const ShaderProgramRaytracingLibrary* foundLib = nullptr;
+		for(const ShaderProgramRaytracingLibrary& lib : progSystem.getRayTracingLibraries())
+		{
+			if(lib.getLibraryName() == libName)
+			{
+				foundLib = &lib;
+				break;
+			}
+		}
+		ANKI_ASSERT(foundLib);
+
+		variant.m_prog = foundLib->getShaderProgram();
+
+		// Set the group handle
+		const U32 groupHandleSize = getManager().getGrManager().getDeviceCapabilities().m_shaderGroupHandleSize;
+		variant.m_hitShaderGroupHandleSize = U8(groupHandleSize);
+
+		ANKI_ASSERT(sizeof(variant.m_hitShaderGroupHandle) >= groupHandleSize);
+		WeakArray<U8> handle(&variant.m_hitShaderGroupHandle[0], groupHandleSize);
+		if(m_shaderStages == ShaderTypeBit::RAY_GEN)
+		{
+			foundLib->getRayGenShaderGroupHandle(handle);
+		}
+		else if(m_shaderStages == ShaderTypeBit::MISS)
+		{
+			const U32 rayType = binary.m_rayType;
+			foundLib->getMissShaderGroupHandle(rayType, handle);
+		}
+		else
+		{
+			ANKI_ASSERT(!!(m_shaderStages & (ShaderTypeBit::ANY_HIT | ShaderTypeBit::CLOSEST_HIT)));
+			foundLib->getHitShaderGroupHandle(getFilename(), mutationHash, handle);
+		}
 	}
 }
 

+ 4 - 3
anki/resource/ShaderProgramResource.h

@@ -93,8 +93,8 @@ public:
 	/// Only for hit ray tracing programs.
 	ConstWeakArray<U8> getHitShaderGroupHandle() const
 	{
-		ANKI_ASSERT(m_hitShaderGroupHandle.getSize() > 0);
-		return m_hitShaderGroupHandle;
+		ANKI_ASSERT(m_hitShaderGroupHandleSize > 0);
+		return ConstWeakArray<U8>(&m_hitShaderGroupHandle[0], m_hitShaderGroupHandleSize);
 	}
 
 private:
@@ -102,7 +102,8 @@ private:
 	const ShaderProgramBinaryVariant* m_binaryVariant = nullptr;
 	BitSet<128, U64> m_activeConsts = {false};
 	Array<U32, 3> m_workgroupSizes;
-	DynamicArray<U8> m_hitShaderGroupHandle; ///< Hit shaders group handle.
+	Array<U8, 32> m_hitShaderGroupHandle = {}; ///< Cache the handle here.
+	U8 m_hitShaderGroupHandleSize = 0;
 };
 
 /// The value of a constant.

+ 3 - 2
anki/resource/ShaderProgramResourceSystem.cpp

@@ -460,6 +460,7 @@ Error ShaderProgramResourceSystem::createRayTracingPrograms(CString cacheDir, Gr
 			}
 
 			// Iterate all mutations
+			// TODO What if there are no mutation?
 			for(U32 m = 0; m < binary.m_mutations.getSize(); ++m)
 			{
 				const ShaderProgramBinaryMutation& mutation = binary.m_mutations[m];
@@ -467,7 +468,7 @@ Error ShaderProgramResourceSystem::createRayTracingPrograms(CString cacheDir, Gr
 
 				// Generate the hash
 				const U64 hitGroupHash =
-					ShaderProgramRaytracingLibrary::generateHitGroupHash(filename, mutation.m_values);
+					ShaderProgramRaytracingLibrary::generateHitGroupHash(filename, mutation.m_hash);
 
 				HitGroup hitGroup;
 				hitGroup.m_hitGroupHash = hitGroupHash;
@@ -595,7 +596,7 @@ Error ShaderProgramResourceSystem::createRayTracingPrograms(CString cacheDir, Gr
 					const HitGroup& inHitGroup = inRayType.m_hitGroups[hitGroupIdx];
 
 					outLib.m_groupHashToGroupIndex.emplace(alloc, inHitGroup.m_hitGroupHash,
-														   initInfoHitGroups.getSize());
+														   initInfoHitGroups.getSize() + outLib.m_rayTypeCount + 1);
 
 					RayTracingHitGroup* infoHitGroup = initInfoHitGroups.emplaceBack();
 					if(inHitGroup.m_ahit != MAX_U32)

+ 9 - 6
anki/resource/ShaderProgramResourceSystem.h

@@ -39,10 +39,9 @@ public:
 	}
 
 	/// Given the filename of a program (that contains hit shaders) and a specific mutation get the group handle.
-	void getHitShaderGroupHandle(CString resourceFilename, ConstWeakArray<MutatorValue> mutation,
-								 WeakArray<U8>& handle) const
+	void getHitShaderGroupHandle(CString resourceFilename, U64 mutationHash, WeakArray<U8>& handle) const
 	{
-		const U32 hitGroupIndex = getHitGroupIndex(generateHitGroupHash(resourceFilename, mutation));
+		const U32 hitGroupIndex = getHitGroupIndex(generateHitGroupHash(resourceFilename, mutationHash));
 		getShaderGroupHandle(hitGroupIndex, handle);
 	}
 
@@ -52,6 +51,11 @@ public:
 		getShaderGroupHandle(rayType + 1, handle);
 	}
 
+	void getRayGenShaderGroupHandle(WeakArray<U8>& handle) const
+	{
+		getShaderGroupHandle(0, handle);
+	}
+
 private:
 	String m_libraryName;
 	U32 m_rayTypeCount = MAX_U32;
@@ -59,11 +63,10 @@ private:
 	HashMap<U64, U32> m_groupHashToGroupIndex;
 
 	/// Given the filename of a program (that contains hit shaders) and a specific mutation get a hash back.
-	static U64 generateHitGroupHash(CString resourceFilename, ConstWeakArray<MutatorValue> mutation)
+	static U64 generateHitGroupHash(CString resourceFilename, U64 mutationHash)
 	{
 		ANKI_ASSERT(resourceFilename.getLength() > 0);
-		U64 hash = computeHash(resourceFilename.cstr(), resourceFilename.getLength());
-		hash = appendHash(mutation.getBegin(), mutation.getSizeInBytes(), hash);
+		const U64 hash = appendHash(resourceFilename.cstr(), resourceFilename.getLength(), mutationHash);
 		return hash;
 	}