ModelKdTree.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/std/numeric.h>
  9. #include <AzCore/std/limits.h>
  10. #include <Atom/RPI.Reflect/Model/ModelKdTree.h>
  11. #include <AzCore/Math/IntersectSegment.h>
  12. namespace AZ
  13. {
  14. namespace RPI
  15. {
  16. AZStd::tuple<ModelKdTree::ESplitAxis, float> ModelKdTree::SearchForBestSplitAxis(const AZ::Aabb& aabb)
  17. {
  18. const float xsize = aabb.GetXExtent();
  19. const float ysize = aabb.GetYExtent();
  20. const float zsize = aabb.GetZExtent();
  21. if (xsize >= ysize && xsize >= zsize)
  22. {
  23. return {ModelKdTree::eSA_X, aabb.GetMin().GetX() + xsize * 0.5f};
  24. }
  25. if (ysize >= zsize && ysize >= xsize)
  26. {
  27. return {ModelKdTree::eSA_Y, aabb.GetMin().GetY() + ysize * 0.5f};
  28. }
  29. return {ModelKdTree::eSA_Z, aabb.GetMin().GetZ() + zsize * 0.5f};
  30. }
  31. bool ModelKdTree::SplitNode(const AZ::Aabb& boundbox, const AZStd::vector<ObjectIdTriangleIndices>& indices, ModelKdTree::ESplitAxis splitAxis, float splitPos, SSplitInfo& outInfo)
  32. {
  33. if (splitAxis != ModelKdTree::eSA_X && splitAxis != ModelKdTree::eSA_Y && splitAxis != ModelKdTree::eSA_Z)
  34. {
  35. return false;
  36. }
  37. outInfo.m_aboveBoundbox = boundbox;
  38. outInfo.m_belowBoundbox = boundbox;
  39. {
  40. Vector3 maxBound = outInfo.m_aboveBoundbox.GetMax();
  41. maxBound.SetElement(splitAxis, splitPos);
  42. outInfo.m_aboveBoundbox.SetMax(maxBound);
  43. }
  44. {
  45. Vector3 minBound = outInfo.m_belowBoundbox.GetMin();
  46. minBound.SetElement(splitAxis, splitPos);
  47. outInfo.m_belowBoundbox.SetMin(minBound);
  48. }
  49. const AZ::u32 iIndexSize = aznumeric_cast<AZ::u32>(indices.size());
  50. outInfo.m_aboveIndices.reserve(iIndexSize);
  51. outInfo.m_belowIndices.reserve(iIndexSize);
  52. for (const auto& [nObjIndex, triangleIndices] : indices)
  53. {
  54. const auto& [first, second, third] = triangleIndices;
  55. const AZStd::span<const float>& positionBuffer = m_meshes[nObjIndex].m_vertexData;
  56. if (positionBuffer.empty())
  57. {
  58. continue;
  59. }
  60. // If the split axis is Y, this uses a Vector3 to store the Y positions of each vertex in the triangle.
  61. const AZStd::array<const float, 3> triangleVerticesValuesForThisSplitAxis {
  62. positionBuffer[first * 3 + splitAxis], positionBuffer[second * 3 + splitAxis], positionBuffer[third * 3 + splitAxis]
  63. };
  64. if (AZStd::any_of(begin(triangleVerticesValuesForThisSplitAxis), end(triangleVerticesValuesForThisSplitAxis), [splitPos](const float value) { return value < splitPos; }))
  65. {
  66. outInfo.m_aboveIndices.emplace_back(nObjIndex, triangleIndices);
  67. }
  68. if (AZStd::any_of(begin(triangleVerticesValuesForThisSplitAxis), end(triangleVerticesValuesForThisSplitAxis), [splitPos](const float value) { return value >= splitPos; }))
  69. {
  70. outInfo.m_belowIndices.emplace_back(nObjIndex, triangleIndices);
  71. }
  72. }
  73. // If either the top or bottom contain all the input indices, the triangles are too close to cut any
  74. // further and the split failed
  75. // Additionally, if too many triangles straddle the split-axis,
  76. // the triangles are too close and the split failed
  77. // [ATOM-15944] - Use a more sophisticated method to terminate KdTree generation
  78. return indices.size() != outInfo.m_aboveIndices.size() && indices.size() != outInfo.m_belowIndices.size()
  79. && aznumeric_cast<float>(outInfo.m_aboveIndices.size() + outInfo.m_belowIndices.size()) / aznumeric_cast<float>(indices.size()) < s_MaximumSplitAxisStraddlingTriangles;
  80. }
  81. bool ModelKdTree::Build(const ModelAsset* model)
  82. {
  83. if (model == nullptr)
  84. {
  85. return false;
  86. }
  87. ConstructMeshList(model, AZ::Transform::CreateIdentity());
  88. AZ::Aabb entireBoundBox = AZ::Aabb::CreateNull();
  89. // indices with object ids
  90. AZStd::vector<ObjectIdTriangleIndices> indices;
  91. const size_t totalSizeNeed = AZStd::accumulate(begin(m_meshes), end(m_meshes), size_t{0}, [](const size_t current, const MeshData& data)
  92. {
  93. return current + data.m_mesh->GetVertexCount();
  94. });
  95. indices.reserve(totalSizeNeed);
  96. for (AZ::u8 meshIndex = 0, meshCount = aznumeric_caster(m_meshes.size()); meshIndex < meshCount; ++meshIndex)
  97. {
  98. const AZStd::span<const float> positionBuffer = m_meshes[meshIndex].m_vertexData;
  99. for (size_t positionIndex = 0; positionIndex < positionBuffer.size(); positionIndex += 3)
  100. {
  101. entireBoundBox.AddPoint({positionBuffer[positionIndex], positionBuffer[positionIndex + 1], positionBuffer[positionIndex + 2]});
  102. }
  103. for (const TriangleIndices& triangleIndices : GetIndexBuffer(*m_meshes[meshIndex].m_mesh))
  104. {
  105. indices.emplace_back(meshIndex, triangleIndices);
  106. }
  107. }
  108. m_pRootNode = AZStd::make_unique<ModelKdTreeNode>();
  109. BuildRecursively(m_pRootNode.get(), entireBoundBox, indices);
  110. return true;
  111. }
  112. AZStd::span<const float> ModelKdTree::GetPositionsBuffer(const ModelLodAsset::Mesh& mesh)
  113. {
  114. AZStd::span<const float> positionBuffer = mesh.GetSemanticBufferTyped<float>(AZ::Name{"POSITION"});
  115. AZ_Warning("ModelKdTree", !positionBuffer.empty(), "Could not find position buffers in a mesh");
  116. return positionBuffer;
  117. }
  118. AZStd::span<const ModelKdTree::TriangleIndices> ModelKdTree::GetIndexBuffer(const ModelLodAsset::Mesh& mesh)
  119. {
  120. return mesh.GetIndexBufferTyped<ModelKdTree::TriangleIndices>();
  121. }
  122. void ModelKdTree::BuildRecursively(ModelKdTreeNode* pNode, const AZ::Aabb& boundbox, AZStd::vector<ObjectIdTriangleIndices>& indices)
  123. {
  124. pNode->SetBoundBox(boundbox);
  125. if (indices.size() <= s_MinimumVertexSizeInLeafNode)
  126. {
  127. pNode->SetVertexIndexBuffer(AZStd::move(indices));
  128. return;
  129. }
  130. const auto [splitAxis, splitPos] = SearchForBestSplitAxis(boundbox);
  131. pNode->SetSplitAxis(splitAxis);
  132. pNode->SetSplitPos(splitPos);
  133. SSplitInfo splitInfo;
  134. if (!SplitNode(boundbox, indices, splitAxis, splitPos, splitInfo))
  135. {
  136. pNode->SetVertexIndexBuffer(AZStd::move(indices));
  137. return;
  138. }
  139. if (splitInfo.m_aboveIndices.empty() || splitInfo.m_belowIndices.empty())
  140. {
  141. pNode->SetVertexIndexBuffer(AZStd::move(indices));
  142. return;
  143. }
  144. pNode->SetChild(0, AZStd::make_unique<ModelKdTreeNode>());
  145. pNode->SetChild(1, AZStd::make_unique<ModelKdTreeNode>());
  146. BuildRecursively(pNode->GetChild(0), splitInfo.m_aboveBoundbox, splitInfo.m_aboveIndices);
  147. BuildRecursively(pNode->GetChild(1), splitInfo.m_belowBoundbox, splitInfo.m_belowIndices);
  148. }
  149. void ModelKdTree::ConstructMeshList(const ModelAsset* model, [[maybe_unused]] const AZ::Transform& matParent)
  150. {
  151. if (model == nullptr || model->GetLodAssets().empty())
  152. {
  153. return;
  154. }
  155. if (ModelLodAsset* lodAssetPtr = model->GetLodAssets()[0].Get())
  156. {
  157. AZ_Warning("ModelKdTree", lodAssetPtr->GetMeshes().size() <= AZStd::numeric_limits<AZ::u8>::max() + 1,
  158. "KdTree generation doesn't support models with greater than 256 meshes. RayIntersection results will be incorrect "
  159. "unless the meshes are merged or broken up into multiple models");
  160. const size_t size = AZStd::min<size_t>(lodAssetPtr->GetMeshes().size(), AZStd::numeric_limits<AZ::u8>::max() + 1);
  161. m_meshes.reserve(size);
  162. AZStd::transform(
  163. lodAssetPtr->GetMeshes().begin(), AZStd::next(lodAssetPtr->GetMeshes().begin(), size),
  164. AZStd::back_inserter(m_meshes),
  165. [](const auto& mesh) { return MeshData{&mesh, GetPositionsBuffer(mesh)}; }
  166. );
  167. }
  168. }
  169. bool ModelKdTree::RayIntersection(
  170. const AZ::Vector3& raySrc, const AZ::Vector3& rayDir, float& distanceNormalized, AZ::Vector3& normal) const
  171. {
  172. float shortestDistanceNormalized = AZStd::numeric_limits<float>::max();
  173. if (RayIntersectionRecursively(m_pRootNode.get(), raySrc, rayDir, shortestDistanceNormalized, normal))
  174. {
  175. distanceNormalized = shortestDistanceNormalized;
  176. return true;
  177. }
  178. return false;
  179. }
  180. bool ModelKdTree::RayIntersectionRecursively(
  181. ModelKdTreeNode* pNode,
  182. const AZ::Vector3& raySrc,
  183. const AZ::Vector3& rayDir,
  184. float& distanceNormalized,
  185. AZ::Vector3& normal) const
  186. {
  187. using Intersect::IntersectRayAABB2;
  188. using Intersect::ISECT_RAY_AABB_NONE;
  189. if (!pNode)
  190. {
  191. return false;
  192. }
  193. float start, end;
  194. if (IntersectRayAABB2(raySrc, rayDir.GetReciprocal(), pNode->GetBoundBox(), start, end) == ISECT_RAY_AABB_NONE)
  195. {
  196. return false;
  197. }
  198. if (start > distanceNormalized)
  199. {
  200. return false;
  201. }
  202. if (pNode->IsLeaf())
  203. {
  204. if (m_meshes.empty())
  205. {
  206. return false;
  207. }
  208. const AZ::u32 nVBuffSize = pNode->GetVertexBufferSize();
  209. if (nVBuffSize == 0)
  210. {
  211. return false;
  212. }
  213. const AZ::Vector3 rayEnd = raySrc + rayDir;
  214. Intersect::SegmentTriangleHitTester hitTester(raySrc, rayEnd);
  215. float nearestDistanceNormalized = distanceNormalized;
  216. for (AZ::u32 i = 0; i < nVBuffSize; ++i)
  217. {
  218. const auto& [first, second, third] = pNode->GetVertexIndex(i);
  219. const AZ::u32 nObjIndex = pNode->GetObjIndex(i);
  220. const AZStd::span<const float> positionBuffer = m_meshes[nObjIndex].m_vertexData;
  221. if (positionBuffer.empty())
  222. {
  223. continue;
  224. }
  225. const AZStd::array trianglePoints {
  226. AZ::Vector3{positionBuffer[first * 3 + 0], positionBuffer[first * 3 + 1], positionBuffer[first * 3 + 2]},
  227. AZ::Vector3{positionBuffer[second * 3 + 0], positionBuffer[second * 3 + 1], positionBuffer[second * 3 + 2]},
  228. AZ::Vector3{positionBuffer[third * 3 + 0], positionBuffer[third * 3 + 1], positionBuffer[third * 3 + 2]},
  229. };
  230. float hitDistanceNormalized;
  231. AZ::Vector3 intersectionNormal;
  232. if (hitTester.IntersectSegmentTriangleCCW(trianglePoints[0], trianglePoints[1], trianglePoints[2],
  233. intersectionNormal, hitDistanceNormalized))
  234. {
  235. if (nearestDistanceNormalized > hitDistanceNormalized)
  236. {
  237. normal = intersectionNormal;
  238. nearestDistanceNormalized = hitDistanceNormalized;
  239. }
  240. }
  241. }
  242. if (nearestDistanceNormalized < distanceNormalized)
  243. {
  244. distanceNormalized = nearestDistanceNormalized;
  245. return true;
  246. }
  247. return false;
  248. }
  249. // running both sides to find the closest intersection
  250. const bool bFoundChild0 = RayIntersectionRecursively(pNode->GetChild(0), raySrc, rayDir, distanceNormalized, normal);
  251. const bool bFoundChild1 = RayIntersectionRecursively(pNode->GetChild(1), raySrc, rayDir, distanceNormalized, normal);
  252. return bFoundChild0 || bFoundChild1;
  253. }
  254. void ModelKdTree::GetPenetratedBoxes(const AZ::Vector3& raySrc, const AZ::Vector3& rayDir, AZStd::vector<AZ::Aabb>& outBoxes)
  255. {
  256. GetPenetratedBoxesRecursively(m_pRootNode.get(), raySrc, rayDir, outBoxes);
  257. }
  258. void ModelKdTree::GetPenetratedBoxesRecursively(ModelKdTreeNode* pNode, const AZ::Vector3& raySrc, const AZ::Vector3& rayDir, AZStd::vector<AZ::Aabb>& outBoxes)
  259. {
  260. AZ::Vector3 ignoreNormal;
  261. float ignore;
  262. if (!pNode || (!pNode->GetBoundBox().Contains(raySrc) &&
  263. (AZ::Intersect::IntersectRayAABB(raySrc, rayDir, rayDir.GetReciprocal(), pNode->GetBoundBox(),
  264. ignore, ignore, ignoreNormal)) == Intersect::ISECT_RAY_AABB_NONE))
  265. {
  266. return;
  267. }
  268. outBoxes.push_back(pNode->GetBoundBox());
  269. GetPenetratedBoxesRecursively(pNode->GetChild(0), raySrc, rayDir, outBoxes);
  270. GetPenetratedBoxesRecursively(pNode->GetChild(1), raySrc, rayDir, outBoxes);
  271. }
  272. } // namespace RPI
  273. } // namespace AZ