123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695 |
- /*
- * Copyright (c) Contributors to the Open 3D Engine Project.
- * For complete copyright and license terms please see the LICENSE at the root of this distribution.
- *
- * SPDX-License-Identifier: Apache-2.0 OR MIT
- *
- */
- #include <RHI/RayTracingExampleComponent.h>
- #include <Utils/Utils.h>
- #include <SampleComponentManager.h>
- #include <Atom/RHI/CommandList.h>
- #include <Atom/RHI/FrameGraphInterface.h>
- #include <Atom/RHI/RayTracingPipelineState.h>
- #include <Atom/RHI/RayTracingShaderTable.h>
- #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
- #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
- #include <Atom/RPI.Public/Shader/Shader.h>
- #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
- #include <AzCore/Serialization/SerializeContext.h>
- static const char* RayTracingExampleName = "RayTracingExample";
- namespace AtomSampleViewer
- {
- void RayTracingExampleComponent::Reflect(AZ::ReflectContext* context)
- {
- if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
- {
- serializeContext->Class<RayTracingExampleComponent, AZ::Component>()
- ->Version(0)
- ;
- }
- }
- RayTracingExampleComponent::RayTracingExampleComponent()
- {
- m_supportRHISamplePipeline = true;
- }
- void RayTracingExampleComponent::Activate()
- {
- CreateResourcePools();
- CreateGeometry();
- CreateFullScreenBuffer();
- CreateOutputTexture();
- CreateRasterShader();
- CreateRayTracingAccelerationStructureObjects();
- CreateRayTracingPipelineState();
- CreateRayTracingShaderTable();
- CreateRayTracingAccelerationTableScope();
- CreateRayTracingDispatchScope();
- CreateRasterScope();
- RHI::RHISystemNotificationBus::Handler::BusConnect();
- }
- void RayTracingExampleComponent::Deactivate()
- {
- RHI::RHISystemNotificationBus::Handler::BusDisconnect();
- m_windowContext = nullptr;
- m_scopeProducers.clear();
- }
- void RayTracingExampleComponent::CreateResourcePools()
- {
- // create input assembly buffer pool
- {
- m_inputAssemblyBufferPool = aznew RHI::BufferPool();
- RHI::BufferPoolDescriptor bufferPoolDesc;
- bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
- bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Host;
- [[maybe_unused]] RHI::ResultCode resultCode = m_inputAssemblyBufferPool->Init(bufferPoolDesc);
- AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize input assembly buffer pool");
- }
- // create output image pool
- {
- RHI::ImagePoolDescriptor imagePoolDesc;
- imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
- m_imagePool = aznew RHI::ImagePool();
- [[maybe_unused]] RHI::ResultCode result = m_imagePool->Init(imagePoolDesc);
- AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image pool");
- }
- // initialize ray tracing buffer pools
- m_rayTracingBufferPools = aznew RHI::RayTracingBufferPools;
- m_rayTracingBufferPools->Init(RHI::MultiDevice::AllDevices);
- }
- void RayTracingExampleComponent::CreateGeometry()
- {
- // triangle
- {
- // vertex buffer
- SetVertexPosition(m_triangleVertices.data(), 0, 0.0f, 0.5f, 1.0);
- SetVertexPosition(m_triangleVertices.data(), 1, 0.5f, -0.5f, 1.0);
- SetVertexPosition(m_triangleVertices.data(), 2, -0.5f, -0.5f, 1.0);
- m_triangleVB = aznew RHI::Buffer();
- m_triangleVB->SetName(AZ::Name("Triangle VB"));
- RHI::BufferInitRequest request;
- request.m_buffer = m_triangleVB.get();
- request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleVertices) };
- request.m_initialData = m_triangleVertices.data();
- m_inputAssemblyBufferPool->InitBuffer(request);
- // index buffer
- SetVertexIndexIncreasing(m_triangleIndices.data(), m_triangleIndices.size());
- m_triangleIB = aznew RHI::Buffer();
- m_triangleIB->SetName(AZ::Name("Triangle IB"));
- request.m_buffer = m_triangleIB.get();
- request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleIndices) };
- request.m_initialData = m_triangleIndices.data();
- m_inputAssemblyBufferPool->InitBuffer(request);
- }
- // rectangle
- {
- // vertex buffer
- SetVertexPosition(m_rectangleVertices.data(), 0, -0.5f, 0.5f, 1.0);
- SetVertexPosition(m_rectangleVertices.data(), 1, 0.5f, 0.5f, 1.0);
- SetVertexPosition(m_rectangleVertices.data(), 2, 0.5f, -0.5f, 1.0);
- SetVertexPosition(m_rectangleVertices.data(), 3, -0.5f, -0.5f, 1.0);
- m_rectangleVB = aznew RHI::Buffer();
- m_rectangleVB->SetName(AZ::Name("Rectangle VB"));
- RHI::BufferInitRequest request;
- request.m_buffer = m_rectangleVB.get();
- request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleVertices) };
- request.m_initialData = m_rectangleVertices.data();
- m_inputAssemblyBufferPool->InitBuffer(request);
- // index buffer
- m_rectangleIndices[0] = 0;
- m_rectangleIndices[1] = 1;
- m_rectangleIndices[2] = 2;
- m_rectangleIndices[3] = 0;
- m_rectangleIndices[4] = 2;
- m_rectangleIndices[5] = 3;
- m_rectangleIB = aznew RHI::Buffer();
- m_rectangleIB->SetName(AZ::Name("Rectangle IB"));
- request.m_buffer = m_rectangleIB.get();
- request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleIndices) };
- request.m_initialData = m_rectangleIndices.data();
- m_inputAssemblyBufferPool->InitBuffer(request);
- }
- }
- void RayTracingExampleComponent::CreateFullScreenBuffer()
- {
- FullScreenBufferData bufferData;
- SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
- m_fullScreenInputAssemblyBuffer = aznew RHI::Buffer();
- RHI::BufferInitRequest request;
- request.m_buffer = m_fullScreenInputAssemblyBuffer.get();
- request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
- request.m_initialData = &bufferData;
- m_inputAssemblyBufferPool->InitBuffer(request);
- m_geometryView.SetDrawArguments(RHI::DrawIndexed(0, 6, 0));
- m_geometryView.AddStreamBufferView({
- *m_fullScreenInputAssemblyBuffer,
- offsetof(FullScreenBufferData, m_positions),
- sizeof(FullScreenBufferData::m_positions),
- sizeof(VertexPosition)
- });
- m_geometryView.AddStreamBufferView({
- *m_fullScreenInputAssemblyBuffer,
- offsetof(FullScreenBufferData, m_uvs),
- sizeof(FullScreenBufferData::m_uvs),
- sizeof(VertexUV)
- });
- m_geometryView.SetIndexBufferView({
- *m_fullScreenInputAssemblyBuffer,
- offsetof(FullScreenBufferData, m_indices),
- sizeof(FullScreenBufferData::m_indices),
- RHI::IndexFormat::Uint16
- });
- RHI::InputStreamLayoutBuilder layoutBuilder;
- layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
- layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
- m_fullScreenInputStreamLayout = layoutBuilder.End();
- }
- void RayTracingExampleComponent::CreateOutputTexture()
- {
- // create output image
- m_outputImage = aznew RHI::Image();
- RHI::ImageInitRequest request;
- request.m_image = m_outputImage.get();
- request.m_descriptor = RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, m_imageWidth, m_imageHeight, RHI::Format::R8G8B8A8_UNORM);
- [[maybe_unused]] RHI::ResultCode result = m_imagePool->InitImage(request);
- AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image");
- m_outputImageViewDescriptor = RHI::ImageViewDescriptor::Create(RHI::Format::R8G8B8A8_UNORM, 0, 0);
- m_outputImageView = m_outputImage->GetImageView(m_outputImageViewDescriptor);
- AZ_Assert(m_outputImageView.get(), "Failed to create output image view");
- AZ_Assert(m_outputImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex)->IsFullView(), "Image View initialization IsFullView() failed");
- }
- void RayTracingExampleComponent::CreateRayTracingAccelerationStructureObjects()
- {
- m_triangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
- m_rectangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
- m_rayTracingTlas = aznew AZ::RHI::RayTracingTlas;
- }
- void RayTracingExampleComponent::CreateRasterShader()
- {
- const char* shaderFilePath = "Shaders/RHI/RayTracingDraw.azshader";
- auto drawShader = LoadShader(shaderFilePath, RayTracingExampleName);
- AZ_Assert(drawShader, "Failed to load Draw shader");
- RHI::PipelineStateDescriptorForDraw pipelineDesc;
- drawShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
- pipelineDesc.m_inputStreamLayout = m_fullScreenInputStreamLayout;
- RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
- attachmentsBuilder.AddSubpass()->RenderTargetAttachment(m_outputFormat);
- [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
- AZ_Assert(result == RHI::ResultCode::Success, "Failed to create draw render attachment layout");
- m_drawPipelineState = drawShader->AcquirePipelineState(pipelineDesc);
- AZ_Assert(m_drawPipelineState, "Failed to acquire draw pipeline state");
- m_drawSRG = CreateShaderResourceGroup(drawShader, "BufferSrg", RayTracingExampleName);
- }
- void RayTracingExampleComponent::CreateRayTracingPipelineState()
- {
- // load ray generation shader
- const char* rayGenerationShaderFilePath = "Shaders/RHI/RayTracingDispatch.azshader";
- m_rayGenerationShader = LoadShader(rayGenerationShaderFilePath, RayTracingExampleName);
- AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
- auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
- RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
- rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
- // load miss shader
- const char* missShaderFilePath = "Shaders/RHI/RayTracingMiss.azshader";
- m_missShader = LoadShader(missShaderFilePath, RayTracingExampleName);
- AZ_Assert(m_missShader, "Failed to load miss shader");
- auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
- RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
- missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
- // load closest hit gradient shader
- const char* closestHitGradientShaderFilePath = "Shaders/RHI/RayTracingClosestHitGradient.azshader";
- m_closestHitGradientShader = LoadShader(closestHitGradientShaderFilePath, RayTracingExampleName);
- AZ_Assert(m_closestHitGradientShader, "Failed to load closest hit gradient shader");
- auto closestHitGradientShaderVariant = m_closestHitGradientShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
- RHI::PipelineStateDescriptorForRayTracing closestHitGradientShaderDescriptor;
- closestHitGradientShaderVariant.ConfigurePipelineState(closestHitGradientShaderDescriptor);
- // load closest hit solid shader
- const char* closestHitSolidShaderFilePath = "Shaders/RHI/RayTracingClosestHitSolid.azshader";
- m_closestHitSolidShader = LoadShader(closestHitSolidShaderFilePath, RayTracingExampleName);
- AZ_Assert(m_closestHitSolidShader, "Failed to load closest hit solid shader");
- auto closestHitSolidShaderVariant = m_closestHitSolidShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
- RHI::PipelineStateDescriptorForRayTracing closestHitSolidShaderDescriptor;
- closestHitSolidShaderVariant.ConfigurePipelineState(closestHitSolidShaderDescriptor);
- // global pipeline state and srg
- m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
- AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
- m_globalSrg = CreateShaderResourceGroup(m_rayGenerationShader, "RayTracingGlobalSrg", RayTracingExampleName);
- // build the ray tracing pipeline state descriptor
- RHI::RayTracingPipelineStateDescriptor descriptor;
- descriptor.m_pipelineState = m_globalPipelineState.get();
- descriptor.AddRayGenerationShaderLibrary(rayGenerationShaderDescriptor, Name("RayGenerationShader"));
- descriptor.AddMissShaderLibrary(missShaderDescriptor, Name("MissShader"));
- descriptor.AddClosestHitShaderLibrary(closestHitGradientShaderDescriptor, Name("ClosestHitGradientShader"));
- descriptor.AddClosestHitShaderLibrary(closestHitSolidShaderDescriptor, Name("ClosestHitSolidShader"));
- descriptor.AddHitGroup(Name("HitGroupGradient"), Name("ClosestHitGradientShader"));
- descriptor.AddHitGroup(Name("HitGroupSolid"), Name("ClosestHitSolidShader"));
- // create the ray tracing pipeline state object
- m_rayTracingPipelineState = aznew RHI::RayTracingPipelineState;
- m_rayTracingPipelineState->Init(RHI::MultiDevice::AllDevices, descriptor);
- }
- void RayTracingExampleComponent::CreateRayTracingShaderTable()
- {
- m_rayTracingShaderTable = aznew RHI::RayTracingShaderTable;
- m_rayTracingShaderTable->Init(RHI::MultiDevice::AllDevices, *m_rayTracingBufferPools);
- }
- void RayTracingExampleComponent::CreateRayTracingAccelerationTableScope()
- {
- struct ScopeData
- {
- };
- const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
- {
- // create triangle BLAS buffer if necessary
- if (!m_triangleRayTracingBlas->IsValid())
- {
- RHI::StreamBufferView triangleVertexBufferView =
- {
- *m_triangleVB,
- 0,
- sizeof(m_triangleVertices),
- sizeof(VertexPosition)
- };
- RHI::IndexBufferView triangleIndexBufferView =
- {
- *m_triangleIB,
- 0,
- sizeof(m_triangleIndices),
- RHI::IndexFormat::Uint16
- };
- RHI::RayTracingBlasDescriptor triangleBlasDescriptor;
- RHI::RayTracingGeometry& triangleBlasGeometry = triangleBlasDescriptor.m_geometries.emplace_back();
- triangleBlasGeometry.m_vertexFormat = RHI::VertexFormat::R32G32B32_FLOAT;
- triangleBlasGeometry.m_vertexBuffer = triangleVertexBufferView;
- triangleBlasGeometry.m_indexBuffer = triangleIndexBufferView;
- m_triangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &triangleBlasDescriptor, *m_rayTracingBufferPools);
- }
- // create rectangle BLAS if necessary
- if (!m_rectangleRayTracingBlas->IsValid())
- {
- RHI::StreamBufferView rectangleVertexBufferView =
- {
- *m_rectangleVB,
- 0,
- sizeof(m_rectangleVertices),
- sizeof(VertexPosition)
- };
- RHI::IndexBufferView rectangleIndexBufferView =
- {
- *m_rectangleIB,
- 0,
- sizeof(m_rectangleIndices),
- RHI::IndexFormat::Uint16
- };
- RHI::RayTracingBlasDescriptor rectangleBlasDescriptor;
- RHI::RayTracingGeometry& rectangleBlasGeometry = rectangleBlasDescriptor.m_geometries.emplace_back();
- rectangleBlasGeometry.m_vertexFormat = RHI::VertexFormat::R32G32B32_FLOAT;
- rectangleBlasGeometry.m_vertexBuffer = rectangleVertexBufferView;
- rectangleBlasGeometry.m_indexBuffer = rectangleIndexBufferView;
- m_rectangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &rectangleBlasDescriptor, *m_rayTracingBufferPools);
- }
- m_time += 0.005f;
- // transforms
- AZ::Transform triangleTransform1 = AZ::Transform::CreateIdentity();
- triangleTransform1.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * -100.0f, 1.0f);
- triangleTransform1.MultiplyByUniformScale(100.0f);
- AZ::Transform triangleTransform2 = AZ::Transform::CreateIdentity();
- triangleTransform2.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * 100.0f, 2.0f);
- triangleTransform2.MultiplyByUniformScale(100.0f);
- AZ::Transform triangleTransform3 = AZ::Transform::CreateIdentity();
- triangleTransform3.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * 100.0f, 3.0f);
- triangleTransform3.MultiplyByUniformScale(100.0f);
- AZ::Transform rectangleTransform = AZ::Transform::CreateIdentity();
- rectangleTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * -100.0f, 4.0f);
- rectangleTransform.MultiplyByUniformScale(100.0f);
- // create the TLAS
- auto deviceMask = RHI::MultiDevice::AllDevices;
- AZStd::unordered_map<int, RHI::DeviceRayTracingTlasDescriptor> tlasDescriptor;
- RHI::MultiDeviceObject::IterateDevices(
- deviceMask,
- [&](int deviceIndex)
- {
- {
- auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
- tlasInstance.m_instanceID = 0;
- tlasInstance.m_hitGroupIndex = 0;
- tlasInstance.m_blas = m_triangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
- tlasInstance.m_transform = triangleTransform1;
- }
- {
- auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
- tlasInstance.m_instanceID = 1;
- tlasInstance.m_hitGroupIndex = 1;
- tlasInstance.m_blas = m_triangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
- tlasInstance.m_transform = triangleTransform2;
- }
- {
- auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
- tlasInstance.m_instanceID = 2;
- tlasInstance.m_hitGroupIndex = 2;
- tlasInstance.m_blas = m_triangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
- tlasInstance.m_transform = triangleTransform3;
- }
- {
- auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
- tlasInstance.m_instanceID = 3;
- tlasInstance.m_hitGroupIndex = 3;
- tlasInstance.m_blas = m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
- tlasInstance.m_transform = rectangleTransform;
- }
- return true;
- });
- m_rayTracingTlas->CreateBuffers(deviceMask, tlasDescriptor, *m_rayTracingBufferPools);
- m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
- [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
- AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
- RHI::BufferScopeAttachmentDescriptor desc;
- desc.m_attachmentId = m_tlasBufferAttachmentId;
- desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
- desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
- frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::AnyGraphics);
- };
- RHI::EmptyCompileFunction<ScopeData> compileFunction;
- const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
- {
- RHI::CommandList* commandList = context.GetCommandList();
- commandList->BuildBottomLevelAccelerationStructure(*m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
- commandList->BuildBottomLevelAccelerationStructure(*m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
- commandList->BuildTopLevelAccelerationStructure(
- *m_rayTracingTlas->GetDeviceRayTracingTlas(context.GetDeviceIndex()), { m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get(), m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get() });
- };
- m_scopeProducers.emplace_back(
- aznew RHI::ScopeProducerFunction<
- ScopeData,
- decltype(prepareFunction),
- decltype(compileFunction),
- decltype(executeFunction)>(
- RHI::ScopeId{ "RayTracingBuildAccelerationStructure" },
- ScopeData{},
- prepareFunction,
- compileFunction,
- executeFunction));
- }
- void RayTracingExampleComponent::CreateRayTracingDispatchScope()
- {
- struct ScopeData
- {
- };
- const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
- {
- // attach output image
- {
- [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(m_outputImageAttachmentId, m_outputImage);
- AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import output image with error %d", result);
- RHI::ImageScopeAttachmentDescriptor desc;
- desc.m_attachmentId = m_outputImageAttachmentId;
- desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
- desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
- frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
- }
- // attach TLAS buffer
- if (m_rayTracingTlas->GetTlasBuffer())
- {
- RHI::BufferScopeAttachmentDescriptor desc;
- desc.m_attachmentId = m_tlasBufferAttachmentId;
- desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
- desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
- frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
- }
- frameGraph.SetEstimatedItemCount(1);
- };
- const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
- {
- if (m_rayTracingTlas->GetTlasBuffer())
- {
- // set the TLAS and output image in the ray tracing global Srg
- RHI::ShaderInputBufferIndex tlasConstantIndex;
- FindShaderInputIndex(&tlasConstantIndex, m_globalSrg, AZ::Name{ "m_scene" }, RayTracingExampleName);
- uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
- RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
- m_globalSrg->SetBufferView(tlasConstantIndex, m_rayTracingTlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get());
- RHI::ShaderInputImageIndex outputConstantIndex;
- FindShaderInputIndex(&outputConstantIndex, m_globalSrg, AZ::Name{ "m_output" }, RayTracingExampleName);
- m_globalSrg->SetImageView(outputConstantIndex, m_outputImageView.get());
- // set hit shader data, each array element corresponds to the InstanceIndex() of the geometry in the TLAS
- // Note: this method is used instead of LocalRootSignatures for compatibility with non-DX12 platforms
- // set HitGradient values
- RHI::ShaderInputConstantIndex hitGradientDataConstantIndex;
- FindShaderInputIndex(&hitGradientDataConstantIndex, m_globalSrg, AZ::Name{"m_hitGradientData"}, RayTracingExampleName);
- struct HitGradientData
- {
- AZ::Vector4 m_color;
- };
- AZStd::array<HitGradientData, 4> hitGradientData = {{
- {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f)}, // triangle1
- {AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle2
- {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
- {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
- }};
- m_globalSrg->SetConstantArray(hitGradientDataConstantIndex, hitGradientData);
- // set HitSolid values
- RHI::ShaderInputConstantIndex hitSolidDataConstantIndex;
- FindShaderInputIndex(&hitSolidDataConstantIndex, m_globalSrg, AZ::Name{"m_hitSolidData"}, RayTracingExampleName);
- struct HitSolidData
- {
- AZ::Vector4 m_color1;
- float m_lerp;
- float m_pad[3];
- AZ::Vector4 m_color2;
- };
- AZStd::array<HitSolidData, 4> hitSolidData = {{
- {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
- {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
- {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle3
- {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 1.0f, 1.0f)}, // rectangle
- }};
- m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
- m_globalSrg->Compile();
- // update the ray tracing shader table
- AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
- descriptor->m_name = AZ::Name("RayTracingExampleShaderTable");
- descriptor->m_rayTracingPipelineState = m_rayTracingPipelineState;
- descriptor->m_rayGenerationRecord.emplace_back(AZ::Name("RayGenerationShader"));
- descriptor->m_missRecords.emplace_back(AZ::Name("MissShader"));
- descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupGradient")); // triangle1
- descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupGradient")); // triangle2
- descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupSolid")); // triangle3
- descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupSolid")); // rectangle
- m_rayTracingShaderTable->Build(descriptor);
- }
- };
- const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
- {
- if (!m_rayTracingTlas->GetTlasBuffer())
- {
- return;
- }
- RHI::CommandList* commandList = context.GetCommandList();
- const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
- m_globalSrg->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
- };
- RHI::DeviceDispatchRaysItem dispatchRaysItem;
- dispatchRaysItem.m_arguments.m_direct.m_width = m_imageWidth;
- dispatchRaysItem.m_arguments.m_direct.m_height = m_imageHeight;
- dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
- dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
- dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
- dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
- dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
- dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
- // submit the DispatchRays item
- commandList->Submit(dispatchRaysItem);
- };
- m_scopeProducers.emplace_back(
- aznew RHI::ScopeProducerFunction<
- ScopeData,
- decltype(prepareFunction),
- decltype(compileFunction),
- decltype(executeFunction)>(
- RHI::ScopeId{ "RayTracingDispatch" },
- ScopeData{},
- prepareFunction,
- compileFunction,
- executeFunction));
- }
- void RayTracingExampleComponent::CreateRasterScope()
- {
- struct ScopeData
- {
- };
- const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
- {
- // attach swapchain
- {
- RHI::ImageScopeAttachmentDescriptor descriptor;
- descriptor.m_attachmentId = m_outputAttachmentId;
- descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
- frameGraph.UseColorAttachment(descriptor);
- }
- // attach output buffer
- {
- RHI::ImageScopeAttachmentDescriptor desc;
- desc.m_attachmentId = m_outputImageAttachmentId;
- desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
- desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
- frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::FragmentShader);
- const Name outputImageId{ "m_output" };
- RHI::ShaderInputImageIndex outputImageIndex = m_drawSRG->FindShaderInputImageIndex(outputImageId);
- AZ_Error(RayTracingExampleName, outputImageIndex.IsValid(), "Failed to find shader input image %s.", outputImageId.GetCStr());
- m_drawSRG->SetImageView(outputImageIndex, m_outputImageView.get());
- m_drawSRG->Compile();
- }
- frameGraph.SetEstimatedItemCount(1);
- };
- RHI::EmptyCompileFunction<ScopeData> compileFunction;
- const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
- {
- RHI::CommandList* commandList = context.GetCommandList();
- commandList->SetViewports(&m_viewport, 1);
- commandList->SetScissors(&m_scissor, 1);
- const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
- m_drawSRG->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
- };
- RHI::DeviceDrawItem drawItem;
- drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
- drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
- drawItem.m_pipelineState = m_drawPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
- drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
- drawItem.m_shaderResourceGroups = shaderResourceGroups;
- // submit the triangle draw item.
- commandList->Submit(drawItem);
- };
- m_scopeProducers.emplace_back(
- aznew RHI::ScopeProducerFunction<
- ScopeData,
- decltype(prepareFunction),
- decltype(compileFunction),
- decltype(executeFunction)>(
- RHI::ScopeId{ "Raster" },
- ScopeData{},
- prepareFunction,
- compileFunction,
- executeFunction));
- }
- } // namespace AtomSampleViewer
|