BsRect3.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #include "BsRect3.h"
  2. #include "BsRay.h"
  3. #include "BsLineSegment3.h"
  4. namespace BansheeEngine
  5. {
  6. Rect3::Rect3(const Vector3& center, const std::array<Vector3, 2>& axes,
  7. const std::array<float, 2>& extents)
  8. :mCenter(center), mAxes(axes), mExtents(extents)
  9. {
  10. }
  11. std::pair<std::array<Vector3, 2>, float> Rect3::getNearestPoint(const Ray& ray) const
  12. {
  13. const Vector3& org = ray.getOrigin();
  14. const Vector3& dir = ray.getDirection();
  15. bool foundNearest = false;
  16. float t = 0.0f;
  17. std::array<Vector3, 2> nearestPoints;
  18. float distance = 0.0f;
  19. // Check if Ray intersects the rectangle
  20. auto intersectResult = intersects(ray);
  21. if (intersectResult.first)
  22. {
  23. t = intersectResult.second;
  24. nearestPoints[0] = org + dir * t;
  25. nearestPoints[1] = nearestPoints[0]; // Just one point of intersection
  26. foundNearest = true;
  27. }
  28. // Ray is either passing next to the rectangle or parallel to it,
  29. // compare ray to 4 edges of the rectangle
  30. if (!foundNearest)
  31. {
  32. Vector3 scaledAxes[2];
  33. scaledAxes[0] = mExtents[0] * mAxes[0];
  34. scaledAxes[1] = mExtents[1] * mAxes[1];
  35. distance = std::numeric_limits<float>::max();
  36. for (UINT32 i = 0; i < 2; i++)
  37. {
  38. for (UINT32 j = 0; j < 2; j++)
  39. {
  40. float sign = (float)(2 * j - 1);
  41. Vector3 segCenter = mCenter + sign * scaledAxes[i];
  42. Vector3 segStart = segCenter - scaledAxes[1 - i];
  43. Vector3 segEnd = segCenter + scaledAxes[1 - i];
  44. LineSegment3 segment(segStart, segEnd);
  45. auto segResult = segment.getNearestPoint(ray);
  46. if (segResult.second < distance)
  47. {
  48. nearestPoints = segResult.first;
  49. distance = segResult.second;
  50. }
  51. }
  52. }
  53. }
  54. // Front of the ray is nearest, use found points
  55. if (t >= 0.0f)
  56. {
  57. // Do nothing, we already have the points
  58. }
  59. else // Rectangle is behind the ray origin, find nearest point to origin
  60. {
  61. auto nearestPointToOrg = getNearestPoint(org);
  62. nearestPoints[0] = org;
  63. nearestPoints[1] = nearestPointToOrg.first;
  64. distance = nearestPointToOrg.second;
  65. }
  66. return std::make_pair(nearestPoints, distance);
  67. }
  68. std::pair<Vector3, float> Rect3::getNearestPoint(const Vector3& point) const
  69. {
  70. Vector3 diff = mCenter - point;
  71. float b0 = diff.dot(mAxes[0]);
  72. float b1 = diff.dot(mAxes[1]);
  73. float s0 = -b0, s1 = -b1;
  74. float sqrDistance = diff.dot(diff);
  75. if (s0 < -mExtents[0])
  76. s0 = -mExtents[0];
  77. else if (s0 > mExtents[0])
  78. s0 = mExtents[0];
  79. sqrDistance += s0*(s0 + 2.0f*b0);
  80. if (s1 < -mExtents[1])
  81. s1 = -mExtents[1];
  82. else if (s1 > mExtents[1])
  83. s1 = mExtents[1];
  84. sqrDistance += s1*(s1 + 2.0f*b1);
  85. if (sqrDistance < 0.0f)
  86. sqrDistance = 0.0f;
  87. float dist = std::sqrt(sqrDistance);
  88. Vector3 nearestPoint = mCenter + s0 * mAxes[0] + s1 * mAxes[1];
  89. return std::make_pair(nearestPoint, dist);
  90. }
  91. std::pair<bool, float> Rect3::intersects(const Ray& ray) const
  92. {
  93. const Vector3& org = ray.getOrigin();
  94. const Vector3& dir = ray.getDirection();
  95. Vector3 normal = mAxes[0].cross(mAxes[1]);
  96. float NdotD = normal.dot(ray.getDirection());
  97. if (fabs(NdotD) > 0.0f)
  98. {
  99. Vector3 diff = ray.getOrigin() - mCenter;
  100. Vector3 basis[3];
  101. basis[0] = ray.getDirection();
  102. basis[0].orthogonalComplement(basis[1], basis[2]);
  103. float UdD0 = basis[1].dot(mAxes[0]);
  104. float UdD1 = basis[1].dot(mAxes[1]);
  105. float UdPmC = basis[1].dot(diff);
  106. float VdD0 = basis[2].dot(mAxes[0]);
  107. float VdD1 = basis[2].dot(mAxes[1]);
  108. float VdPmC = basis[2].dot(diff);
  109. float invDet = 1.0f / (UdD0*VdD1 - UdD1*VdD0);
  110. float s0 = (VdD1*UdPmC - UdD1*VdPmC)*invDet;
  111. float s1 = (UdD0*VdPmC - VdD0*UdPmC)*invDet;
  112. if (fabs(s0) <= mExtents[0] && fabs(s1) <= mExtents[1])
  113. {
  114. float DdD0 = dir.dot(mAxes[0]);
  115. float DdD1 = dir.dot(mAxes[1]);
  116. float DdDiff = dir.dot(diff);
  117. float t = s0 * DdD0 + s1 * DdD1 - DdDiff;
  118. return std::make_pair(true, t);
  119. }
  120. }
  121. return std::make_pair(false, 0.0f);
  122. }
  123. }