Browse Source

Write an RTP unit test

Panagiotis Christopoulos Charitos 4 months ago
parent
commit
8c1d5a904d
3 changed files with 317 additions and 4 deletions
  1. 3 3
      AnKi/Gr/D3D/D3DShaderProgram.cpp
  2. 1 1
      AnKi/Gr/D3D/D3DShaderProgram.h
  3. 313 0
      Tests/Gr/Gr.cpp

+ 3 - 3
AnKi/Gr/D3D/D3DShaderProgram.cpp

@@ -74,7 +74,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 
 	ANKI_ASSERT(m_shaders.getSize() > 0);
 
-	for(ShaderPtr& shader : m_shaders)
+	for(ShaderInternalPtr& shader : m_shaders)
 	{
 		m_shaderTypes |= ShaderTypeBit(1 << shader->getShaderType());
 	}
@@ -87,7 +87,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 	// Link reflection
 	ShaderReflection refl;
 	Bool firstLink = true;
-	for(ShaderPtr& shader : m_shaders)
+	for(ShaderInternalPtr& shader : m_shaders)
 	{
 		const ShaderImpl& simpl = static_cast<const ShaderImpl&>(*shader);
 		if(firstLink)
@@ -186,7 +186,7 @@ Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
 	}
 
 	// Get shader sizes and a few other things
-	for(const ShaderPtr& s : m_shaders)
+	for(const ShaderInternalPtr& s : m_shaders)
 	{
 		if(!s.isCreated())
 		{

+ 1 - 1
AnKi/Gr/D3D/D3DShaderProgram.h

@@ -52,7 +52,7 @@ public:
 	Error init(const ShaderProgramInitInfo& inf);
 
 private:
-	GrDynamicArray<ShaderPtr> m_shaders;
+	GrDynamicArray<ShaderInternalPtr> m_shaders;
 };
 /// @}
 

+ 313 - 0
Tests/Gr/Gr.cpp

@@ -2511,6 +2511,319 @@ float4 main(VertOut input) : SV_TARGET0
 	commonDestroy();
 }
 
