Browse Source

Add support for multiple miss shaders

Panagiotis Christopoulos Charitos 5 years ago
parent
commit
c6984197a8

+ 3 - 3
src/anki/gr/ShaderProgram.h

@@ -27,7 +27,7 @@ class RayTracingShaders
 {
 public:
 	ShaderPtr m_rayGenShader;
-	ShaderPtr m_missShader;
+	WeakArray<ShaderPtr> m_missShaders;
 	WeakArray<RayTracingHitGroup> m_hitGroups;
 };
 
@@ -96,9 +96,9 @@ public:
 			rtMask |= ShaderTypeBit::RAY_GEN;
 		}
 
-		if(m_rayTracingShaders.m_missShader)
+		for(const ShaderPtr& s : m_rayTracingShaders.m_missShaders)
 		{
-			if(m_rayTracingShaders.m_missShader->getShaderType() != ShaderType::MISS)
+			if(s->getShaderType() != ShaderType::MISS)
 			{
 				return false;
 			}

+ 7 - 1
src/anki/gr/vulkan/GrManagerImpl.cpp

@@ -380,7 +380,13 @@ Error GrManagerImpl::initInstance(const GrManagerInitInfo& init)
 	count = 1;
 	ANKI_VK_CHECK(vkEnumeratePhysicalDevices(m_instance, &count, &m_physicalDevice));
 
-	vkGetPhysicalDeviceProperties(m_physicalDevice, &m_devProps);
+	m_rtProps.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PROPERTIES_KHR;
+	VkPhysicalDeviceProperties2 props = {};
+	props.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
+	props.pNext = &m_rtProps;
+
+	vkGetPhysicalDeviceProperties2(m_physicalDevice, &props);
+	m_devProps = props.properties;
 
 	// Find vendor
 	switch(m_devProps.vendorID)

+ 11 - 5
src/anki/gr/vulkan/GrManagerImpl.h

@@ -57,6 +57,11 @@ public:
 		return m_devProps;
 	}
 
+	const VkPhysicalDeviceRayTracingPropertiesKHR& getPhysicalDeviceRayTracingProperties() const
+	{
+		return m_rtProps;
+	}
+
 	TexturePtr acquireNextPresentableTexture();
 
 	void endFrame();
@@ -246,11 +251,12 @@ private:
 	VkQueue m_queue = VK_NULL_HANDLE;
 	Mutex m_globalMtx;
 
-	VkPhysicalDeviceProperties m_devProps{};
-	VkPhysicalDeviceFeatures m_devFeatures{};
-	VkPhysicalDeviceDescriptorIndexingFeatures m_descriptorIndexingFeatures{};
-	VkPhysicalDeviceBufferDeviceAddressFeatures m_bufferDeviceAddressFeatures{};
-	VkPhysicalDeviceRayTracingFeaturesKHR m_rtFeatures{};
+	VkPhysicalDeviceProperties m_devProps = {};
+	VkPhysicalDeviceRayTracingPropertiesKHR m_rtProps = {};
+	VkPhysicalDeviceFeatures m_devFeatures = {};
+	VkPhysicalDeviceDescriptorIndexingFeatures m_descriptorIndexingFeatures = {};
+	VkPhysicalDeviceBufferDeviceAddressFeatures m_bufferDeviceAddressFeatures = {};
+	VkPhysicalDeviceRayTracingFeaturesKHR m_rtFeatures = {};
 
 	PFN_vkDebugMarkerSetObjectNameEXT m_pfnDebugMarkerSetObjectNameEXT = nullptr;
 	PFN_vkCmdDebugMarkerBeginEXT m_pfnCmdDebugMarkerBeginEXT = nullptr;

+ 43 - 12
src/anki/gr/vulkan/ShaderProgramImpl.cpp

