|
|
@@ -26,6 +26,7 @@ ShaderProgramImpl::~ShaderProgramImpl()
|
|
|
}
|
|
|
|
|
|
m_shaders.destroy(getAllocator());
|
|
|
+ m_rt.m_allHandles.destroy(getAllocator());
|
|
|
}
|
|
|
|
|
|
Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
@@ -34,6 +35,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
|
|
|
// Create the shader references
|
|
|
//
|
|
|
+ HashMapAuto<U64, U32> shaderUuidToMShadersIdx(getAllocator()); // Shader UUID to m_shaders idx
|
|
|
if(inf.m_computeShader)
|
|
|
{
|
|
|
m_shaders.emplaceBack(getAllocator(), inf.m_computeShader);
|
|
|
@@ -52,8 +54,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
{
|
|
|
// Ray tracing
|
|
|
|
|
|
- m_shaders.resizeStorage(getAllocator(), 1 + inf.m_rayTracingShaders.m_missShaders.getSize()
|
|
|
- + inf.m_rayTracingShaders.m_hitGroups.getSize());
|
|
|
+ m_shaders.resizeStorage(getAllocator(), 1 + inf.m_rayTracingShaders.m_missShaders.getSize());
|
|
|
|
|
|
m_shaders.emplaceBack(getAllocator(), inf.m_rayTracingShaders.m_rayGenShader);
|
|
|
|
|
|
@@ -66,12 +67,22 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
{
|
|
|
if(group.m_anyHitShader)
|
|
|
{
|
|
|
- m_shaders.emplaceBack(getAllocator(), group.m_anyHitShader);
|
|
|
+ auto it = shaderUuidToMShadersIdx.find(group.m_anyHitShader->getUuid());
|
|
|
+ if(it == shaderUuidToMShadersIdx.getEnd())
|
|
|
+ {
|
|
|
+ shaderUuidToMShadersIdx.emplace(group.m_anyHitShader->getUuid(), m_shaders.getSize());
|
|
|
+ m_shaders.emplaceBack(getAllocator(), group.m_anyHitShader);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if(group.m_closestHitShader)
|
|
|
{
|
|
|
- m_shaders.emplaceBack(getAllocator(), group.m_closestHitShader);
|
|
|
+ auto it = shaderUuidToMShadersIdx.find(group.m_closestHitShader->getUuid());
|
|
|
+ if(it == shaderUuidToMShadersIdx.getEnd())
|
|
|
+ {
|
|
|
+ shaderUuidToMShadersIdx.emplace(group.m_closestHitShader->getUuid(), m_shaders.getSize());
|
|
|
+ m_shaders.emplaceBack(getAllocator(), group.m_closestHitShader);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -263,13 +274,11 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
|
|
|
// Miss
|
|
|
U32 groupCount = 1;
|
|
|
- U32 shaderCount = 1;
|
|
|
for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
|
|
|
{
|
|
|
groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
|
|
|
- groups[groupCount].generalShader = shaderCount;
|
|
|
+ groups[groupCount].generalShader = 0;
|
|
|
++groupCount;
|
|
|
- ++shaderCount;
|
|
|
}
|
|
|
|
|
|
// The rest of the groups are hit
|
|
|
@@ -278,20 +287,20 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
|
|
|
if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
|
|
|
{
|
|
|
- groups[groupCount].anyHitShader = shaderCount;
|
|
|
- ++shaderCount;
|
|
|
+ groups[groupCount].anyHitShader =
|
|
|
+ *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader->getUuid());
|
|
|
}
|
|
|
|
|
|
if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
|
|
|
{
|
|
|
- ANKI_ASSERT(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader);
|
|
|
- groups[groupCount].closestHitShader = shaderCount;
|
|
|
- ++shaderCount;
|
|
|
+ groups[groupCount].closestHitShader =
|
|
|
+ *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader->getUuid());
|
|
|
}
|
|
|
+
|
|
|
+ ++groupCount;
|
|
|
}
|
|
|
|
|
|
ANKI_ASSERT(groupCount == groups.getSize());
|
|
|
- ANKI_ASSERT(shaderCount == m_shaders.getSize());
|
|
|
|
|
|
VkRayTracingPipelineCreateInfoKHR ci = {};
|
|
|
ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
|
|
|
@@ -309,15 +318,15 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
|
|
|
nullptr, &m_rt.m_rtPpline));
|
|
|
}
|
|
|
|
|
|
- createStb();
|
|
|
+ // Get RT handles
|
|
|
+ const U32 handleArraySize =
|
|
|
+ getGrManagerImpl().getPhysicalDeviceRayTracingProperties().shaderGroupHandleSize * groupCount;
|
|
|
+ m_rt.m_allHandles.create(getAllocator(), handleArraySize, 0);
|
|
|
+ ANKI_VK_CHECK(vkGetRayTracingShaderGroupHandlesKHR(getDevice(), m_rt.m_rtPpline, 0, groupCount, handleArraySize,
|
|
|
+ &m_rt.m_allHandles[0]));
|
|
|
}
|
|
|
|
|
|
return Error::NONE;
|
|
|
}
|
|
|
|
|
|
-void ShaderProgramImpl::createStb()
|
|
|
-{
|
|
|
- // TODO
|
|
|
-}
|
|
|
-
|
|
|
} // end namespace anki
|