+ANKI_TEST(Gr, RayTracingPipeline)
+{
+	g_rayTracingCVar = true;
+	commonInit();
+
+	{
+		if(!GrManager::getSingleton().getDeviceCapabilities().m_rayTracingEnabled)
+		{
+			ANKI_TEST_LOGF("Test can't run without ray tracing");
+		}
+
+		// Index buffer
+		BufferPtr idxBuffer;
+		{
+			Array<U16, 3> indices = {0, 1, 2};
+			BufferInitInfo init("IdxBuffer");
+			init.m_mapAccess = BufferMapAccessBit::kWrite;
+			init.m_usage = BufferUsageBit::kVertexOrIndex | BufferUsageBit::kAccelerationStructureBuild;
+			init.m_size = sizeof(indices);
+			idxBuffer = GrManager::getSingleton().newBuffer(init);
+
+			void* addr = idxBuffer->map(0, kMaxPtrSize, BufferMapAccessBit::kWrite);
+			memcpy(addr, &indices[0], sizeof(indices));
+			idxBuffer->unmap();
+		}
+
+		// Position buffer (add some padding to complicate things a bit)
+		BufferPtr vertBuffer;
+		{
+			Array<Vec4, 3> verts = {{{-1.0f, 0.0f, 0.0f, 100.0f}, {1.0f, 0.0f, 0.0f, 100.0f}, {0.0f, 2.0f, 0.0f, 100.0f}}};
+
+			BufferInitInfo init("VertBuffer");
+			init.m_mapAccess = BufferMapAccessBit::kWrite;
+			init.m_usage = BufferUsageBit::kVertexOrIndex | BufferUsageBit::kAccelerationStructureBuild;
+			init.m_size = sizeof(verts);
+			vertBuffer = GrManager::getSingleton().newBuffer(init);
+
+			void* addr = vertBuffer->map(0, kMaxPtrSize, BufferMapAccessBit::kWrite);
+			memcpy(addr, &verts[0], sizeof(verts));
+			vertBuffer->unmap();
+		}
+
+		// BLAS
+		AccelerationStructurePtr blas;
+		{
+			AccelerationStructureInitInfo init;
+			init.m_type = AccelerationStructureType::kBottomLevel;
+			init.m_bottomLevel.m_indexBuffer = BufferView(idxBuffer.get());
+			init.m_bottomLevel.m_indexCount = 3;
+			init.m_bottomLevel.m_indexType = IndexType::kU16;
+			init.m_bottomLevel.m_positionBuffer = BufferView(vertBuffer.get());
+			init.m_bottomLevel.m_positionCount = 3;
+			init.m_bottomLevel.m_positionsFormat = Format::kR32G32B32_Sfloat;
+			init.m_bottomLevel.m_positionStride = 4 * 4;
+
+			blas = GrManager::getSingleton().newAccelerationStructure(init);
+		}
+
+		// TLAS
+		AccelerationStructurePtr tlas;
+		{
+			Array<AccelerationStructureInstance, 2> inst = {};
+			inst[0].m_accelerationStructureAddress = blas->getGpuAddress();
+			inst[0].m_transform = Mat3x4::getIdentity();
+			inst[0].m_mask = 0xFF;
+			inst[0].m_flags = kAccellerationStructureFlagForceOpaque;
+			inst[1] = inst[0];
+			inst[1].m_transform = Mat3x4(Vec3(0.0f, -2.0f, 0.0f), Mat3::getIdentity(), Vec3(1.0f));
+			inst[1].m_instanceCustomIndex = 1;
+			inst[1].m_instanceShaderBindingTableRecordOffset = 1;
+			BufferPtr instBuff = createBuffer(BufferUsageBit::kAll, inst, 1);
+
+			AccelerationStructureInitInfo init;
+			init.m_type = AccelerationStructureType::kTopLevel;
+			init.m_topLevel.m_instancesBuffer = BufferView(instBuff.get());
+			init.m_topLevel.m_instanceCount = 2;
+
+			tlas = GrManager::getSingleton().newAccelerationStructure(init);
+		}
+
+		// Program
+		ShaderProgramPtr prog;
+		{
+			constexpr const Char* kCHit = R"(
+struct Barycentrics
+{
+	float2 m_value;
+};
+
+struct [raypayload] Payload
+{
+	float3 m_color : write(closesthit, miss, caller) : read(caller);
+};
+
+[shader("closesthit")] void main(inout Payload payload : SV_RayPayload, in Barycentrics barycentrics : SV_IntersectionAttributes)
+{
+	const float3 bary = float3(1.0 - barycentrics.m_value.x - barycentrics.m_value.y, barycentrics.m_value.x, barycentrics.m_value.y);
+	payload.m_color = bary * COLOR_SCALE;
+})";
+
+			constexpr const Char* kMiss = R"(
+struct [raypayload] Payload
+{
+	float3 m_color : write(closesthit, miss, caller) : read(caller);
+};
+
+[shader("miss")] void main(inout Payload payload : SV_RayPayload)
+{
+	payload.m_color = float3(0.2, 0.0, 0.2);
+})";
+
+			constexpr const Char* kRaygen = R"(
+struct Consts
+{
+	float4x4 m_invViewProj;
+	float3 m_cameraPos;
+	float m_padding0;
+	float2 m_viewport;
+	float2 m_padding1;
+};
+
+struct [raypayload] Payload
+{
+	float3 m_color : write(closesthit, miss, caller) : read(caller);
+};
+
+#if defined(__spirv__)
+[[vk::push_constant]] ConstantBuffer<Consts> g_consts;
+#else
+ConstantBuffer<Consts> g_consts : register(b0, space3000);
+#endif
+
+RaytracingAccelerationStructure g_tlas : register(t0);
+RWTexture2D<float4> g_uav : register(u0);
+
+[shader("raygeneration")] void main()
+{
+	// Unproject
+	const float2 uv = (DispatchRaysIndex().xy + 0.5) / g_consts.m_viewport;
+	float2 ndc = uv * 2.0 - 1.0;
+	ndc.y *= -1;
+	const float4 p4 = mul(g_consts.m_invViewProj, float4(ndc, 1.0, 1.0));
+	const float3 p3 = p4.xyz / p4.w;
+
+	const float3 rayDir = normalize(p3 - g_consts.m_cameraPos);
+	const float3 rayOrigin = g_consts.m_cameraPos;
+
+	Payload payload;
+	payload.m_color = 0.0;
+
+	const uint cullMask = 0xFFu;
+	const uint traceFlags = RAY_FLAG_FORCE_OPAQUE | RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH;
+	const uint missIndex = 0u;
+	const uint sbtRecordOffset = 0u;
+	const uint sbtRecordStride = 0u;
+	RayDesc ray;
+	ray.Origin = rayOrigin;
+	ray.TMin = 0.01;
+	ray.Direction = rayDir;
+	ray.TMax = 100.0;
+	TraceRay(g_tlas, traceFlags, cullMask, sbtRecordOffset, sbtRecordStride, missIndex, ray, payload);
+
+	g_uav[DispatchRaysIndex().xy] = float4(payload.m_color + COLOR_BIAS * ((DispatchRaysIndex().x / 32) & 1), 0.0);
+})";
+
+			ShaderPtr chit1Shader = createShader(kCHit, ShaderType::kClosestHit, Array<CString, 1>{"-DCOLOR_SCALE=float4(1, 0, 0, 0)"});
+			ShaderPtr chit2Shader = createShader(kCHit, ShaderType::kClosestHit, Array<CString, 1>{"-DCOLOR_SCALE=float4(0, 0, 1, 0)"});
+			ShaderPtr missShader = createShader(kMiss, ShaderType::kMiss);
+			ShaderPtr raygen1Shader = createShader(kRaygen, ShaderType::kRayGen, Array<CString, 1>{"-DCOLOR_BIAS=float4(0.5, 0, 0, 0)"});
+			ShaderPtr raygen2Shader = createShader(kRaygen, ShaderType::kRayGen, Array<CString, 1>{"-DCOLOR_BIAS=float4(0, 0.5, 0, 0)"});
+
+			ShaderProgramInitInfo inf;
+			Array<Shader*, 2> raygenArr = {raygen1Shader.get(), raygen2Shader.get()};
+			inf.m_rayTracingShaders.m_rayGenShaders = raygenArr;
+			Shader* pMiss = missShader.get();
+			inf.m_rayTracingShaders.m_missShaders = {&pMiss, 1};
+			Array<RayTracingHitGroup, 2> hits = {{{chit1Shader.get(), nullptr}, {chit2Shader.get(), nullptr}}};
+			inf.m_rayTracingShaders.m_hitGroups = hits;
+			prog = GrManager::getSingleton().newShaderProgram(inf);
+		}
+
+		// Build SBT
+		BufferPtr sbt;
+		U32 sbtRecordSize;
+		{
+			const U32 handleSize = GrManager::getSingleton().getDeviceCapabilities().m_shaderGroupHandleSize;
+			sbtRecordSize = getAlignedRoundUp(GrManager::getSingleton().getDeviceCapabilities().m_sbtRecordAlignment, handleSize);
+
+			ConstWeakArray<U8> handles = prog->getShaderGroupHandles();
+
+			const Array<U32, 4> copyHandles = {1, 2, 3, 4};
+
+			DynamicArray<U8> sbtData;
+			sbtData.resize(copyHandles.getSize() * sbtRecordSize, 0);
+
+			U32 count = 0;
+			for(U32 handleIdx : copyHandles)
+			{
+				memcpy(sbtData.getBegin() + sbtRecordSize * count, handles.getBegin() + handleSize * handleIdx, handleSize);
+				++count;
+			}
+
+			sbt = createBuffer(BufferUsageBit::kShaderBindingTable, ConstWeakArray(sbtData));
+		}
+
+		// Build AS
+		{
+			CommandBufferInitInfo cinit;
+			cinit.m_flags = CommandBufferFlag::kGeneralWork | CommandBufferFlag::kSmallBatch;
+			CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(cinit);
+
+			AccelerationStructureBarrierInfo barr;
+			barr.m_as = blas.get();
+			barr.m_previousUsage = AccelerationStructureUsageBit::kNone;
+			barr.m_nextUsage = AccelerationStructureUsageBit::kBuild;
+
+			cmdb->setPipelineBarrier({}, {}, {&barr, 1});
+			BufferInitInfo scratchInit;
+			scratchInit.m_size = blas->getBuildScratchBufferSize();
+			scratchInit.m_usage = BufferUsageBit::kAccelerationStructureBuildScratch;
+			BufferPtr scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
+			cmdb->buildAccelerationStructure(blas.get(), BufferView(scratchBuff.get()));
+
+			Array<AccelerationStructureBarrierInfo, 2> barr2;
+			barr2[0].m_as = blas.get();
+			barr2[0].m_previousUsage = AccelerationStructureUsageBit::kBuild;
+			barr2[0].m_nextUsage = AccelerationStructureUsageBit::kAttach;
+			barr2[1].m_as = tlas.get();
+			barr2[1].m_previousUsage = AccelerationStructureUsageBit::kNone;
+			barr2[1].m_nextUsage = AccelerationStructureUsageBit::kBuild;
+
+			cmdb->setPipelineBarrier({}, {}, barr2);
+
+			scratchInit.m_size = tlas->getBuildScratchBufferSize();
+			scratchBuff = GrManager::getSingleton().newBuffer(scratchInit);
+			cmdb->buildAccelerationStructure(tlas.get(), BufferView(scratchBuff.get()));
+
+			AccelerationStructureBarrierInfo barr3;
+			barr3.m_as = tlas.get();
+			barr3.m_previousUsage = AccelerationStructureUsageBit::kBuild;
+			barr3.m_nextUsage = AccelerationStructureUsageBit::kComputeSrv;
+
+			cmdb->setPipelineBarrier({}, {}, {&barr3, 1});
+
+			cmdb->endRecording();
+			GrManager::getSingleton().submit(cmdb.get());
+		}
+
+		// Draw
+		constexpr U32 kIterations = 200;
+		for(U i = 0; i < kIterations; ++i)
+		{
+			HighRezTimer timer;
+			timer.start();
+
+			GrManager::getSingleton().beginFrame();
+
+			const Vec4 cameraPos(0.0f, 0.0f, 3.0f, 0.0f);
+			const Mat4 viewMat = Mat4(cameraPos.xyz(), Mat3::getIdentity(), Vec3(1.0f)).invert();
+			const Mat4 projMat = Mat4::calculatePerspectiveProjectionMatrix(toRad(90.0f), toRad(90.0f), 0.01f, 1000.0f);
+
+			CommandBufferInitInfo cinit;
+			cinit.m_flags = CommandBufferFlag::kGeneralWork | CommandBufferFlag::kSmallBatch;
+			CommandBufferPtr cmdb = GrManager::getSingleton().newCommandBuffer(cinit);
+
+			cmdb->bindShaderProgram(prog.get());
+			struct Consts
+			{
+				Mat4 m_invViewProj;
+				Vec3 m_cameraPos;
+				F32 m_padding0;
+				Vec2 m_viewport;
+				Vec2 m_padding1;
+			} consts;
+			consts.m_invViewProj = (projMat * viewMat).invert().transpose();
+			consts.m_cameraPos = cameraPos.xyz();
+			consts.m_viewport = Vec2(kWidth, kHeight);
+			cmdb->setFastConstants(&consts, sizeof(consts));
+
+			cmdb->bindSrv(0, 0, tlas.get());
+
+			TexturePtr presentTex = GrManager::getSingleton().acquireNextPresentableTexture();
+			cmdb->bindUav(0, 0, TextureView(presentTex.get(), TextureSubresourceDesc::all()));
+
+			TextureBarrierInfo barr;
+			barr.m_textureView = TextureView(presentTex.get(), TextureSubresourceDesc::all());
+			barr.m_previousUsage = TextureUsageBit::kNone;
+			barr.m_nextUsage = TextureUsageBit::kUavTraceRays;
+			cmdb->setPipelineBarrier({&barr, 1}, {}, {});
+
+			cmdb->traceRays(BufferView(sbt.get()), sbtRecordSize, 2, 1, kWidth, kHeight, 1);
+
+			barr.m_previousUsage = TextureUsageBit::kUavTraceRays;
+			barr.m_nextUsage = TextureUsageBit::kPresent;
+			cmdb->setPipelineBarrier({&barr, 1}, {}, {});
+
+			cmdb->endRecording();
+			GrManager::getSingleton().submit(cmdb.get());
+
+			GrManager::getSingleton().endFrame();
+
+			timer.stop();
+			const F32 TICK = 1.0f / 30.0f;
+			if(timer.getElapsedTime() < TICK)
+			{
+				HighRezTimer::sleep(TICK - timer.getElapsedTime());
+			}
+		}
+	}
+
+	commonDestroy();
+}
+
 static void createCubeBuffers(GrManager& gr, Vec3 min, Vec3 max, BufferPtr& indexBuffer, BufferPtr& vertBuffer, Bool turnInsideOut = false)
 {
 	BufferInitInfo inf;