3
0

RayTracingFeatureProcessor.cpp 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RayTracing/RayTracingFeatureProcessor.h>
  9. #include <RayTracing/RayTracingPass.h>
  10. #include <Atom/Feature/TransformService/TransformServiceFeatureProcessor.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/MultiDeviceRayTracingAccelerationStructure.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RPI.Public/Scene.h>
  15. #include <Atom/RPI.Public/Pass/PassFilter.h>
  16. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  17. #include <Atom/RPI.Reflect/Asset/AssetUtils.h>
  18. #include <Atom/Feature/ImageBasedLights/ImageBasedLightFeatureProcessor.h>
  19. #include <CoreLights/DirectionalLightFeatureProcessor.h>
  20. #include <CoreLights/SimplePointLightFeatureProcessor.h>
  21. #include <CoreLights/SimpleSpotLightFeatureProcessor.h>
  22. #include <CoreLights/PointLightFeatureProcessor.h>
  23. #include <CoreLights/DiskLightFeatureProcessor.h>
  24. #include <CoreLights/CapsuleLightFeatureProcessor.h>
  25. #include <CoreLights/QuadLightFeatureProcessor.h>
  26. namespace AZ
  27. {
  28. namespace Render
  29. {
  30. void RayTracingFeatureProcessor::Reflect(ReflectContext* context)
  31. {
  32. if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
  33. {
  34. serializeContext
  35. ->Class<RayTracingFeatureProcessor, FeatureProcessor>()
  36. ->Version(1);
  37. }
  38. }
  39. void RayTracingFeatureProcessor::Activate()
  40. {
  41. auto deviceMask{RHI::RHISystemInterface::Get()->GetRayTracingSupport()};
  42. m_rayTracingEnabled = (deviceMask != RHI::MultiDevice::NoDevices);
  43. if (!m_rayTracingEnabled)
  44. {
  45. return;
  46. }
  47. m_transformServiceFeatureProcessor = GetParentScene()->GetFeatureProcessor<TransformServiceFeatureProcessor>();
  48. // initialize the ray tracing buffer pools
  49. m_bufferPools = aznew RHI::MultiDeviceRayTracingBufferPools;
  50. m_bufferPools->Init(deviceMask);
  51. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  52. for (auto deviceIndex{0}; deviceIndex < deviceCount; ++deviceIndex)
  53. {
  54. if ((AZStd::to_underlying(deviceMask) >> deviceIndex) & 1)
  55. {
  56. m_meshBufferIndices[deviceIndex] = {};
  57. m_materialTextureIndices[deviceIndex] = {};
  58. m_materialInfos[deviceIndex] = {};
  59. }
  60. }
  61. // create TLAS attachmentId
  62. AZStd::string uuidString = AZ::Uuid::CreateRandom().ToString<AZStd::string>();
  63. m_tlasAttachmentId = RHI::AttachmentId(AZStd::string::format("RayTracingTlasAttachmentId_%s", uuidString.c_str()));
  64. // create the TLAS object
  65. m_tlas = aznew RHI::MultiDeviceRayTracingTlas;
  66. // load the RayTracingSrg asset asset
  67. m_rayTracingSrgAsset = RPI::AssetUtils::LoadCriticalAsset<RPI::ShaderAsset>("shaderlib/atom/features/rayTracing/raytracingsrgs.azshader");
  68. if (!m_rayTracingSrgAsset.IsReady())
  69. {
  70. AZ_Assert(false, "Failed to load RayTracingSrg asset");
  71. return;
  72. }
  73. // create the RayTracingSceneSrg
  74. m_rayTracingSceneSrg = RPI::ShaderResourceGroup::Create(m_rayTracingSrgAsset, Name("RayTracingSceneSrg"));
  75. AZ_Assert(m_rayTracingSceneSrg, "Failed to create RayTracingSceneSrg");
  76. // create the RayTracingMaterialSrg
  77. const AZ::Name rayTracingMaterialSrgName("RayTracingMaterialSrg");
  78. m_rayTracingMaterialSrg = RPI::ShaderResourceGroup::Create(m_rayTracingSrgAsset, Name("RayTracingMaterialSrg"));
  79. AZ_Assert(m_rayTracingMaterialSrg, "Failed to create RayTracingMaterialSrg");
  80. EnableSceneNotification();
  81. }
  82. void RayTracingFeatureProcessor::Deactivate()
  83. {
  84. DisableSceneNotification();
  85. }
  86. RayTracingFeatureProcessor::ProceduralGeometryTypeHandle RayTracingFeatureProcessor::RegisterProceduralGeometryType(
  87. const AZStd::string& name,
  88. const Data::Instance<RPI::Shader>& intersectionShader,
  89. const AZStd::string& intersectionShaderName,
  90. uint32_t bindlessBufferIndex)
  91. {
  92. ProceduralGeometryTypeHandle geometryTypeHandle;
  93. {
  94. ProceduralGeometryType proceduralGeometryType;
  95. proceduralGeometryType.m_name = AZ::Name(name);
  96. proceduralGeometryType.m_intersectionShader = intersectionShader;
  97. proceduralGeometryType.m_intersectionShaderName = AZ::Name(intersectionShaderName);
  98. proceduralGeometryType.m_bindlessBufferIndex = bindlessBufferIndex;
  99. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  100. geometryTypeHandle = m_proceduralGeometryTypes.insert(proceduralGeometryType);
  101. }
  102. m_proceduralGeometryTypeRevision++;
  103. return geometryTypeHandle;
  104. }
  105. void RayTracingFeatureProcessor::SetProceduralGeometryTypeBindlessBufferIndex(
  106. ProceduralGeometryTypeWeakHandle geometryTypeHandle, uint32_t bindlessBufferIndex)
  107. {
  108. if (!m_rayTracingEnabled)
  109. {
  110. return;
  111. }
  112. geometryTypeHandle->m_bindlessBufferIndex = bindlessBufferIndex;
  113. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  114. }
  115. void RayTracingFeatureProcessor::AddProceduralGeometry(
  116. ProceduralGeometryTypeWeakHandle geometryTypeHandle,
  117. const Uuid& uuid,
  118. const Aabb& aabb,
  119. const SubMeshMaterial& material,
  120. RHI::RayTracingAccelerationStructureInstanceInclusionMask instanceMask,
  121. uint32_t localInstanceIndex)
  122. {
  123. if (!m_rayTracingEnabled)
  124. {
  125. return;
  126. }
  127. RHI::Ptr<AZ::RHI::MultiDeviceRayTracingBlas> rayTracingBlas = aznew AZ::RHI::MultiDeviceRayTracingBlas;
  128. RHI::MultiDeviceRayTracingBlasDescriptor blasDescriptor;
  129. blasDescriptor.Build()
  130. ->AABB(aabb)
  131. ;
  132. rayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &blasDescriptor, *m_bufferPools);
  133. ProceduralGeometry proceduralGeometry;
  134. proceduralGeometry.m_uuid = uuid;
  135. proceduralGeometry.m_typeHandle = geometryTypeHandle;
  136. proceduralGeometry.m_aabb = aabb;
  137. proceduralGeometry.m_instanceMask = static_cast<uint32_t>(instanceMask);
  138. proceduralGeometry.m_blas = rayTracingBlas;
  139. proceduralGeometry.m_localInstanceIndex = localInstanceIndex;
  140. MeshBlasInstance meshBlasInstance;
  141. meshBlasInstance.m_count = 1;
  142. meshBlasInstance.m_subMeshes.push_back(SubMeshBlasInstance{ rayTracingBlas });
  143. MaterialInfo materialInfo;
  144. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  145. m_proceduralGeometryLookup.emplace(uuid, m_proceduralGeometry.size());
  146. m_proceduralGeometry.push_back(proceduralGeometry);
  147. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  148. for (auto deviceIndex{0}; deviceIndex < deviceCount; ++deviceIndex)
  149. {
  150. m_proceduralGeometryMaterialInfos[deviceIndex].emplace_back();
  151. ConvertMaterial(m_proceduralGeometryMaterialInfos[deviceIndex].back(), material, deviceIndex);
  152. }
  153. m_blasInstanceMap.emplace(Data::AssetId(uuid), meshBlasInstance);
  154. geometryTypeHandle->m_instanceCount++;
  155. m_revision++;
  156. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  157. m_materialInfoBufferNeedsUpdate = true;
  158. m_indexListNeedsUpdate = true;
  159. }
  160. void RayTracingFeatureProcessor::SetProceduralGeometryTransform(
  161. const Uuid& uuid, const Transform& transform, const Vector3& nonUniformScale)
  162. {
  163. if (!m_rayTracingEnabled)
  164. {
  165. return;
  166. }
  167. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  168. if (auto it = m_proceduralGeometryLookup.find(uuid); it != m_proceduralGeometryLookup.end())
  169. {
  170. m_proceduralGeometry[it->second].m_transform = transform;
  171. m_proceduralGeometry[it->second].m_nonUniformScale = nonUniformScale;
  172. }
  173. m_revision++;
  174. }
  175. void RayTracingFeatureProcessor::SetProceduralGeometryLocalInstanceIndex(const Uuid& uuid, uint32_t localInstanceIndex)
  176. {
  177. if (!m_rayTracingEnabled)
  178. {
  179. return;
  180. }
  181. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  182. if (auto it = m_proceduralGeometryLookup.find(uuid); it != m_proceduralGeometryLookup.end())
  183. {
  184. m_proceduralGeometry[it->second].m_localInstanceIndex = localInstanceIndex;
  185. }
  186. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  187. }
  188. void RayTracingFeatureProcessor::RemoveProceduralGeometry(const Uuid& uuid)
  189. {
  190. if (!m_rayTracingEnabled)
  191. {
  192. return;
  193. }
  194. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  195. size_t materialInfoIndex = m_proceduralGeometryLookup[uuid];
  196. m_proceduralGeometry[materialInfoIndex].m_typeHandle->m_instanceCount--;
  197. if (materialInfoIndex < m_proceduralGeometry.size() - 1)
  198. {
  199. m_proceduralGeometryLookup[m_proceduralGeometry.back().m_uuid] = m_proceduralGeometryLookup[uuid];
  200. m_proceduralGeometry[materialInfoIndex] = m_proceduralGeometry.back();
  201. for (auto& [deviceIndex, materialInfos] : m_proceduralGeometryMaterialInfos)
  202. {
  203. materialInfos[materialInfoIndex] = materialInfos.back();
  204. }
  205. }
  206. m_proceduralGeometry.pop_back();
  207. for (auto& [deviceIndex, materialInfos] : m_proceduralGeometryMaterialInfos)
  208. {
  209. materialInfos.pop_back();
  210. }
  211. m_blasInstanceMap.erase(uuid);
  212. m_proceduralGeometryLookup.erase(uuid);
  213. m_revision++;
  214. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  215. m_materialInfoBufferNeedsUpdate = true;
  216. m_indexListNeedsUpdate = true;
  217. }
  218. int RayTracingFeatureProcessor::GetProceduralGeometryCount(ProceduralGeometryTypeWeakHandle geometryTypeHandle) const
  219. {
  220. return geometryTypeHandle->m_instanceCount;
  221. }
  222. void RayTracingFeatureProcessor::AddMesh(const AZ::Uuid& uuid, const Mesh& rayTracingMesh, const SubMeshVector& subMeshes)
  223. {
  224. if (!m_rayTracingEnabled)
  225. {
  226. return;
  227. }
  228. // lock the mutex to protect the mesh and BLAS lists
  229. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  230. // check to see if we already have this mesh
  231. MeshMap::iterator itMesh = m_meshes.find(uuid);
  232. if (itMesh != m_meshes.end())
  233. {
  234. AZ_Assert(false, "AddMesh called on an existing Mesh, call RemoveMesh first");
  235. return;
  236. }
  237. // add the mesh
  238. m_meshes.insert(AZStd::make_pair(uuid, rayTracingMesh));
  239. Mesh& mesh = m_meshes[uuid];
  240. // add the subMeshes to the end of the global subMesh vector
  241. // Note 1: the MeshInfo and MaterialInfo vectors are parallel with the subMesh vector
  242. // Note 2: the list of indices for the subMeshes in the global vector are stored in the parent Mesh
  243. IndexVector subMeshIndices;
  244. uint32_t subMeshGlobalIndex = aznumeric_cast<uint32_t>(m_subMeshes.size());
  245. for (uint32_t subMeshIndex = 0; subMeshIndex < subMeshes.size(); ++subMeshIndex, ++subMeshGlobalIndex)
  246. {
  247. SubMesh& subMesh = m_subMeshes.emplace_back(subMeshes[subMeshIndex]);
  248. subMesh.m_mesh = &mesh;
  249. subMesh.m_subMeshIndex = subMeshIndex;
  250. subMesh.m_globalIndex = subMeshGlobalIndex;
  251. // add to the list of global subMeshIndices, which will be stored in the Mesh
  252. subMeshIndices.push_back(subMeshGlobalIndex);
  253. // add MeshInfo and MaterialInfo entries
  254. m_meshInfos.emplace_back();
  255. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  256. {
  257. materialInfos.emplace_back();
  258. }
  259. }
  260. mesh.m_subMeshIndices = subMeshIndices;
  261. // search for an existing BLAS instance entry for this mesh using the assetId
  262. BlasInstanceMap::iterator itMeshBlasInstance = m_blasInstanceMap.find(mesh.m_assetId);
  263. if (itMeshBlasInstance == m_blasInstanceMap.end())
  264. {
  265. // make a new BLAS map entry for this mesh
  266. MeshBlasInstance meshBlasInstance;
  267. meshBlasInstance.m_count = 1;
  268. meshBlasInstance.m_subMeshes.reserve(mesh.m_subMeshIndices.size());
  269. meshBlasInstance.m_isSkinnedMesh = mesh.m_isSkinnedMesh;
  270. itMeshBlasInstance = m_blasInstanceMap.insert({ mesh.m_assetId, meshBlasInstance }).first;
  271. if (mesh.m_isSkinnedMesh)
  272. {
  273. ++m_skinnedMeshCount;
  274. }
  275. }
  276. else
  277. {
  278. itMeshBlasInstance->second.m_count++;
  279. }
  280. // create the BLAS buffers for each sub-mesh, or re-use existing BLAS objects if they were already created.
  281. // Note: all sub-meshes must either create new BLAS objects or re-use existing ones, otherwise it's an error (it's the same model in both cases)
  282. // Note: the buffer is just reserved here, the BLAS is built in the RayTracingAccelerationStructurePass
  283. // Note: the build flags are set to be the same for each BLAS created for the mesh
  284. RHI::RayTracingAccelerationStructureBuildFlags buildFlags = CreateRayTracingAccelerationStructureBuildFlags(mesh.m_isSkinnedMesh);
  285. [[maybe_unused]] bool blasInstanceFound = false;
  286. for (uint32_t subMeshIndex = 0; subMeshIndex < mesh.m_subMeshIndices.size(); ++subMeshIndex)
  287. {
  288. SubMesh& subMesh = m_subMeshes[mesh.m_subMeshIndices[subMeshIndex]];
  289. RHI::MultiDeviceRayTracingBlasDescriptor blasDescriptor;
  290. blasDescriptor.Build()
  291. ->Geometry()
  292. ->VertexFormat(subMesh.m_positionFormat)
  293. ->VertexBuffer(subMesh.m_positionVertexBufferView)
  294. ->IndexBuffer(subMesh.m_indexBufferView)
  295. ->BuildFlags(buildFlags)
  296. ;
  297. // determine if we have an existing BLAS object for this subMesh
  298. if (itMeshBlasInstance->second.m_subMeshes.size() >= subMeshIndex + 1)
  299. {
  300. // re-use existing BLAS
  301. subMesh.m_blas = itMeshBlasInstance->second.m_subMeshes[subMeshIndex].m_blas;
  302. // keep track of the fact that we re-used a BLAS
  303. blasInstanceFound = true;
  304. }
  305. else
  306. {
  307. AZ_Assert(blasInstanceFound == false, "Partial set of RayTracingBlas objects found for mesh");
  308. // create the BLAS object and store it in the BLAS list
  309. RHI::Ptr<RHI::MultiDeviceRayTracingBlas> rayTracingBlas = aznew RHI::MultiDeviceRayTracingBlas;
  310. itMeshBlasInstance->second.m_subMeshes.push_back({ rayTracingBlas });
  311. // create the buffers from the BLAS descriptor
  312. rayTracingBlas->CreateBuffers(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), &blasDescriptor, *m_bufferPools);
  313. // store the BLAS in the mesh
  314. subMesh.m_blas = rayTracingBlas;
  315. }
  316. }
  317. AZ::Transform noScaleTransform = mesh.m_transform;
  318. noScaleTransform.ExtractUniformScale();
  319. AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform);
  320. rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose();
  321. Matrix3x4 worldInvTranspose3x4 = Matrix3x4::CreateFromMatrix3x3(rotationMatrix);
  322. Matrix3x4 reflectionProbeModelToWorld3x4 = Matrix3x4::CreateFromTransform(mesh.m_reflectionProbe.m_modelToWorld);
  323. // store the mesh buffers and material textures in the resource lists
  324. for (uint32_t subMeshIndex : mesh.m_subMeshIndices)
  325. {
  326. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  327. MeshInfo& meshInfo = m_meshInfos[subMesh.m_globalIndex];
  328. worldInvTranspose3x4.StoreToRowMajorFloat12(meshInfo.m_worldInvTranspose.data());
  329. meshInfo.m_bufferFlags = subMesh.m_bufferFlags;
  330. AZ_Assert(subMesh.m_indexShaderBufferView.get(), "RayTracing Mesh IndexBuffer cannot be null");
  331. AZ_Assert(subMesh.m_positionShaderBufferView.get(), "RayTracing Mesh PositionBuffer cannot be null");
  332. AZ_Assert(subMesh.m_normalShaderBufferView.get(), "RayTracing Mesh NormalBuffer cannot be null");
  333. meshInfo.m_indexByteOffset = subMesh.m_indexBufferView.GetByteOffset();
  334. meshInfo.m_positionByteOffset = subMesh.m_positionVertexBufferView.GetByteOffset();
  335. meshInfo.m_normalByteOffset = subMesh.m_normalVertexBufferView.GetByteOffset();
  336. meshInfo.m_tangentByteOffset = subMesh.m_tangentShaderBufferView ? subMesh.m_tangentVertexBufferView.GetByteOffset() : 0;
  337. meshInfo.m_bitangentByteOffset = subMesh.m_bitangentShaderBufferView ? subMesh.m_bitangentVertexBufferView.GetByteOffset() : 0;
  338. meshInfo.m_uvByteOffset = subMesh.m_uvShaderBufferView ? subMesh.m_uvVertexBufferView.GetByteOffset() : 0;
  339. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  340. {
  341. MaterialInfo& materialInfo = materialInfos[subMesh.m_globalIndex];
  342. ConvertMaterial(materialInfo, subMesh.m_material, deviceIndex);
  343. auto& meshBufferIndices = m_meshBufferIndices[deviceIndex];
  344. // add mesh buffers
  345. meshInfo.m_bufferStartIndex = meshBufferIndices.AddEntry(
  346. {
  347. #if USE_BINDLESS_SRG
  348. subMesh.m_indexShaderBufferView.get() ? subMesh.m_indexShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  349. subMesh.m_positionShaderBufferView.get() ? subMesh.m_positionShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  350. subMesh.m_normalShaderBufferView.get() ? subMesh.m_normalShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  351. subMesh.m_tangentShaderBufferView.get() ? subMesh.m_tangentShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  352. subMesh.m_bitangentShaderBufferView.get() ? subMesh.m_bitangentShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  353. subMesh.m_uvShaderBufferView.get() ? subMesh.m_uvShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex
  354. #else
  355. m_meshBuffers.AddResource(subMesh.m_indexShaderBufferView.get()),
  356. m_meshBuffers.AddResource(subMesh.m_positionShaderBufferView.get()),
  357. m_meshBuffers.AddResource(subMesh.m_normalShaderBufferView.get()),
  358. m_meshBuffers.AddResource(subMesh.m_tangentShaderBufferView.get()),
  359. m_meshBuffers.AddResource(subMesh.m_bitangentShaderBufferView.get()),
  360. m_meshBuffers.AddResource(subMesh.m_uvShaderBufferView.get())
  361. #endif
  362. });
  363. // add reflection probe data
  364. if (mesh.m_reflectionProbe.m_reflectionProbeCubeMap.get())
  365. {
  366. materialInfo.m_reflectionProbeCubeMapIndex = mesh.m_reflectionProbe.m_reflectionProbeCubeMap->GetImageView()->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex();
  367. if (materialInfo.m_reflectionProbeCubeMapIndex != InvalidIndex)
  368. {
  369. reflectionProbeModelToWorld3x4.StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorld.data());
  370. reflectionProbeModelToWorld3x4.GetInverseFull().StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorldInverse.data());
  371. mesh.m_reflectionProbe.m_outerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_outerObbHalfLengths.data());
  372. mesh.m_reflectionProbe.m_innerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_innerObbHalfLengths.data());
  373. materialInfo.m_reflectionProbeData.m_useReflectionProbe = true;
  374. materialInfo.m_reflectionProbeData.m_useParallaxCorrection = mesh.m_reflectionProbe.m_useParallaxCorrection;
  375. materialInfo.m_reflectionProbeData.m_exposure = mesh.m_reflectionProbe.m_exposure;
  376. }
  377. }
  378. }
  379. }
  380. m_revision++;
  381. m_subMeshCount += aznumeric_cast<uint32_t>(subMeshes.size());
  382. m_meshInfoBufferNeedsUpdate = true;
  383. m_materialInfoBufferNeedsUpdate = true;
  384. m_indexListNeedsUpdate = true;
  385. }
  386. void RayTracingFeatureProcessor::RemoveMesh(const AZ::Uuid& uuid)
  387. {
  388. if (!m_rayTracingEnabled)
  389. {
  390. return;
  391. }
  392. // lock the mutex to protect the mesh and BLAS lists
  393. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  394. MeshMap::iterator itMesh = m_meshes.find(uuid);
  395. if (itMesh != m_meshes.end())
  396. {
  397. Mesh& mesh = itMesh->second;
  398. // decrement the count from the BLAS instances, and check to see if we can remove them
  399. BlasInstanceMap::iterator itBlas = m_blasInstanceMap.find(mesh.m_assetId);
  400. if (itBlas != m_blasInstanceMap.end())
  401. {
  402. itBlas->second.m_count--;
  403. if (itBlas->second.m_count == 0)
  404. {
  405. if (itBlas->second.m_isSkinnedMesh)
  406. {
  407. --m_skinnedMeshCount;
  408. }
  409. m_blasInstanceMap.erase(itBlas);
  410. }
  411. }
  412. // remove the SubMeshes
  413. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  414. {
  415. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  416. uint32_t globalIndex = subMesh.m_globalIndex;
  417. MeshInfo& meshInfo = m_meshInfos[globalIndex];
  418. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  419. {
  420. meshBufferIndices.RemoveEntry(meshInfo.m_bufferStartIndex);
  421. }
  422. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  423. {
  424. MaterialInfo& materialInfo = m_materialInfos[deviceIndex][globalIndex];
  425. materialTextureIndices.RemoveEntry(materialInfo.m_textureStartIndex);
  426. }
  427. #if !USE_BINDLESS_SRG
  428. m_meshBuffers.RemoveResource(subMesh.m_indexShaderBufferView.get());
  429. m_meshBuffers.RemoveResource(subMesh.m_positionShaderBufferView.get());
  430. m_meshBuffers.RemoveResource(subMesh.m_normalShaderBufferView.get());
  431. m_meshBuffers.RemoveResource(subMesh.m_tangentShaderBufferView.get());
  432. m_meshBuffers.RemoveResource(subMesh.m_bitangentShaderBufferView.get());
  433. m_meshBuffers.RemoveResource(subMesh.m_uvShaderBufferView.get());
  434. m_materialTextures.RemoveResource(subMesh.m_baseColorImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  435. m_materialTextures.RemoveResource(subMesh.m_normalImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  436. m_materialTextures.RemoveResource(subMesh.m_metallicImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  437. m_materialTextures.RemoveResource(subMesh.m_roughnessImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  438. m_materialTextures.RemoveResource(subMesh.m_emissiveImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  439. #endif
  440. if (globalIndex < m_subMeshes.size() - 1)
  441. {
  442. // the subMesh we're removing is in the middle of the global lists, remove by swapping the last element to its position in the list
  443. m_subMeshes[globalIndex] = m_subMeshes.back();
  444. m_meshInfos[globalIndex] = m_meshInfos.back();
  445. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  446. {
  447. materialInfos[globalIndex] = materialInfos.back();
  448. }
  449. // update the global index for the swapped subMesh
  450. m_subMeshes[globalIndex].m_globalIndex = globalIndex;
  451. // update the global index in the parent Mesh' subMesh list
  452. Mesh* swappedSubMeshParent = m_subMeshes[globalIndex].m_mesh;
  453. uint32_t swappedSubMeshIndex = m_subMeshes[globalIndex].m_subMeshIndex;
  454. swappedSubMeshParent->m_subMeshIndices[swappedSubMeshIndex] = globalIndex;
  455. }
  456. m_subMeshes.pop_back();
  457. m_meshInfos.pop_back();
  458. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  459. {
  460. materialInfos.pop_back();
  461. }
  462. }
  463. // remove from the Mesh list
  464. m_subMeshCount -= aznumeric_cast<uint32_t>(mesh.m_subMeshIndices.size());
  465. m_meshes.erase(itMesh);
  466. m_revision++;
  467. // reset all data structures if all meshes were removed (i.e., empty scene)
  468. if (m_subMeshCount == 0)
  469. {
  470. m_meshes.clear();
  471. m_subMeshes.clear();
  472. m_meshInfos.clear();
  473. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  474. {
  475. materialInfos.clear();
  476. }
  477. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  478. {
  479. meshBufferIndices.Reset();
  480. }
  481. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  482. {
  483. materialTextureIndices.Reset();
  484. }
  485. #if !USE_BINDLESS_SRG
  486. m_meshBuffers.Reset();
  487. m_materialTextures.Reset();
  488. #endif
  489. }
  490. }
  491. m_meshInfoBufferNeedsUpdate = true;
  492. m_materialInfoBufferNeedsUpdate = true;
  493. m_indexListNeedsUpdate = true;
  494. }
  495. void RayTracingFeatureProcessor::SetMeshTransform(const AZ::Uuid& uuid, const AZ::Transform transform, const AZ::Vector3 nonUniformScale)
  496. {
  497. if (!m_rayTracingEnabled)
  498. {
  499. return;
  500. }
  501. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  502. MeshMap::iterator itMesh = m_meshes.find(uuid);
  503. if (itMesh != m_meshes.end())
  504. {
  505. Mesh& mesh = itMesh->second;
  506. mesh.m_transform = transform;
  507. mesh.m_nonUniformScale = nonUniformScale;
  508. m_revision++;
  509. // create a world inverse transpose 3x4 matrix
  510. AZ::Transform noScaleTransform = mesh.m_transform;
  511. noScaleTransform.ExtractUniformScale();
  512. AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform);
  513. rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose();
  514. Matrix3x4 worldInvTranspose3x4 = Matrix3x4::CreateFromMatrix3x3(rotationMatrix);
  515. // update all MeshInfos for this Mesh with the new transform
  516. for (const auto& subMeshIndex : mesh.m_subMeshIndices)
  517. {
  518. MeshInfo& meshInfo = m_meshInfos[subMeshIndex];
  519. worldInvTranspose3x4.StoreToRowMajorFloat12(meshInfo.m_worldInvTranspose.data());
  520. }
  521. m_meshInfoBufferNeedsUpdate = true;
  522. }
  523. }
  524. void RayTracingFeatureProcessor::SetMeshReflectionProbe(const AZ::Uuid& uuid, const Mesh::ReflectionProbe& reflectionProbe)
  525. {
  526. if (!m_rayTracingEnabled)
  527. {
  528. return;
  529. }
  530. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  531. MeshMap::iterator itMesh = m_meshes.find(uuid);
  532. if (itMesh != m_meshes.end())
  533. {
  534. Mesh& mesh = itMesh->second;
  535. // update the Mesh reflection probe data
  536. mesh.m_reflectionProbe = reflectionProbe;
  537. // update all of the subMeshes
  538. const Data::Instance<RPI::Image>& reflectionProbeCubeMap = reflectionProbe.m_reflectionProbeCubeMap;
  539. uint32_t reflectionProbeCubeMapIndex = reflectionProbeCubeMap.get() ? reflectionProbeCubeMap->GetImageView()->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex)->GetBindlessReadIndex() : InvalidIndex;
  540. Matrix3x4 reflectionProbeModelToWorld3x4 = Matrix3x4::CreateFromTransform(mesh.m_reflectionProbe.m_modelToWorld);
  541. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  542. {
  543. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  544. uint32_t globalIndex = subMesh.m_globalIndex;
  545. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  546. {
  547. MaterialInfo& materialInfo = materialInfos[globalIndex];
  548. materialInfo.m_reflectionProbeCubeMapIndex = reflectionProbeCubeMapIndex;
  549. if (materialInfo.m_reflectionProbeCubeMapIndex != InvalidIndex)
  550. {
  551. reflectionProbeModelToWorld3x4.StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorld.data());
  552. reflectionProbeModelToWorld3x4.GetInverseFull().StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorldInverse.data());
  553. mesh.m_reflectionProbe.m_outerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_outerObbHalfLengths.data());
  554. mesh.m_reflectionProbe.m_innerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_innerObbHalfLengths.data());
  555. materialInfo.m_reflectionProbeData.m_useReflectionProbe = true;
  556. materialInfo.m_reflectionProbeData.m_useParallaxCorrection = mesh.m_reflectionProbe.m_useParallaxCorrection;
  557. materialInfo.m_reflectionProbeData.m_exposure = mesh.m_reflectionProbe.m_exposure;
  558. }
  559. else
  560. {
  561. materialInfo.m_reflectionProbeData.m_useReflectionProbe = false;
  562. }
  563. }
  564. }
  565. m_materialInfoBufferNeedsUpdate = true;
  566. }
  567. }
  568. void RayTracingFeatureProcessor::SetMeshMaterials(const AZ::Uuid& uuid, const SubMeshMaterialVector& subMeshMaterials)
  569. {
  570. if (!m_rayTracingEnabled)
  571. {
  572. return;
  573. }
  574. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  575. MeshMap::iterator itMesh = m_meshes.find(uuid);
  576. if (itMesh != m_meshes.end())
  577. {
  578. Mesh& mesh = itMesh->second;
  579. AZ_Assert(
  580. subMeshMaterials.size() == mesh.m_subMeshIndices.size(),
  581. "The size of subMeshes in SetMeshMaterial must be the same as in AddMesh");
  582. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  583. {
  584. const SubMesh& subMesh = m_subMeshes[subMeshIndex];
  585. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  586. {
  587. ConvertMaterial(materialInfos[subMesh.m_globalIndex], subMeshMaterials[subMesh.m_subMeshIndex], deviceIndex);
  588. }
  589. }
  590. m_materialInfoBufferNeedsUpdate = true;
  591. m_indexListNeedsUpdate = true;
  592. }
  593. }
  594. void RayTracingFeatureProcessor::UpdateRayTracingSrgs()
  595. {
  596. AZ_PROFILE_SCOPE(AzRender, "RayTracingFeatureProcessor::UpdateRayTracingSrgs");
  597. if (!m_tlas->GetTlasBuffer())
  598. {
  599. return;
  600. }
  601. if (m_rayTracingSceneSrg->IsQueuedForCompile() || m_rayTracingMaterialSrg->IsQueuedForCompile())
  602. {
  603. //[GFX TODO][ATOM-14792] AtomSampleViewer: Reset scene and feature processors before switching to sample
  604. return;
  605. }
  606. // lock the mutex to protect the mesh and BLAS lists
  607. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  608. if (HasMeshGeometry())
  609. {
  610. UpdateMeshInfoBuffer();
  611. }
  612. if (HasProceduralGeometry())
  613. {
  614. UpdateProceduralGeometryInfoBuffer();
  615. }
  616. if (HasGeometry())
  617. {
  618. UpdateMaterialInfoBuffer();
  619. UpdateIndexLists();
  620. }
  621. UpdateRayTracingSceneSrg();
  622. UpdateRayTracingMaterialSrg();
  623. }
  624. void RayTracingFeatureProcessor::UpdateMeshInfoBuffer()
  625. {
  626. if (m_meshInfoBufferNeedsUpdate)
  627. {
  628. m_meshInfoGpuBuffer.AdvanceCurrentBufferAndUpdateData(m_meshInfos);
  629. m_meshInfoBufferNeedsUpdate = false;
  630. }
  631. }
  632. void RayTracingFeatureProcessor::UpdateProceduralGeometryInfoBuffer()
  633. {
  634. if (!m_proceduralGeometryInfoBufferNeedsUpdate)
  635. {
  636. return;
  637. }
  638. AZStd::vector<uint32_t> proceduralGeometryInfo;
  639. proceduralGeometryInfo.reserve(m_proceduralGeometry.size() * 2);
  640. for (const auto& proceduralGeometry : m_proceduralGeometry)
  641. {
  642. proceduralGeometryInfo.push_back(proceduralGeometry.m_typeHandle->m_bindlessBufferIndex);
  643. proceduralGeometryInfo.push_back(proceduralGeometry.m_localInstanceIndex);
  644. }
  645. m_proceduralGeometryInfoGpuBuffer.AdvanceCurrentBufferAndUpdateData(proceduralGeometryInfo);
  646. m_proceduralGeometryInfoBufferNeedsUpdate = false;
  647. }
  648. void RayTracingFeatureProcessor::UpdateMaterialInfoBuffer()
  649. {
  650. if (m_materialInfoBufferNeedsUpdate)
  651. {
  652. m_materialInfoGpuBuffer.AdvanceCurrentElement();
  653. m_materialInfoGpuBuffer.CreateOrResizeCurrentBufferWithElementCount<MaterialInfo>(
  654. m_subMeshCount + m_proceduralGeometryMaterialInfos.begin()->second.size());
  655. m_materialInfoGpuBuffer.UpdateCurrentBufferData(m_materialInfos);
  656. m_materialInfoGpuBuffer.UpdateCurrentBufferData(m_proceduralGeometryMaterialInfos, m_subMeshCount);
  657. m_materialInfoBufferNeedsUpdate = false;
  658. }
  659. }
  660. void RayTracingFeatureProcessor::UpdateIndexLists()
  661. {
  662. if (m_indexListNeedsUpdate)
  663. {
  664. #if !USE_BINDLESS_SRG
  665. // resolve to the true indices using the indirection list
  666. // Note: this is done on the CPU to avoid double-indirection in the shader
  667. IndexVector resolvedMeshBufferIndices(m_meshBufferIndices.GetIndexList().size());
  668. uint32_t resolvedMeshBufferIndex = 0;
  669. for (auto& meshBufferIndex : m_meshBufferIndices.GetIndexList())
  670. {
  671. if (!m_meshBufferIndices.IsValidIndex(meshBufferIndex))
  672. {
  673. resolvedMeshBufferIndices[resolvedMeshBufferIndex++] = InvalidIndex;
  674. }
  675. else
  676. {
  677. resolvedMeshBufferIndices[resolvedMeshBufferIndex++] = m_meshBuffers.GetIndirectionList()[meshBufferIndex];
  678. }
  679. }
  680. m_meshBufferIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(resolvedMeshBufferIndices);
  681. #else
  682. AZStd::unordered_map<int, const void*> rawMeshData;
  683. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  684. {
  685. rawMeshData[deviceIndex] = meshBufferIndices.GetIndexList().data();
  686. }
  687. size_t newMeshBufferIndicesByteCount = m_meshBufferIndices.begin()->second.GetIndexList().size() * sizeof(uint32_t);
  688. m_meshBufferIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(rawMeshData, newMeshBufferIndicesByteCount);
  689. #endif
  690. #if !USE_BINDLESS_SRG
  691. // resolve to the true indices using the indirection list
  692. // Note: this is done on the CPU to avoid double-indirection in the shader
  693. IndexVector resolvedMaterialTextureIndices(m_materialTextureIndices.GetIndexList().size());
  694. uint32_t resolvedMaterialTextureIndex = 0;
  695. for (auto& materialTextureIndex : m_materialTextureIndices.GetIndexList())
  696. {
  697. if (!m_materialTextureIndices.IsValidIndex(materialTextureIndex))
  698. {
  699. resolvedMaterialTextureIndices[resolvedMaterialTextureIndex++] = InvalidIndex;
  700. }
  701. else
  702. {
  703. resolvedMaterialTextureIndices[resolvedMaterialTextureIndex++] = m_materialTextures.GetIndirectionList()[materialTextureIndex];
  704. }
  705. }
  706. m_materialTextureIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(resolvedMaterialTextureIndices);
  707. #else
  708. AZStd::unordered_map<int, const void*> rawMaterialData;
  709. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  710. {
  711. rawMaterialData[deviceIndex] = materialTextureIndices.GetIndexList().data();
  712. }
  713. size_t newMaterialTextureIndicesByteCount = m_materialTextureIndices.begin()->second.GetIndexList().size() * sizeof(uint32_t);
  714. m_materialTextureIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(rawMaterialData, newMaterialTextureIndicesByteCount);
  715. #endif
  716. m_indexListNeedsUpdate = false;
  717. }
  718. }
  719. void RayTracingFeatureProcessor::UpdateRayTracingSceneSrg()
  720. {
  721. const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingSceneSrg->GetLayout();
  722. RHI::ShaderInputImageIndex imageIndex;
  723. RHI::ShaderInputBufferIndex bufferIndex;
  724. RHI::ShaderInputConstantIndex constantIndex;
  725. // TLAS
  726. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_tlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  727. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  728. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_scene"));
  729. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_tlas->GetTlasBuffer()->BuildBufferView(bufferViewDescriptor).get());
  730. // directional lights
  731. const auto directionalLightFP = GetParentScene()->GetFeatureProcessor<DirectionalLightFeatureProcessor>();
  732. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_directionalLights"));
  733. m_rayTracingSceneSrg->SetBufferView(
  734. bufferIndex,
  735. directionalLightFP->GetLightBuffer()->GetBufferView());
  736. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_directionalLightCount"));
  737. m_rayTracingSceneSrg->SetConstant(constantIndex, directionalLightFP->GetLightCount());
  738. // simple point lights
  739. const auto simplePointLightFP = GetParentScene()->GetFeatureProcessor<SimplePointLightFeatureProcessor>();
  740. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_simplePointLights"));
  741. m_rayTracingSceneSrg->SetBufferView(
  742. bufferIndex,
  743. simplePointLightFP->GetLightBuffer()->GetBufferView());
  744. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_simplePointLightCount"));
  745. m_rayTracingSceneSrg->SetConstant(constantIndex, simplePointLightFP->GetLightCount());
  746. // simple spot lights
  747. const auto simpleSpotLightFP = GetParentScene()->GetFeatureProcessor<SimpleSpotLightFeatureProcessor>();
  748. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_simpleSpotLights"));
  749. m_rayTracingSceneSrg->SetBufferView(
  750. bufferIndex,
  751. simpleSpotLightFP->GetLightBuffer()->GetBufferView());
  752. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_simpleSpotLightCount"));
  753. m_rayTracingSceneSrg->SetConstant(constantIndex, simpleSpotLightFP->GetLightCount());
  754. // point lights (sphere)
  755. const auto pointLightFP = GetParentScene()->GetFeatureProcessor<PointLightFeatureProcessor>();
  756. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_pointLights"));
  757. m_rayTracingSceneSrg->SetBufferView(
  758. bufferIndex,
  759. pointLightFP->GetLightBuffer()->GetBufferView());
  760. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_pointLightCount"));
  761. m_rayTracingSceneSrg->SetConstant(constantIndex, pointLightFP->GetLightCount());
  762. // disk lights
  763. const auto diskLightFP = GetParentScene()->GetFeatureProcessor<DiskLightFeatureProcessor>();
  764. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_diskLights"));
  765. m_rayTracingSceneSrg->SetBufferView(
  766. bufferIndex,
  767. diskLightFP->GetLightBuffer()->GetBufferView());
  768. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_diskLightCount"));
  769. m_rayTracingSceneSrg->SetConstant(constantIndex, diskLightFP->GetLightCount());
  770. // capsule lights
  771. const auto capsuleLightFP = GetParentScene()->GetFeatureProcessor<CapsuleLightFeatureProcessor>();
  772. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_capsuleLights"));
  773. m_rayTracingSceneSrg->SetBufferView(
  774. bufferIndex,
  775. capsuleLightFP->GetLightBuffer()->GetBufferView());
  776. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_capsuleLightCount"));
  777. m_rayTracingSceneSrg->SetConstant(constantIndex, capsuleLightFP->GetLightCount());
  778. // quad lights
  779. const auto quadLightFP = GetParentScene()->GetFeatureProcessor<QuadLightFeatureProcessor>();
  780. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_quadLights"));
  781. m_rayTracingSceneSrg->SetBufferView(
  782. bufferIndex,
  783. quadLightFP->GetLightBuffer()->GetBufferView());
  784. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_quadLightCount"));
  785. m_rayTracingSceneSrg->SetConstant(constantIndex, quadLightFP->GetLightCount());
  786. // diffuse environment map for sky hits
  787. ImageBasedLightFeatureProcessor* imageBasedLightFeatureProcessor = GetParentScene()->GetFeatureProcessor<ImageBasedLightFeatureProcessor>();
  788. if (imageBasedLightFeatureProcessor)
  789. {
  790. imageIndex = srgLayout->FindShaderInputImageIndex(AZ::Name("m_diffuseEnvMap"));
  791. m_rayTracingSceneSrg->SetImage(imageIndex, imageBasedLightFeatureProcessor->GetDiffuseImage());
  792. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblOrientation"));
  793. m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetOrientation());
  794. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblExposure"));
  795. m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetExposure());
  796. }
  797. if (m_meshInfoGpuBuffer.IsCurrentBufferValid())
  798. {
  799. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshInfo"));
  800. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshInfoGpuBuffer.GetCurrentBufferView());
  801. }
  802. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_meshInfoCount"));
  803. m_rayTracingSceneSrg->SetConstant(constantIndex, m_subMeshCount);
  804. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshBufferIndices"));
  805. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshBufferIndicesGpuBuffer.GetCurrentBufferView());
  806. if (m_proceduralGeometryInfoGpuBuffer.IsCurrentBufferValid())
  807. {
  808. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_proceduralGeometryInfo"));
  809. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_proceduralGeometryInfoGpuBuffer.GetCurrentBufferView());
  810. }
  811. #if !USE_BINDLESS_SRG
  812. RHI::ShaderInputBufferUnboundedArrayIndex bufferUnboundedArrayIndex = srgLayout->FindShaderInputBufferUnboundedArrayIndex(AZ::Name("m_meshBuffers"));
  813. m_rayTracingSceneSrg->SetBufferViewUnboundedArray(bufferUnboundedArrayIndex, m_meshBuffers.GetResourceList());
  814. #endif
  815. m_rayTracingSceneSrg->Compile();
  816. }
  817. void RayTracingFeatureProcessor::UpdateRayTracingMaterialSrg()
  818. {
  819. const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingMaterialSrg->GetLayout();
  820. RHI::ShaderInputBufferIndex bufferIndex;
  821. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_materialInfo"));
  822. m_rayTracingMaterialSrg->SetBufferView(bufferIndex, m_materialInfoGpuBuffer.GetCurrentBufferView());
  823. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_materialTextureIndices"));
  824. m_rayTracingMaterialSrg->SetBufferView(bufferIndex, m_materialTextureIndicesGpuBuffer.GetCurrentBufferView());
  825. #if !USE_BINDLESS_SRG
  826. RHI::ShaderInputImageUnboundedArrayIndex textureUnboundedArrayIndex = srgLayout->FindShaderInputImageUnboundedArrayIndex(AZ::Name("m_materialTextures"));
  827. m_rayTracingMaterialSrg->SetImageViewUnboundedArray(textureUnboundedArrayIndex, m_materialTextures.GetResourceList());
  828. #endif
  829. m_rayTracingMaterialSrg->Compile();
  830. }
  831. void RayTracingFeatureProcessor::OnRenderPipelineChanged([[maybe_unused]] RPI::RenderPipeline* renderPipeline, RPI::SceneNotification::RenderPipelineChangeType changeType)
  832. {
  833. if (!m_rayTracingEnabled)
  834. {
  835. return;
  836. }
  837. // only enable the RayTracingAccelerationStructurePass on the first pipeline in this scene, this will avoid multiple updates to the same AS
  838. bool enabled = true;
  839. if (changeType == RPI::SceneNotification::RenderPipelineChangeType::Added
  840. || changeType == RPI::SceneNotification::RenderPipelineChangeType::Removed)
  841. {
  842. AZ::RPI::PassFilter passFilter = AZ::RPI::PassFilter::CreateWithPassName(AZ::Name("RayTracingAccelerationStructurePass"), GetParentScene());
  843. AZ::RPI::PassSystemInterface::Get()->ForEachPass(passFilter, [&enabled](AZ::RPI::Pass* pass) -> AZ::RPI::PassFilterExecutionFlow
  844. {
  845. pass->SetEnabled(enabled);
  846. enabled = false;
  847. return AZ::RPI::PassFilterExecutionFlow::ContinueVisitingPasses;
  848. });
  849. }
  850. }
  851. AZ::RHI::RayTracingAccelerationStructureBuildFlags RayTracingFeatureProcessor::CreateRayTracingAccelerationStructureBuildFlags(bool isSkinnedMesh)
  852. {
  853. AZ::RHI::RayTracingAccelerationStructureBuildFlags buildFlags;
  854. if (isSkinnedMesh)
  855. {
  856. buildFlags = AZ::RHI::RayTracingAccelerationStructureBuildFlags::ENABLE_UPDATE | AZ::RHI::RayTracingAccelerationStructureBuildFlags::FAST_BUILD;
  857. }
  858. else
  859. {
  860. buildFlags = AZ::RHI::RayTracingAccelerationStructureBuildFlags::FAST_TRACE;
  861. }
  862. return buildFlags;
  863. }
  864. void RayTracingFeatureProcessor::ConvertMaterial(MaterialInfo& materialInfo, const SubMeshMaterial& subMeshMaterial, int deviceIndex)
  865. {
  866. subMeshMaterial.m_baseColor.StoreToFloat4(materialInfo.m_baseColor.data());
  867. subMeshMaterial.m_emissiveColor.StoreToFloat4(materialInfo.m_emissiveColor.data());
  868. subMeshMaterial.m_irradianceColor.StoreToFloat4(materialInfo.m_irradianceColor.data());
  869. materialInfo.m_metallicFactor = subMeshMaterial.m_metallicFactor;
  870. materialInfo.m_roughnessFactor = subMeshMaterial.m_roughnessFactor;
  871. materialInfo.m_textureFlags = subMeshMaterial.m_textureFlags;
  872. if (materialInfo.m_textureStartIndex != InvalidIndex)
  873. {
  874. m_materialTextureIndices[deviceIndex].RemoveEntry(materialInfo.m_textureStartIndex);
  875. #if !USE_BINDLESS_SRG
  876. m_materialTextures.RemoveResource(subMeshMaterial.m_baseColorImageView.get());
  877. m_materialTextures.RemoveResource(subMeshMaterial.m_normalImageView.get());
  878. m_materialTextures.RemoveResource(subMeshMaterial.m_metallicImageView.get());
  879. m_materialTextures.RemoveResource(subMeshMaterial.m_roughnessImageView.get());
  880. m_materialTextures.RemoveResource(subMeshMaterial.m_emissiveImageView.get());
  881. #endif
  882. }
  883. materialInfo.m_textureStartIndex = m_materialTextureIndices[deviceIndex].AddEntry({
  884. #if USE_BINDLESS_SRG
  885. subMeshMaterial.m_baseColorImageView.get() ? subMeshMaterial.m_baseColorImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  886. subMeshMaterial.m_normalImageView.get() ? subMeshMaterial.m_normalImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  887. subMeshMaterial.m_metallicImageView.get() ? subMeshMaterial.m_metallicImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  888. subMeshMaterial.m_roughnessImageView.get() ? subMeshMaterial.m_roughnessImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  889. subMeshMaterial.m_emissiveImageView.get() ? subMeshMaterial.m_emissiveImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex
  890. #else
  891. m_materialTextures.AddResource(subMeshMaterial.m_baseColorImageView.get()),
  892. m_materialTextures.AddResource(subMeshMaterial.m_normalImageView.get()),
  893. m_materialTextures.AddResource(subMeshMaterial.m_metallicImageView.get()),
  894. m_materialTextures.AddResource(subMeshMaterial.m_roughnessImageView.get()),
  895. m_materialTextures.AddResource(subMeshMaterial.m_emissiveImageView.get())
  896. #endif
  897. });
  898. }
  899. }
  900. }