@@ -52,10 +52,15 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 	{
 		// Ray tracing
 
-		m_shaders.resizeStorage(getAllocator(), 2 + inf.m_rayTracingShaders.m_hitGroups.getSize());
+		m_shaders.resizeStorage(getAllocator(), 1 + inf.m_rayTracingShaders.m_missShaders.getSize()
+													+ inf.m_rayTracingShaders.m_hitGroups.getSize());
 
 		m_shaders.emplaceBack(getAllocator(), inf.m_rayTracingShaders.m_rayGenShader);
-		m_shaders.emplaceBack(getAllocator(), inf.m_rayTracingShaders.m_missShader);
+
+		for(const ShaderPtr& s : inf.m_rayTracingShaders.m_missShaders)
+		{
+			m_shaders.emplaceBack(getAllocator(), s);
+		}
 
 		for(const RayTracingHitGroup& group : inf.m_rayTracingShaders.m_hitGroups)
 		{
@@ -214,6 +219,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		ci.stage.module = shaderImpl.m_handle;
 		ci.stage.pSpecializationInfo = shaderImpl.getSpecConstInfo();
 
+		ANKI_TRACE_SCOPED_EVENT(VK_PIPELINE_CREATE);
 		ANKI_VK_CHECK(vkCreateComputePipelines(getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr,
 											   &m_compute.m_ppline));
 		getGrManagerImpl().printPipelineShaderInfo(m_compute.m_ppline, getName(), ShaderTypeBit::COMPUTE);
@@ -247,31 +253,46 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		defaultGroup.intersectionShader = VK_SHADER_UNUSED_KHR;
 
 		DynamicArrayAuto<VkRayTracingShaderGroupCreateInfoKHR> groups(
-			getAllocator(), 2 + inf.m_rayTracingShaders.m_hitGroups.getSize(), defaultGroup);
+			getAllocator(),
+			1 + inf.m_rayTracingShaders.m_missShaders.getSize() + inf.m_rayTracingShaders.m_hitGroups.getSize(),
+			defaultGroup);
 
 		// 1st group is the ray gen
 		groups[0].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
 		groups[0].generalShader = 0;
 
-		// 2nd group is the miss
-		groups[1].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
-		groups[1].generalShader = 1;
+		// Miss
+		U32 groupCount = 1;
+		U32 shaderCount = 1;
+		for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
+		{
+			groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
+			groups[groupCount].generalShader = shaderCount;
+			++groupCount;
+			++shaderCount;
+		}
 
 		// The rest of the groups are hit
 		for(U32 i = 0; i < inf.m_rayTracingShaders.m_hitGroups.getSize(); ++i)
 		{
-			groups[i + 2].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
+			groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
 			if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
 			{
-				groups[i + 2].anyHitShader = i + 2;
+				groups[groupCount].anyHitShader = shaderCount;
+				++shaderCount;
 			}
-			else
+
+			if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
 			{
 				ANKI_ASSERT(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader);
-				groups[i + 2].closestHitShader = i + 2;
+				groups[groupCount].closestHitShader = shaderCount;
+				++shaderCount;
 			}
 		}
 
+		ANKI_ASSERT(groupCount == groups.getSize());
+		ANKI_ASSERT(shaderCount == m_shaders.getSize());
+
 		VkRayTracingPipelineCreateInfoKHR ci = {};
 		ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
 		ci.stageCount = stages.getSize();
@@ -282,11 +303,21 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 		ci.libraries.sType = VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR;
 		ci.layout = m_pplineLayout.getHandle();
 
-		ANKI_VK_CHECK(vkCreateRayTracingPipelinesKHR(getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci,
-													 nullptr, &m_rt.m_rtPpline));
+		{
+			ANKI_TRACE_SCOPED_EVENT(VK_PIPELINE_CREATE);
+			ANKI_VK_CHECK(vkCreateRayTracingPipelinesKHR(getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci,
+														 nullptr, &m_rt.m_rtPpline));
+		}
+
+		createStb();
 	}
 
 	return Error::NONE;
 }
 
+void ShaderProgramImpl::createStb()
+{
+	// TODO
+}
+
 } // end namespace anki

+ 3 - 0
src/anki/gr/vulkan/ShaderProgramImpl.h

@@ -115,7 +115,10 @@ private:
 	{
 	public:
 		VkPipeline m_rtPpline = VK_NULL_HANDLE;
+		BufferPtr m_stb;
 	} m_rt;
+
+	void createStb();
 };
 /// @}