3
0

RayTracingFeatureProcessor.cpp 53 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RayTracing/RayTracingFeatureProcessor.h>
  9. #include <RayTracing/RayTracingPass.h>
  10. #include <Atom/Feature/TransformService/TransformServiceFeatureProcessor.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/RayTracingAccelerationStructure.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RPI.Public/Scene.h>
  15. #include <Atom/RPI.Public/Pass/PassFilter.h>
  16. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  17. #include <Atom/RPI.Reflect/Asset/AssetUtils.h>
  18. #include <Atom/Feature/ImageBasedLights/ImageBasedLightFeatureProcessor.h>
  19. #include <CoreLights/DirectionalLightFeatureProcessor.h>
  20. #include <CoreLights/SimplePointLightFeatureProcessor.h>
  21. #include <CoreLights/SimpleSpotLightFeatureProcessor.h>
  22. #include <CoreLights/PointLightFeatureProcessor.h>
  23. #include <CoreLights/DiskLightFeatureProcessor.h>
  24. #include <CoreLights/CapsuleLightFeatureProcessor.h>
  25. #include <CoreLights/QuadLightFeatureProcessor.h>
  26. namespace AZ
  27. {
  28. namespace Render
  29. {
  30. void RayTracingFeatureProcessor::Reflect(ReflectContext* context)
  31. {
  32. if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
  33. {
  34. serializeContext
  35. ->Class<RayTracingFeatureProcessor, FeatureProcessor>()
  36. ->Version(1);
  37. }
  38. }
  39. void RayTracingFeatureProcessor::Activate()
  40. {
  41. auto deviceMask{RHI::RHISystemInterface::Get()->GetRayTracingSupport()};
  42. m_rayTracingEnabled = (deviceMask != RHI::MultiDevice::NoDevices);
  43. if (!m_rayTracingEnabled)
  44. {
  45. return;
  46. }
  47. m_transformServiceFeatureProcessor = GetParentScene()->GetFeatureProcessor<TransformServiceFeatureProcessor>();
  48. // initialize the ray tracing buffer pools
  49. m_bufferPools = aznew RHI::RayTracingBufferPools;
  50. m_bufferPools->Init(deviceMask);
  51. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  52. for (auto deviceIndex{0}; deviceIndex < deviceCount; ++deviceIndex)
  53. {
  54. if ((AZStd::to_underlying(deviceMask) >> deviceIndex) & 1)
  55. {
  56. m_meshBufferIndices[deviceIndex] = {};
  57. m_materialTextureIndices[deviceIndex] = {};
  58. m_materialInfos[deviceIndex] = {};
  59. m_proceduralGeometryMaterialInfos[deviceIndex] = {};
  60. }
  61. }
  62. // create TLAS attachmentId
  63. AZStd::string uuidString = AZ::Uuid::CreateRandom().ToString<AZStd::string>();
  64. m_tlasAttachmentId = RHI::AttachmentId(AZStd::string::format("RayTracingTlasAttachmentId_%s", uuidString.c_str()));
  65. // create the TLAS object
  66. m_tlas = aznew RHI::RayTracingTlas;
  67. // load the RayTracingSrg asset asset
  68. m_rayTracingSrgAsset = RPI::AssetUtils::LoadCriticalAsset<RPI::ShaderAsset>("shaderlib/atom/features/rayTracing/raytracingsrgs.azshader");
  69. if (!m_rayTracingSrgAsset.IsReady())
  70. {
  71. AZ_Assert(false, "Failed to load RayTracingSrg asset");
  72. return;
  73. }
  74. // create the RayTracingSceneSrg
  75. m_rayTracingSceneSrg = RPI::ShaderResourceGroup::Create(m_rayTracingSrgAsset, Name("RayTracingSceneSrg"));
  76. AZ_Assert(m_rayTracingSceneSrg, "Failed to create RayTracingSceneSrg");
  77. // create the RayTracingMaterialSrg
  78. const AZ::Name rayTracingMaterialSrgName("RayTracingMaterialSrg");
  79. m_rayTracingMaterialSrg = RPI::ShaderResourceGroup::Create(m_rayTracingSrgAsset, Name("RayTracingMaterialSrg"));
  80. AZ_Assert(m_rayTracingMaterialSrg, "Failed to create RayTracingMaterialSrg");
  81. EnableSceneNotification();
  82. }
  83. void RayTracingFeatureProcessor::Deactivate()
  84. {
  85. DisableSceneNotification();
  86. }
  87. RayTracingFeatureProcessor::ProceduralGeometryTypeHandle RayTracingFeatureProcessor::RegisterProceduralGeometryType(
  88. const AZStd::string& name,
  89. const Data::Instance<RPI::Shader>& intersectionShader,
  90. const AZStd::string& intersectionShaderName,
  91. uint32_t bindlessBufferIndex)
  92. {
  93. ProceduralGeometryTypeHandle geometryTypeHandle;
  94. {
  95. ProceduralGeometryType proceduralGeometryType;
  96. proceduralGeometryType.m_name = AZ::Name(name);
  97. proceduralGeometryType.m_intersectionShader = intersectionShader;
  98. proceduralGeometryType.m_intersectionShaderName = AZ::Name(intersectionShaderName);
  99. proceduralGeometryType.m_bindlessBufferIndex = bindlessBufferIndex;
  100. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  101. geometryTypeHandle = m_proceduralGeometryTypes.insert(proceduralGeometryType);
  102. }
  103. m_proceduralGeometryTypeRevision++;
  104. return geometryTypeHandle;
  105. }
  106. void RayTracingFeatureProcessor::SetProceduralGeometryTypeBindlessBufferIndex(
  107. ProceduralGeometryTypeWeakHandle geometryTypeHandle, uint32_t bindlessBufferIndex)
  108. {
  109. if (!m_rayTracingEnabled)
  110. {
  111. return;
  112. }
  113. geometryTypeHandle->m_bindlessBufferIndex = bindlessBufferIndex;
  114. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  115. }
  116. void RayTracingFeatureProcessor::AddProceduralGeometry(
  117. ProceduralGeometryTypeWeakHandle geometryTypeHandle,
  118. const Uuid& uuid,
  119. const Aabb& aabb,
  120. const SubMeshMaterial& material,
  121. RHI::RayTracingAccelerationStructureInstanceInclusionMask instanceMask,
  122. uint32_t localInstanceIndex)
  123. {
  124. if (!m_rayTracingEnabled)
  125. {
  126. return;
  127. }
  128. RHI::Ptr<AZ::RHI::RayTracingBlas> rayTracingBlas = aznew AZ::RHI::RayTracingBlas;
  129. RHI::RayTracingBlasDescriptor blasDescriptor;
  130. blasDescriptor.Build()
  131. ->AABB(aabb)
  132. ;
  133. rayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &blasDescriptor, *m_bufferPools);
  134. ProceduralGeometry proceduralGeometry;
  135. proceduralGeometry.m_uuid = uuid;
  136. proceduralGeometry.m_typeHandle = geometryTypeHandle;
  137. proceduralGeometry.m_aabb = aabb;
  138. proceduralGeometry.m_instanceMask = static_cast<uint32_t>(instanceMask);
  139. proceduralGeometry.m_blas = rayTracingBlas;
  140. proceduralGeometry.m_localInstanceIndex = localInstanceIndex;
  141. MeshBlasInstance meshBlasInstance;
  142. meshBlasInstance.m_count = 1;
  143. meshBlasInstance.m_subMeshes.push_back(SubMeshBlasInstance{ rayTracingBlas });
  144. MaterialInfo materialInfo;
  145. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  146. m_proceduralGeometryLookup.emplace(uuid, m_proceduralGeometry.size());
  147. m_proceduralGeometry.push_back(proceduralGeometry);
  148. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  149. for (auto deviceIndex{0}; deviceIndex < deviceCount; ++deviceIndex)
  150. {
  151. m_proceduralGeometryMaterialInfos[deviceIndex].emplace_back();
  152. ConvertMaterial(m_proceduralGeometryMaterialInfos[deviceIndex].back(), material, deviceIndex);
  153. }
  154. m_blasInstanceMap.emplace(Data::AssetId(uuid), meshBlasInstance);
  155. geometryTypeHandle->m_instanceCount++;
  156. m_revision++;
  157. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  158. m_materialInfoBufferNeedsUpdate = true;
  159. m_indexListNeedsUpdate = true;
  160. }
  161. void RayTracingFeatureProcessor::SetProceduralGeometryTransform(
  162. const Uuid& uuid, const Transform& transform, const Vector3& nonUniformScale)
  163. {
  164. if (!m_rayTracingEnabled)
  165. {
  166. return;
  167. }
  168. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  169. if (auto it = m_proceduralGeometryLookup.find(uuid); it != m_proceduralGeometryLookup.end())
  170. {
  171. m_proceduralGeometry[it->second].m_transform = transform;
  172. m_proceduralGeometry[it->second].m_nonUniformScale = nonUniformScale;
  173. }
  174. m_revision++;
  175. }
  176. void RayTracingFeatureProcessor::SetProceduralGeometryLocalInstanceIndex(const Uuid& uuid, uint32_t localInstanceIndex)
  177. {
  178. if (!m_rayTracingEnabled)
  179. {
  180. return;
  181. }
  182. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  183. if (auto it = m_proceduralGeometryLookup.find(uuid); it != m_proceduralGeometryLookup.end())
  184. {
  185. m_proceduralGeometry[it->second].m_localInstanceIndex = localInstanceIndex;
  186. }
  187. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  188. }
  189. void RayTracingFeatureProcessor::SetProceduralGeometryMaterial(
  190. const Uuid& uuid, const RayTracingFeatureProcessor::SubMeshMaterial& material)
  191. {
  192. if (!m_rayTracingEnabled)
  193. {
  194. return;
  195. }
  196. auto deviceCount = RHI::RHISystemInterface::Get()->GetDeviceCount();
  197. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  198. for (auto deviceIndex{0}; deviceIndex < deviceCount; ++deviceIndex)
  199. {
  200. if (auto it = m_proceduralGeometryLookup.find(uuid); it != m_proceduralGeometryLookup.end())
  201. {
  202. ConvertMaterial(m_proceduralGeometryMaterialInfos[deviceIndex][it->second], material, deviceIndex);
  203. }
  204. }
  205. m_materialInfoBufferNeedsUpdate = true;
  206. }
  207. void RayTracingFeatureProcessor::RemoveProceduralGeometry(const Uuid& uuid)
  208. {
  209. if (!m_rayTracingEnabled)
  210. {
  211. return;
  212. }
  213. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  214. size_t materialInfoIndex = m_proceduralGeometryLookup[uuid];
  215. m_proceduralGeometry[materialInfoIndex].m_typeHandle->m_instanceCount--;
  216. if (materialInfoIndex < m_proceduralGeometry.size() - 1)
  217. {
  218. m_proceduralGeometryLookup[m_proceduralGeometry.back().m_uuid] = m_proceduralGeometryLookup[uuid];
  219. m_proceduralGeometry[materialInfoIndex] = m_proceduralGeometry.back();
  220. for (auto& [deviceIndex, materialInfos] : m_proceduralGeometryMaterialInfos)
  221. {
  222. materialInfos[materialInfoIndex] = materialInfos.back();
  223. }
  224. }
  225. m_proceduralGeometry.pop_back();
  226. for (auto& [deviceIndex, materialInfos] : m_proceduralGeometryMaterialInfos)
  227. {
  228. materialInfos.pop_back();
  229. }
  230. m_blasInstanceMap.erase(uuid);
  231. m_proceduralGeometryLookup.erase(uuid);
  232. m_revision++;
  233. m_proceduralGeometryInfoBufferNeedsUpdate = true;
  234. m_materialInfoBufferNeedsUpdate = true;
  235. m_indexListNeedsUpdate = true;
  236. }
  237. int RayTracingFeatureProcessor::GetProceduralGeometryCount(ProceduralGeometryTypeWeakHandle geometryTypeHandle) const
  238. {
  239. return geometryTypeHandle->m_instanceCount;
  240. }
  241. void RayTracingFeatureProcessor::AddMesh(const AZ::Uuid& uuid, const Mesh& rayTracingMesh, const SubMeshVector& subMeshes)
  242. {
  243. if (!m_rayTracingEnabled)
  244. {
  245. return;
  246. }
  247. // lock the mutex to protect the mesh and BLAS lists
  248. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  249. // check to see if we already have this mesh
  250. MeshMap::iterator itMesh = m_meshes.find(uuid);
  251. if (itMesh != m_meshes.end())
  252. {
  253. AZ_Assert(false, "AddMesh called on an existing Mesh, call RemoveMesh first");
  254. return;
  255. }
  256. // add the mesh
  257. m_meshes.insert(AZStd::make_pair(uuid, rayTracingMesh));
  258. Mesh& mesh = m_meshes[uuid];
  259. // add the subMeshes to the end of the global subMesh vector
  260. // Note 1: the MeshInfo and MaterialInfo vectors are parallel with the subMesh vector
  261. // Note 2: the list of indices for the subMeshes in the global vector are stored in the parent Mesh
  262. IndexVector subMeshIndices;
  263. uint32_t subMeshGlobalIndex = aznumeric_cast<uint32_t>(m_subMeshes.size());
  264. for (uint32_t subMeshIndex = 0; subMeshIndex < subMeshes.size(); ++subMeshIndex, ++subMeshGlobalIndex)
  265. {
  266. SubMesh& subMesh = m_subMeshes.emplace_back(subMeshes[subMeshIndex]);
  267. subMesh.m_mesh = &mesh;
  268. subMesh.m_subMeshIndex = subMeshIndex;
  269. subMesh.m_globalIndex = subMeshGlobalIndex;
  270. // add to the list of global subMeshIndices, which will be stored in the Mesh
  271. subMeshIndices.push_back(subMeshGlobalIndex);
  272. // add MeshInfo and MaterialInfo entries
  273. m_meshInfos.emplace_back();
  274. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  275. {
  276. materialInfos.emplace_back();
  277. }
  278. }
  279. mesh.m_subMeshIndices = subMeshIndices;
  280. // search for an existing BLAS instance entry for this mesh using the assetId
  281. BlasInstanceMap::iterator itMeshBlasInstance = m_blasInstanceMap.find(mesh.m_assetId);
  282. if (itMeshBlasInstance == m_blasInstanceMap.end())
  283. {
  284. // make a new BLAS map entry for this mesh
  285. MeshBlasInstance meshBlasInstance;
  286. meshBlasInstance.m_count = 1;
  287. meshBlasInstance.m_subMeshes.reserve(mesh.m_subMeshIndices.size());
  288. meshBlasInstance.m_isSkinnedMesh = mesh.m_isSkinnedMesh;
  289. itMeshBlasInstance = m_blasInstanceMap.insert({ mesh.m_assetId, meshBlasInstance }).first;
  290. if (mesh.m_isSkinnedMesh)
  291. {
  292. ++m_skinnedMeshCount;
  293. }
  294. }
  295. else
  296. {
  297. itMeshBlasInstance->second.m_count++;
  298. }
  299. // create the BLAS buffers for each sub-mesh, or re-use existing BLAS objects if they were already created.
  300. // Note: all sub-meshes must either create new BLAS objects or re-use existing ones, otherwise it's an error (it's the same model in both cases)
  301. // Note: the buffer is just reserved here, the BLAS is built in the RayTracingAccelerationStructurePass
  302. // Note: the build flags are set to be the same for each BLAS created for the mesh
  303. RHI::RayTracingAccelerationStructureBuildFlags buildFlags = CreateRayTracingAccelerationStructureBuildFlags(mesh.m_isSkinnedMesh);
  304. [[maybe_unused]] bool blasInstanceFound = false;
  305. for (uint32_t subMeshIndex = 0; subMeshIndex < mesh.m_subMeshIndices.size(); ++subMeshIndex)
  306. {
  307. SubMesh& subMesh = m_subMeshes[mesh.m_subMeshIndices[subMeshIndex]];
  308. RHI::RayTracingBlasDescriptor blasDescriptor;
  309. blasDescriptor.Build()
  310. ->Geometry()
  311. ->VertexFormat(subMesh.m_positionFormat)
  312. ->VertexBuffer(subMesh.m_positionVertexBufferView)
  313. ->IndexBuffer(subMesh.m_indexBufferView)
  314. ->BuildFlags(buildFlags)
  315. ;
  316. // determine if we have an existing BLAS object for this subMesh
  317. if (itMeshBlasInstance->second.m_subMeshes.size() >= subMeshIndex + 1)
  318. {
  319. // re-use existing BLAS
  320. subMesh.m_blas = itMeshBlasInstance->second.m_subMeshes[subMeshIndex].m_blas;
  321. // keep track of the fact that we re-used a BLAS
  322. blasInstanceFound = true;
  323. }
  324. else
  325. {
  326. AZ_Assert(blasInstanceFound == false, "Partial set of RayTracingBlas objects found for mesh");
  327. // create the BLAS object and store it in the BLAS list
  328. RHI::Ptr<RHI::RayTracingBlas> rayTracingBlas = aznew RHI::RayTracingBlas;
  329. itMeshBlasInstance->second.m_subMeshes.push_back({ rayTracingBlas });
  330. // create the buffers from the BLAS descriptor
  331. rayTracingBlas->CreateBuffers(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), &blasDescriptor, *m_bufferPools);
  332. // store the BLAS in the mesh
  333. subMesh.m_blas = rayTracingBlas;
  334. }
  335. }
  336. AZ::Transform noScaleTransform = mesh.m_transform;
  337. noScaleTransform.ExtractUniformScale();
  338. AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform);
  339. rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose();
  340. Matrix3x4 worldInvTranspose3x4 = Matrix3x4::CreateFromMatrix3x3(rotationMatrix);
  341. Matrix3x4 reflectionProbeModelToWorld3x4 = Matrix3x4::CreateFromTransform(mesh.m_reflectionProbe.m_modelToWorld);
  342. // store the mesh buffers and material textures in the resource lists
  343. for (uint32_t subMeshIndex : mesh.m_subMeshIndices)
  344. {
  345. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  346. MeshInfo& meshInfo = m_meshInfos[subMesh.m_globalIndex];
  347. worldInvTranspose3x4.StoreToRowMajorFloat12(meshInfo.m_worldInvTranspose.data());
  348. meshInfo.m_bufferFlags = subMesh.m_bufferFlags;
  349. AZ_Assert(subMesh.m_indexShaderBufferView.get(), "RayTracing Mesh IndexBuffer cannot be null");
  350. AZ_Assert(subMesh.m_positionShaderBufferView.get(), "RayTracing Mesh PositionBuffer cannot be null");
  351. AZ_Assert(subMesh.m_normalShaderBufferView.get(), "RayTracing Mesh NormalBuffer cannot be null");
  352. meshInfo.m_indexByteOffset = subMesh.m_indexBufferView.GetByteOffset();
  353. meshInfo.m_positionByteOffset = subMesh.m_positionVertexBufferView.GetByteOffset();
  354. meshInfo.m_normalByteOffset = subMesh.m_normalVertexBufferView.GetByteOffset();
  355. meshInfo.m_tangentByteOffset = subMesh.m_tangentShaderBufferView ? subMesh.m_tangentVertexBufferView.GetByteOffset() : 0;
  356. meshInfo.m_bitangentByteOffset = subMesh.m_bitangentShaderBufferView ? subMesh.m_bitangentVertexBufferView.GetByteOffset() : 0;
  357. meshInfo.m_uvByteOffset = subMesh.m_uvShaderBufferView ? subMesh.m_uvVertexBufferView.GetByteOffset() : 0;
  358. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  359. {
  360. MaterialInfo& materialInfo = materialInfos[subMesh.m_globalIndex];
  361. ConvertMaterial(materialInfo, subMesh.m_material, deviceIndex);
  362. auto& meshBufferIndices = m_meshBufferIndices[deviceIndex];
  363. // add mesh buffers
  364. meshInfo.m_bufferStartIndex = meshBufferIndices.AddEntry(
  365. {
  366. #if USE_BINDLESS_SRG
  367. subMesh.m_indexShaderBufferView.get() ? subMesh.m_indexShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  368. subMesh.m_positionShaderBufferView.get() ? subMesh.m_positionShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  369. subMesh.m_normalShaderBufferView.get() ? subMesh.m_normalShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  370. subMesh.m_tangentShaderBufferView.get() ? subMesh.m_tangentShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  371. subMesh.m_bitangentShaderBufferView.get() ? subMesh.m_bitangentShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  372. subMesh.m_uvShaderBufferView.get() ? subMesh.m_uvShaderBufferView->GetDeviceBufferView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex
  373. #else
  374. m_meshBuffers.AddResource(subMesh.m_indexShaderBufferView.get()),
  375. m_meshBuffers.AddResource(subMesh.m_positionShaderBufferView.get()),
  376. m_meshBuffers.AddResource(subMesh.m_normalShaderBufferView.get()),
  377. m_meshBuffers.AddResource(subMesh.m_tangentShaderBufferView.get()),
  378. m_meshBuffers.AddResource(subMesh.m_bitangentShaderBufferView.get()),
  379. m_meshBuffers.AddResource(subMesh.m_uvShaderBufferView.get())
  380. #endif
  381. });
  382. // add reflection probe data
  383. if (mesh.m_reflectionProbe.m_reflectionProbeCubeMap.get())
  384. {
  385. materialInfo.m_reflectionProbeCubeMapIndex = mesh.m_reflectionProbe.m_reflectionProbeCubeMap->GetImageView()->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex();
  386. if (materialInfo.m_reflectionProbeCubeMapIndex != InvalidIndex)
  387. {
  388. reflectionProbeModelToWorld3x4.StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorld.data());
  389. reflectionProbeModelToWorld3x4.GetInverseFull().StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorldInverse.data());
  390. mesh.m_reflectionProbe.m_outerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_outerObbHalfLengths.data());
  391. mesh.m_reflectionProbe.m_innerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_innerObbHalfLengths.data());
  392. materialInfo.m_reflectionProbeData.m_useReflectionProbe = true;
  393. materialInfo.m_reflectionProbeData.m_useParallaxCorrection = mesh.m_reflectionProbe.m_useParallaxCorrection;
  394. materialInfo.m_reflectionProbeData.m_exposure = mesh.m_reflectionProbe.m_exposure;
  395. }
  396. }
  397. }
  398. }
  399. m_revision++;
  400. m_subMeshCount += aznumeric_cast<uint32_t>(subMeshes.size());
  401. m_meshInfoBufferNeedsUpdate = true;
  402. m_materialInfoBufferNeedsUpdate = true;
  403. m_indexListNeedsUpdate = true;
  404. }
  405. void RayTracingFeatureProcessor::RemoveMesh(const AZ::Uuid& uuid)
  406. {
  407. if (!m_rayTracingEnabled)
  408. {
  409. return;
  410. }
  411. // lock the mutex to protect the mesh and BLAS lists
  412. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  413. MeshMap::iterator itMesh = m_meshes.find(uuid);
  414. if (itMesh != m_meshes.end())
  415. {
  416. Mesh& mesh = itMesh->second;
  417. // decrement the count from the BLAS instances, and check to see if we can remove them
  418. BlasInstanceMap::iterator itBlas = m_blasInstanceMap.find(mesh.m_assetId);
  419. if (itBlas != m_blasInstanceMap.end())
  420. {
  421. itBlas->second.m_count--;
  422. if (itBlas->second.m_count == 0)
  423. {
  424. if (itBlas->second.m_isSkinnedMesh)
  425. {
  426. --m_skinnedMeshCount;
  427. }
  428. m_blasInstanceMap.erase(itBlas);
  429. }
  430. }
  431. // remove the SubMeshes
  432. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  433. {
  434. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  435. uint32_t globalIndex = subMesh.m_globalIndex;
  436. MeshInfo& meshInfo = m_meshInfos[globalIndex];
  437. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  438. {
  439. meshBufferIndices.RemoveEntry(meshInfo.m_bufferStartIndex);
  440. }
  441. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  442. {
  443. MaterialInfo& materialInfo = m_materialInfos[deviceIndex][globalIndex];
  444. materialTextureIndices.RemoveEntry(materialInfo.m_textureStartIndex);
  445. }
  446. #if !USE_BINDLESS_SRG
  447. m_meshBuffers.RemoveResource(subMesh.m_indexShaderBufferView.get());
  448. m_meshBuffers.RemoveResource(subMesh.m_positionShaderBufferView.get());
  449. m_meshBuffers.RemoveResource(subMesh.m_normalShaderBufferView.get());
  450. m_meshBuffers.RemoveResource(subMesh.m_tangentShaderBufferView.get());
  451. m_meshBuffers.RemoveResource(subMesh.m_bitangentShaderBufferView.get());
  452. m_meshBuffers.RemoveResource(subMesh.m_uvShaderBufferView.get());
  453. m_materialTextures.RemoveResource(subMesh.m_baseColorImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  454. m_materialTextures.RemoveResource(subMesh.m_normalImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  455. m_materialTextures.RemoveResource(subMesh.m_metallicImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  456. m_materialTextures.RemoveResource(subMesh.m_roughnessImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  457. m_materialTextures.RemoveResource(subMesh.m_emissiveImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex).get());
  458. #endif
  459. if (globalIndex < m_subMeshes.size() - 1)
  460. {
  461. // the subMesh we're removing is in the middle of the global lists, remove by swapping the last element to its position in the list
  462. m_subMeshes[globalIndex] = m_subMeshes.back();
  463. m_meshInfos[globalIndex] = m_meshInfos.back();
  464. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  465. {
  466. materialInfos[globalIndex] = materialInfos.back();
  467. }
  468. // update the global index for the swapped subMesh
  469. m_subMeshes[globalIndex].m_globalIndex = globalIndex;
  470. // update the global index in the parent Mesh' subMesh list
  471. Mesh* swappedSubMeshParent = m_subMeshes[globalIndex].m_mesh;
  472. uint32_t swappedSubMeshIndex = m_subMeshes[globalIndex].m_subMeshIndex;
  473. swappedSubMeshParent->m_subMeshIndices[swappedSubMeshIndex] = globalIndex;
  474. }
  475. m_subMeshes.pop_back();
  476. m_meshInfos.pop_back();
  477. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  478. {
  479. materialInfos.pop_back();
  480. }
  481. }
  482. // remove from the Mesh list
  483. m_subMeshCount -= aznumeric_cast<uint32_t>(mesh.m_subMeshIndices.size());
  484. m_meshes.erase(itMesh);
  485. m_revision++;
  486. // reset all data structures if all meshes were removed (i.e., empty scene)
  487. if (m_subMeshCount == 0)
  488. {
  489. m_meshes.clear();
  490. m_subMeshes.clear();
  491. m_meshInfos.clear();
  492. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  493. {
  494. materialInfos.clear();
  495. }
  496. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  497. {
  498. meshBufferIndices.Reset();
  499. }
  500. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  501. {
  502. materialTextureIndices.Reset();
  503. }
  504. #if !USE_BINDLESS_SRG
  505. m_meshBuffers.Reset();
  506. m_materialTextures.Reset();
  507. #endif
  508. }
  509. }
  510. m_meshInfoBufferNeedsUpdate = true;
  511. m_materialInfoBufferNeedsUpdate = true;
  512. m_indexListNeedsUpdate = true;
  513. }
  514. void RayTracingFeatureProcessor::SetMeshTransform(const AZ::Uuid& uuid, const AZ::Transform transform, const AZ::Vector3 nonUniformScale)
  515. {
  516. if (!m_rayTracingEnabled)
  517. {
  518. return;
  519. }
  520. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  521. MeshMap::iterator itMesh = m_meshes.find(uuid);
  522. if (itMesh != m_meshes.end())
  523. {
  524. Mesh& mesh = itMesh->second;
  525. mesh.m_transform = transform;
  526. mesh.m_nonUniformScale = nonUniformScale;
  527. m_revision++;
  528. // create a world inverse transpose 3x4 matrix
  529. AZ::Transform noScaleTransform = mesh.m_transform;
  530. noScaleTransform.ExtractUniformScale();
  531. AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform);
  532. rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose();
  533. Matrix3x4 worldInvTranspose3x4 = Matrix3x4::CreateFromMatrix3x3(rotationMatrix);
  534. // update all MeshInfos for this Mesh with the new transform
  535. for (const auto& subMeshIndex : mesh.m_subMeshIndices)
  536. {
  537. MeshInfo& meshInfo = m_meshInfos[subMeshIndex];
  538. worldInvTranspose3x4.StoreToRowMajorFloat12(meshInfo.m_worldInvTranspose.data());
  539. }
  540. m_meshInfoBufferNeedsUpdate = true;
  541. }
  542. }
  543. void RayTracingFeatureProcessor::SetMeshReflectionProbe(const AZ::Uuid& uuid, const Mesh::ReflectionProbe& reflectionProbe)
  544. {
  545. if (!m_rayTracingEnabled)
  546. {
  547. return;
  548. }
  549. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  550. MeshMap::iterator itMesh = m_meshes.find(uuid);
  551. if (itMesh != m_meshes.end())
  552. {
  553. Mesh& mesh = itMesh->second;
  554. // update the Mesh reflection probe data
  555. mesh.m_reflectionProbe = reflectionProbe;
  556. // update all of the subMeshes
  557. const Data::Instance<RPI::Image>& reflectionProbeCubeMap = reflectionProbe.m_reflectionProbeCubeMap;
  558. uint32_t reflectionProbeCubeMapIndex = reflectionProbeCubeMap.get() ? reflectionProbeCubeMap->GetImageView()->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex)->GetBindlessReadIndex() : InvalidIndex;
  559. Matrix3x4 reflectionProbeModelToWorld3x4 = Matrix3x4::CreateFromTransform(mesh.m_reflectionProbe.m_modelToWorld);
  560. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  561. {
  562. SubMesh& subMesh = m_subMeshes[subMeshIndex];
  563. uint32_t globalIndex = subMesh.m_globalIndex;
  564. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  565. {
  566. MaterialInfo& materialInfo = materialInfos[globalIndex];
  567. materialInfo.m_reflectionProbeCubeMapIndex = reflectionProbeCubeMapIndex;
  568. if (materialInfo.m_reflectionProbeCubeMapIndex != InvalidIndex)
  569. {
  570. reflectionProbeModelToWorld3x4.StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorld.data());
  571. reflectionProbeModelToWorld3x4.GetInverseFull().StoreToRowMajorFloat12(materialInfo.m_reflectionProbeData.m_modelToWorldInverse.data());
  572. mesh.m_reflectionProbe.m_outerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_outerObbHalfLengths.data());
  573. mesh.m_reflectionProbe.m_innerObbHalfLengths.StoreToFloat3(materialInfo.m_reflectionProbeData.m_innerObbHalfLengths.data());
  574. materialInfo.m_reflectionProbeData.m_useReflectionProbe = true;
  575. materialInfo.m_reflectionProbeData.m_useParallaxCorrection = mesh.m_reflectionProbe.m_useParallaxCorrection;
  576. materialInfo.m_reflectionProbeData.m_exposure = mesh.m_reflectionProbe.m_exposure;
  577. }
  578. else
  579. {
  580. materialInfo.m_reflectionProbeData.m_useReflectionProbe = false;
  581. }
  582. }
  583. }
  584. m_materialInfoBufferNeedsUpdate = true;
  585. }
  586. }
  587. void RayTracingFeatureProcessor::SetMeshMaterials(const AZ::Uuid& uuid, const SubMeshMaterialVector& subMeshMaterials)
  588. {
  589. if (!m_rayTracingEnabled)
  590. {
  591. return;
  592. }
  593. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  594. MeshMap::iterator itMesh = m_meshes.find(uuid);
  595. if (itMesh != m_meshes.end())
  596. {
  597. Mesh& mesh = itMesh->second;
  598. AZ_Assert(
  599. subMeshMaterials.size() == mesh.m_subMeshIndices.size(),
  600. "The size of subMeshes in SetMeshMaterial must be the same as in AddMesh");
  601. for (auto& subMeshIndex : mesh.m_subMeshIndices)
  602. {
  603. const SubMesh& subMesh = m_subMeshes[subMeshIndex];
  604. for (auto& [deviceIndex, materialInfos] : m_materialInfos)
  605. {
  606. ConvertMaterial(materialInfos[subMesh.m_globalIndex], subMeshMaterials[subMesh.m_subMeshIndex], deviceIndex);
  607. }
  608. }
  609. m_materialInfoBufferNeedsUpdate = true;
  610. m_indexListNeedsUpdate = true;
  611. }
  612. }
  613. void RayTracingFeatureProcessor::UpdateRayTracingSrgs()
  614. {
  615. AZ_PROFILE_SCOPE(AzRender, "RayTracingFeatureProcessor::UpdateRayTracingSrgs");
  616. if (!m_tlas->GetTlasBuffer())
  617. {
  618. return;
  619. }
  620. if (m_rayTracingSceneSrg->IsQueuedForCompile() || m_rayTracingMaterialSrg->IsQueuedForCompile())
  621. {
  622. //[GFX TODO][ATOM-14792] AtomSampleViewer: Reset scene and feature processors before switching to sample
  623. return;
  624. }
  625. // lock the mutex to protect the mesh and BLAS lists
  626. AZStd::unique_lock<AZStd::mutex> lock(m_mutex);
  627. if (HasMeshGeometry())
  628. {
  629. UpdateMeshInfoBuffer();
  630. }
  631. if (HasProceduralGeometry())
  632. {
  633. UpdateProceduralGeometryInfoBuffer();
  634. }
  635. if (HasGeometry())
  636. {
  637. UpdateMaterialInfoBuffer();
  638. UpdateIndexLists();
  639. }
  640. UpdateRayTracingSceneSrg();
  641. UpdateRayTracingMaterialSrg();
  642. }
  643. void RayTracingFeatureProcessor::UpdateMeshInfoBuffer()
  644. {
  645. if (m_meshInfoBufferNeedsUpdate)
  646. {
  647. m_meshInfoGpuBuffer.AdvanceCurrentBufferAndUpdateData(m_meshInfos);
  648. m_meshInfoBufferNeedsUpdate = false;
  649. }
  650. }
  651. void RayTracingFeatureProcessor::UpdateProceduralGeometryInfoBuffer()
  652. {
  653. if (!m_proceduralGeometryInfoBufferNeedsUpdate)
  654. {
  655. return;
  656. }
  657. AZStd::vector<uint32_t> proceduralGeometryInfo;
  658. proceduralGeometryInfo.reserve(m_proceduralGeometry.size() * 2);
  659. for (const auto& proceduralGeometry : m_proceduralGeometry)
  660. {
  661. proceduralGeometryInfo.push_back(proceduralGeometry.m_typeHandle->m_bindlessBufferIndex);
  662. proceduralGeometryInfo.push_back(proceduralGeometry.m_localInstanceIndex);
  663. }
  664. m_proceduralGeometryInfoGpuBuffer.AdvanceCurrentBufferAndUpdateData(proceduralGeometryInfo);
  665. m_proceduralGeometryInfoBufferNeedsUpdate = false;
  666. }
  667. void RayTracingFeatureProcessor::UpdateMaterialInfoBuffer()
  668. {
  669. if (m_materialInfoBufferNeedsUpdate)
  670. {
  671. m_materialInfoGpuBuffer.AdvanceCurrentElement();
  672. m_materialInfoGpuBuffer.CreateOrResizeCurrentBufferWithElementCount<MaterialInfo>(
  673. m_subMeshCount + m_proceduralGeometryMaterialInfos.begin()->second.size());
  674. m_materialInfoGpuBuffer.UpdateCurrentBufferData(m_materialInfos);
  675. m_materialInfoGpuBuffer.UpdateCurrentBufferData(m_proceduralGeometryMaterialInfos, m_subMeshCount);
  676. m_materialInfoBufferNeedsUpdate = false;
  677. }
  678. }
  679. void RayTracingFeatureProcessor::UpdateIndexLists()
  680. {
  681. if (m_indexListNeedsUpdate)
  682. {
  683. #if !USE_BINDLESS_SRG
  684. // resolve to the true indices using the indirection list
  685. // Note: this is done on the CPU to avoid double-indirection in the shader
  686. IndexVector resolvedMeshBufferIndices(m_meshBufferIndices.GetIndexList().size());
  687. uint32_t resolvedMeshBufferIndex = 0;
  688. for (auto& meshBufferIndex : m_meshBufferIndices.GetIndexList())
  689. {
  690. if (!m_meshBufferIndices.IsValidIndex(meshBufferIndex))
  691. {
  692. resolvedMeshBufferIndices[resolvedMeshBufferIndex++] = InvalidIndex;
  693. }
  694. else
  695. {
  696. resolvedMeshBufferIndices[resolvedMeshBufferIndex++] = m_meshBuffers.GetIndirectionList()[meshBufferIndex];
  697. }
  698. }
  699. m_meshBufferIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(resolvedMeshBufferIndices);
  700. #else
  701. AZStd::unordered_map<int, const void*> rawMeshData;
  702. for (auto& [deviceIndex, meshBufferIndices] : m_meshBufferIndices)
  703. {
  704. rawMeshData[deviceIndex] = meshBufferIndices.GetIndexList().data();
  705. }
  706. size_t newMeshBufferIndicesByteCount = m_meshBufferIndices.begin()->second.GetIndexList().size() * sizeof(uint32_t);
  707. m_meshBufferIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(rawMeshData, newMeshBufferIndicesByteCount);
  708. #endif
  709. #if !USE_BINDLESS_SRG
  710. // resolve to the true indices using the indirection list
  711. // Note: this is done on the CPU to avoid double-indirection in the shader
  712. IndexVector resolvedMaterialTextureIndices(m_materialTextureIndices.GetIndexList().size());
  713. uint32_t resolvedMaterialTextureIndex = 0;
  714. for (auto& materialTextureIndex : m_materialTextureIndices.GetIndexList())
  715. {
  716. if (!m_materialTextureIndices.IsValidIndex(materialTextureIndex))
  717. {
  718. resolvedMaterialTextureIndices[resolvedMaterialTextureIndex++] = InvalidIndex;
  719. }
  720. else
  721. {
  722. resolvedMaterialTextureIndices[resolvedMaterialTextureIndex++] = m_materialTextures.GetIndirectionList()[materialTextureIndex];
  723. }
  724. }
  725. m_materialTextureIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(resolvedMaterialTextureIndices);
  726. #else
  727. AZStd::unordered_map<int, const void*> rawMaterialData;
  728. for (auto& [deviceIndex, materialTextureIndices] : m_materialTextureIndices)
  729. {
  730. rawMaterialData[deviceIndex] = materialTextureIndices.GetIndexList().data();
  731. }
  732. size_t newMaterialTextureIndicesByteCount = m_materialTextureIndices.begin()->second.GetIndexList().size() * sizeof(uint32_t);
  733. m_materialTextureIndicesGpuBuffer.AdvanceCurrentBufferAndUpdateData(rawMaterialData, newMaterialTextureIndicesByteCount);
  734. #endif
  735. m_indexListNeedsUpdate = false;
  736. }
  737. }
  738. void RayTracingFeatureProcessor::UpdateRayTracingSceneSrg()
  739. {
  740. const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingSceneSrg->GetLayout();
  741. RHI::ShaderInputImageIndex imageIndex;
  742. RHI::ShaderInputBufferIndex bufferIndex;
  743. RHI::ShaderInputConstantIndex constantIndex;
  744. // TLAS
  745. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_tlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  746. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  747. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_scene"));
  748. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_tlas->GetTlasBuffer()->BuildBufferView(bufferViewDescriptor).get());
  749. // directional lights
  750. const auto directionalLightFP = GetParentScene()->GetFeatureProcessor<DirectionalLightFeatureProcessor>();
  751. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_directionalLights"));
  752. m_rayTracingSceneSrg->SetBufferView(
  753. bufferIndex,
  754. directionalLightFP->GetLightBuffer()->GetBufferView());
  755. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_directionalLightCount"));
  756. m_rayTracingSceneSrg->SetConstant(constantIndex, directionalLightFP->GetLightCount());
  757. // simple point lights
  758. const auto simplePointLightFP = GetParentScene()->GetFeatureProcessor<SimplePointLightFeatureProcessor>();
  759. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_simplePointLights"));
  760. m_rayTracingSceneSrg->SetBufferView(
  761. bufferIndex,
  762. simplePointLightFP->GetLightBuffer()->GetBufferView());
  763. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_simplePointLightCount"));
  764. m_rayTracingSceneSrg->SetConstant(constantIndex, simplePointLightFP->GetLightCount());
  765. // simple spot lights
  766. const auto simpleSpotLightFP = GetParentScene()->GetFeatureProcessor<SimpleSpotLightFeatureProcessor>();
  767. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_simpleSpotLights"));
  768. m_rayTracingSceneSrg->SetBufferView(
  769. bufferIndex,
  770. simpleSpotLightFP->GetLightBuffer()->GetBufferView());
  771. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_simpleSpotLightCount"));
  772. m_rayTracingSceneSrg->SetConstant(constantIndex, simpleSpotLightFP->GetLightCount());
  773. // point lights (sphere)
  774. const auto pointLightFP = GetParentScene()->GetFeatureProcessor<PointLightFeatureProcessor>();
  775. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_pointLights"));
  776. m_rayTracingSceneSrg->SetBufferView(
  777. bufferIndex,
  778. pointLightFP->GetLightBuffer()->GetBufferView());
  779. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_pointLightCount"));
  780. m_rayTracingSceneSrg->SetConstant(constantIndex, pointLightFP->GetLightCount());
  781. // disk lights
  782. const auto diskLightFP = GetParentScene()->GetFeatureProcessor<DiskLightFeatureProcessor>();
  783. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_diskLights"));
  784. m_rayTracingSceneSrg->SetBufferView(
  785. bufferIndex,
  786. diskLightFP->GetLightBuffer()->GetBufferView());
  787. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_diskLightCount"));
  788. m_rayTracingSceneSrg->SetConstant(constantIndex, diskLightFP->GetLightCount());
  789. // capsule lights
  790. const auto capsuleLightFP = GetParentScene()->GetFeatureProcessor<CapsuleLightFeatureProcessor>();
  791. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_capsuleLights"));
  792. m_rayTracingSceneSrg->SetBufferView(
  793. bufferIndex,
  794. capsuleLightFP->GetLightBuffer()->GetBufferView());
  795. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_capsuleLightCount"));
  796. m_rayTracingSceneSrg->SetConstant(constantIndex, capsuleLightFP->GetLightCount());
  797. // quad lights
  798. const auto quadLightFP = GetParentScene()->GetFeatureProcessor<QuadLightFeatureProcessor>();
  799. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_quadLights"));
  800. m_rayTracingSceneSrg->SetBufferView(
  801. bufferIndex,
  802. quadLightFP->GetLightBuffer()->GetBufferView());
  803. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_quadLightCount"));
  804. m_rayTracingSceneSrg->SetConstant(constantIndex, quadLightFP->GetLightCount());
  805. // diffuse environment map for sky hits
  806. ImageBasedLightFeatureProcessor* imageBasedLightFeatureProcessor = GetParentScene()->GetFeatureProcessor<ImageBasedLightFeatureProcessor>();
  807. if (imageBasedLightFeatureProcessor)
  808. {
  809. imageIndex = srgLayout->FindShaderInputImageIndex(AZ::Name("m_diffuseEnvMap"));
  810. m_rayTracingSceneSrg->SetImage(imageIndex, imageBasedLightFeatureProcessor->GetDiffuseImage());
  811. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblOrientation"));
  812. m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetOrientation());
  813. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblExposure"));
  814. m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetExposure());
  815. }
  816. if (m_meshInfoGpuBuffer.IsCurrentBufferValid())
  817. {
  818. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshInfo"));
  819. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshInfoGpuBuffer.GetCurrentBufferView());
  820. }
  821. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_meshInfoCount"));
  822. m_rayTracingSceneSrg->SetConstant(constantIndex, m_subMeshCount);
  823. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshBufferIndices"));
  824. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshBufferIndicesGpuBuffer.GetCurrentBufferView());
  825. if (m_proceduralGeometryInfoGpuBuffer.IsCurrentBufferValid())
  826. {
  827. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_proceduralGeometryInfo"));
  828. m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_proceduralGeometryInfoGpuBuffer.GetCurrentBufferView());
  829. }
  830. #if !USE_BINDLESS_SRG
  831. RHI::ShaderInputBufferUnboundedArrayIndex bufferUnboundedArrayIndex = srgLayout->FindShaderInputBufferUnboundedArrayIndex(AZ::Name("m_meshBuffers"));
  832. m_rayTracingSceneSrg->SetBufferViewUnboundedArray(bufferUnboundedArrayIndex, m_meshBuffers.GetResourceList());
  833. #endif
  834. m_rayTracingSceneSrg->Compile();
  835. }
  836. void RayTracingFeatureProcessor::UpdateRayTracingMaterialSrg()
  837. {
  838. const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingMaterialSrg->GetLayout();
  839. RHI::ShaderInputBufferIndex bufferIndex;
  840. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_materialInfo"));
  841. m_rayTracingMaterialSrg->SetBufferView(bufferIndex, m_materialInfoGpuBuffer.GetCurrentBufferView());
  842. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_materialTextureIndices"));
  843. m_rayTracingMaterialSrg->SetBufferView(bufferIndex, m_materialTextureIndicesGpuBuffer.GetCurrentBufferView());
  844. #if !USE_BINDLESS_SRG
  845. RHI::ShaderInputImageUnboundedArrayIndex textureUnboundedArrayIndex = srgLayout->FindShaderInputImageUnboundedArrayIndex(AZ::Name("m_materialTextures"));
  846. m_rayTracingMaterialSrg->SetImageViewUnboundedArray(textureUnboundedArrayIndex, m_materialTextures.GetResourceList());
  847. #endif
  848. m_rayTracingMaterialSrg->Compile();
  849. }
  850. void RayTracingFeatureProcessor::OnRenderPipelineChanged([[maybe_unused]] RPI::RenderPipeline* renderPipeline, RPI::SceneNotification::RenderPipelineChangeType changeType)
  851. {
  852. if (!m_rayTracingEnabled)
  853. {
  854. return;
  855. }
  856. // only enable the RayTracingAccelerationStructurePass on the first pipeline in this scene, this will avoid multiple updates to the same AS
  857. bool enabled = true;
  858. if (changeType == RPI::SceneNotification::RenderPipelineChangeType::Added
  859. || changeType == RPI::SceneNotification::RenderPipelineChangeType::Removed)
  860. {
  861. AZ::RPI::PassFilter passFilter = AZ::RPI::PassFilter::CreateWithPassName(AZ::Name("RayTracingAccelerationStructurePass"), GetParentScene());
  862. AZ::RPI::PassSystemInterface::Get()->ForEachPass(passFilter, [&enabled](AZ::RPI::Pass* pass) -> AZ::RPI::PassFilterExecutionFlow
  863. {
  864. pass->SetEnabled(enabled);
  865. enabled = false;
  866. return AZ::RPI::PassFilterExecutionFlow::ContinueVisitingPasses;
  867. });
  868. }
  869. }
  870. AZ::RHI::RayTracingAccelerationStructureBuildFlags RayTracingFeatureProcessor::CreateRayTracingAccelerationStructureBuildFlags(bool isSkinnedMesh)
  871. {
  872. AZ::RHI::RayTracingAccelerationStructureBuildFlags buildFlags;
  873. if (isSkinnedMesh)
  874. {
  875. buildFlags = AZ::RHI::RayTracingAccelerationStructureBuildFlags::ENABLE_UPDATE | AZ::RHI::RayTracingAccelerationStructureBuildFlags::FAST_BUILD;
  876. }
  877. else
  878. {
  879. buildFlags = AZ::RHI::RayTracingAccelerationStructureBuildFlags::FAST_TRACE;
  880. }
  881. return buildFlags;
  882. }
  883. void RayTracingFeatureProcessor::ConvertMaterial(MaterialInfo& materialInfo, const SubMeshMaterial& subMeshMaterial, int deviceIndex)
  884. {
  885. subMeshMaterial.m_baseColor.StoreToFloat4(materialInfo.m_baseColor.data());
  886. subMeshMaterial.m_emissiveColor.StoreToFloat4(materialInfo.m_emissiveColor.data());
  887. subMeshMaterial.m_irradianceColor.StoreToFloat4(materialInfo.m_irradianceColor.data());
  888. materialInfo.m_metallicFactor = subMeshMaterial.m_metallicFactor;
  889. materialInfo.m_roughnessFactor = subMeshMaterial.m_roughnessFactor;
  890. materialInfo.m_textureFlags = subMeshMaterial.m_textureFlags;
  891. if (materialInfo.m_textureStartIndex != InvalidIndex)
  892. {
  893. m_materialTextureIndices[deviceIndex].RemoveEntry(materialInfo.m_textureStartIndex);
  894. #if !USE_BINDLESS_SRG
  895. m_materialTextures.RemoveResource(subMeshMaterial.m_baseColorImageView.get());
  896. m_materialTextures.RemoveResource(subMeshMaterial.m_normalImageView.get());
  897. m_materialTextures.RemoveResource(subMeshMaterial.m_metallicImageView.get());
  898. m_materialTextures.RemoveResource(subMeshMaterial.m_roughnessImageView.get());
  899. m_materialTextures.RemoveResource(subMeshMaterial.m_emissiveImageView.get());
  900. #endif
  901. }
  902. materialInfo.m_textureStartIndex = m_materialTextureIndices[deviceIndex].AddEntry({
  903. #if USE_BINDLESS_SRG
  904. subMeshMaterial.m_baseColorImageView.get() ? subMeshMaterial.m_baseColorImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  905. subMeshMaterial.m_normalImageView.get() ? subMeshMaterial.m_normalImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  906. subMeshMaterial.m_metallicImageView.get() ? subMeshMaterial.m_metallicImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  907. subMeshMaterial.m_roughnessImageView.get() ? subMeshMaterial.m_roughnessImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex,
  908. subMeshMaterial.m_emissiveImageView.get() ? subMeshMaterial.m_emissiveImageView->GetDeviceImageView(deviceIndex)->GetBindlessReadIndex() : InvalidIndex
  909. #else
  910. m_materialTextures.AddResource(subMeshMaterial.m_baseColorImageView.get()),
  911. m_materialTextures.AddResource(subMeshMaterial.m_normalImageView.get()),
  912. m_materialTextures.AddResource(subMeshMaterial.m_metallicImageView.get()),
  913. m_materialTextures.AddResource(subMeshMaterial.m_roughnessImageView.get()),
  914. m_materialTextures.AddResource(subMeshMaterial.m_emissiveImageView.get())
  915. #endif
  916. });
  917. }
  918. }
  919. }