Browse Source

Minor RT ppline fixes

Panagiotis Christopoulos Charitos 5 years ago
parent
commit
f985b4d185
2 changed files with 29 additions and 21 deletions
  1. 28 19
      src/anki/gr/vulkan/ShaderProgramImpl.cpp
  2. 1 2
      src/anki/gr/vulkan/ShaderProgramImpl.h

+ 28 - 19
src/anki/gr/vulkan/ShaderProgramImpl.cpp

@@ -26,6 +26,7 @@ ShaderProgramImpl::~ShaderProgramImpl()
 	}
 
 	m_shaders.destroy(getAllocator());
+	m_rt.m_allHandles.destroy(getAllocator());
 }
 
 Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
@@ -34,6 +35,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 
 	// Create the shader references
 	//
+	HashMapAuto<U64, U32> shaderUuidToMShadersIdx(getAllocator()); // Shader UUID to m_shaders idx
 	if(inf.m_computeShader)
 	{
 		m_shaders.emplaceBack(getAllocator(), inf.m_computeShader);
@@ -52,8 +54,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 	{
 		// Ray tracing
 
-		m_shaders.resizeStorage(getAllocator(), 1 + inf.m_rayTracingShaders.m_missShaders.getSize()
-													+ inf.m_rayTracingShaders.m_hitGroups.getSize());
+		m_shaders.resizeStorage(getAllocator(), 1 + inf.m_rayTracingShaders.m_missShaders.getSize());
 
 		m_shaders.emplaceBack(getAllocator(), inf.m_rayTracingShaders.m_rayGenShader);
 
@@ -66,12 +67,22 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		{
 			if(group.m_anyHitShader)
 			{
-				m_shaders.emplaceBack(getAllocator(), group.m_anyHitShader);
+				auto it = shaderUuidToMShadersIdx.find(group.m_anyHitShader->getUuid());
+				if(it == shaderUuidToMShadersIdx.getEnd())
+				{
+					shaderUuidToMShadersIdx.emplace(group.m_anyHitShader->getUuid(), m_shaders.getSize());
+					m_shaders.emplaceBack(getAllocator(), group.m_anyHitShader);
+				}
 			}
 
 			if(group.m_closestHitShader)
 			{
-				m_shaders.emplaceBack(getAllocator(), group.m_closestHitShader);
+				auto it = shaderUuidToMShadersIdx.find(group.m_closestHitShader->getUuid());
+				if(it == shaderUuidToMShadersIdx.getEnd())
+				{
+					shaderUuidToMShadersIdx.emplace(group.m_closestHitShader->getUuid(), m_shaders.getSize());
+					m_shaders.emplaceBack(getAllocator(), group.m_closestHitShader);
+				}
 			}
 		}
 	}
@@ -263,13 +274,11 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 
 		// Miss
 		U32 groupCount = 1;
-		U32 shaderCount = 1;
 		for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
 		{
 			groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
-			groups[groupCount].generalShader = shaderCount;
+			groups[groupCount].generalShader = 0;
 			++groupCount;
-			++shaderCount;
 		}
 
 		// The rest of the groups are hit
@@ -278,20 +287,20 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 			groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
 			if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
 			{
-				groups[groupCount].anyHitShader = shaderCount;
-				++shaderCount;
+				groups[groupCount].anyHitShader =
+					*shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader->getUuid());
 			}
 
 			if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
 			{
-				ANKI_ASSERT(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader);
-				groups[groupCount].closestHitShader = shaderCount;
-				++shaderCount;
+				groups[groupCount].closestHitShader =
+					*shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader->getUuid());
 			}
+
+			++groupCount;
 		}
 
 		ANKI_ASSERT(groupCount == groups.getSize());
-		ANKI_ASSERT(shaderCount == m_shaders.getSize());
 
 		VkRayTracingPipelineCreateInfoKHR ci = {};
 		ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
@@ -309,15 +318,15 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 														 nullptr, &m_rt.m_rtPpline));
 		}
 
-		createStb();
+		// Get RT handles
+		const U32 handleArraySize =
+			getGrManagerImpl().getPhysicalDeviceRayTracingProperties().shaderGroupHandleSize * groupCount;
+		m_rt.m_allHandles.create(getAllocator(), handleArraySize, 0);
+		ANKI_VK_CHECK(vkGetRayTracingShaderGroupHandlesKHR(getDevice(), m_rt.m_rtPpline, 0, groupCount, handleArraySize,
+														   &m_rt.m_allHandles[0]));
 	}
 
 	return Error::NONE;
 }
 
-void ShaderProgramImpl::createStb()
-{
-	// TODO
-}
-
 } // end namespace anki

+ 1 - 2
src/anki/gr/vulkan/ShaderProgramImpl.h

@@ -116,9 +116,8 @@ private:
 	public:
 		VkPipeline m_rtPpline = VK_NULL_HANDLE;
 		BufferPtr m_stb;
+		DynamicArray<U8> m_allHandles;
 	} m_rt;
-
-	void createStb();
 };
 /// @}