pod_math.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. #pragma once
  2. #include <algorithm>
  3. #include <cmath>
  4. #include <cstring>
  5. namespace Render::Math {
  6. struct alignas(16) Vec3 {
  7. float x, y, z, w;
  8. Vec3() noexcept : x(0), y(0), z(0), w(0) {}
  9. Vec3(float x_, float y_, float z_) noexcept : x(x_), y(y_), z(z_), w(0) {}
  10. inline Vec3 operator+(const Vec3 &o) const noexcept {
  11. return Vec3(x + o.x, y + o.y, z + o.z);
  12. }
  13. inline Vec3 operator-(const Vec3 &o) const noexcept {
  14. return Vec3(x - o.x, y - o.y, z - o.z);
  15. }
  16. inline Vec3 operator*(float s) const noexcept {
  17. return Vec3(x * s, y * s, z * s);
  18. }
  19. inline float dot(const Vec3 &o) const noexcept {
  20. return x * o.x + y * o.y + z * o.z;
  21. }
  22. inline Vec3 cross(const Vec3 &o) const noexcept {
  23. return Vec3(y * o.z - z * o.y, z * o.x - x * o.z, x * o.y - y * o.x);
  24. }
  25. inline float lengthSquared() const noexcept { return x * x + y * y + z * z; }
  26. inline float length() const noexcept { return std::sqrt(lengthSquared()); }
  27. inline Vec3 normalized() const noexcept {
  28. float len = length();
  29. if (len < 1e-6f)
  30. return Vec3(0, 1, 0);
  31. float invLen = 1.0f / len;
  32. return Vec3(x * invLen, y * invLen, z * invLen);
  33. }
  34. inline void normalize() noexcept {
  35. float len = length();
  36. if (len > 1e-6f) {
  37. float invLen = 1.0f / len;
  38. x *= invLen;
  39. y *= invLen;
  40. z *= invLen;
  41. }
  42. }
  43. };
  44. struct alignas(16) Mat3x4 {
  45. float m[3][4];
  46. Mat3x4() noexcept {
  47. std::memset(m, 0, sizeof(m));
  48. m[0][0] = m[1][1] = m[2][2] = 1.0f;
  49. }
  50. static inline Mat3x4 TRS(const Vec3 &translation, const float rotation[3][3],
  51. float scaleX, float scaleY, float scaleZ) noexcept {
  52. Mat3x4 result;
  53. for (int row = 0; row < 3; ++row) {
  54. result.m[row][0] = rotation[row][0] * scaleX;
  55. result.m[row][1] = rotation[row][1] * scaleY;
  56. result.m[row][2] = rotation[row][2] * scaleZ;
  57. result.m[row][3] = (&translation.x)[row];
  58. }
  59. return result;
  60. }
  61. inline Vec3 transformPoint(const Vec3 &p) const noexcept {
  62. return Vec3(m[0][0] * p.x + m[0][1] * p.y + m[0][2] * p.z + m[0][3],
  63. m[1][0] * p.x + m[1][1] * p.y + m[1][2] * p.z + m[1][3],
  64. m[2][0] * p.x + m[2][1] * p.y + m[2][2] * p.z + m[2][3]);
  65. }
  66. inline Vec3 transformVector(const Vec3 &v) const noexcept {
  67. return Vec3(m[0][0] * v.x + m[0][1] * v.y + m[0][2] * v.z,
  68. m[1][0] * v.x + m[1][1] * v.y + m[1][2] * v.z,
  69. m[2][0] * v.x + m[2][1] * v.y + m[2][2] * v.z);
  70. }
  71. inline Mat3x4 operator*(const Mat3x4 &o) const noexcept {
  72. Mat3x4 result;
  73. for (int row = 0; row < 3; ++row) {
  74. for (int col = 0; col < 3; ++col) {
  75. result.m[row][col] = m[row][0] * o.m[0][col] + m[row][1] * o.m[1][col] +
  76. m[row][2] * o.m[2][col];
  77. }
  78. result.m[row][3] = m[row][0] * o.m[0][3] + m[row][1] * o.m[1][3] +
  79. m[row][2] * o.m[2][3] + m[row][3];
  80. }
  81. return result;
  82. }
  83. inline void setTranslation(const Vec3 &t) noexcept {
  84. m[0][3] = t.x;
  85. m[1][3] = t.y;
  86. m[2][3] = t.z;
  87. }
  88. inline Vec3 getTranslation() const noexcept {
  89. return Vec3(m[0][3], m[1][3], m[2][3]);
  90. }
  91. };
  92. struct CylinderTransform {
  93. Vec3 center;
  94. Vec3 axis;
  95. Vec3 tangent;
  96. Vec3 bitangent;
  97. float length;
  98. float radius;
  99. static inline CylinderTransform fromPoints(const Vec3 &start, const Vec3 &end,
  100. float radius) noexcept {
  101. CylinderTransform ct;
  102. ct.radius = radius;
  103. Vec3 diff = end - start;
  104. float lenSq = diff.lengthSquared();
  105. if (lenSq < 1e-10f) {
  106. ct.center = start;
  107. ct.axis = Vec3(0, 1, 0);
  108. ct.tangent = Vec3(1, 0, 0);
  109. ct.bitangent = Vec3(0, 0, 1);
  110. ct.length = 0.0f;
  111. return ct;
  112. }
  113. ct.length = std::sqrt(lenSq);
  114. ct.center = Vec3((start.x + end.x) * 0.5f, (start.y + end.y) * 0.5f,
  115. (start.z + end.z) * 0.5f);
  116. ct.axis = diff * (1.0f / ct.length);
  117. Vec3 up = (std::abs(ct.axis.y) < 0.999f) ? Vec3(0, 1, 0) : Vec3(1, 0, 0);
  118. ct.tangent = up.cross(ct.axis).normalized();
  119. ct.bitangent = ct.axis.cross(ct.tangent).normalized();
  120. return ct;
  121. }
  122. inline Mat3x4 toMatrix() const noexcept {
  123. Mat3x4 m;
  124. m.m[0][0] = tangent.x * radius;
  125. m.m[1][0] = tangent.y * radius;
  126. m.m[2][0] = tangent.z * radius;
  127. m.m[0][1] = axis.x * length;
  128. m.m[1][1] = axis.y * length;
  129. m.m[2][1] = axis.z * length;
  130. m.m[0][2] = bitangent.x * radius;
  131. m.m[1][2] = bitangent.y * radius;
  132. m.m[2][2] = bitangent.z * radius;
  133. m.m[0][3] = center.x;
  134. m.m[1][3] = center.y;
  135. m.m[2][3] = center.z;
  136. return m;
  137. }
  138. };
  139. inline Mat3x4 cylinderBetweenFast(const Vec3 &a, const Vec3 &b,
  140. float radius) noexcept {
  141. const float dx = b.x - a.x;
  142. const float dy = b.y - a.y;
  143. const float dz = b.z - a.z;
  144. const float lenSq = dx * dx + dy * dy + dz * dz;
  145. constexpr float kEpsilonSq = 1e-12f;
  146. constexpr float kRadToDeg = 57.2957795131f;
  147. Vec3 center((a.x + b.x) * 0.5f, (a.y + b.y) * 0.5f, (a.z + b.z) * 0.5f);
  148. if (lenSq < kEpsilonSq) {
  149. Mat3x4 m;
  150. m.m[0][0] = radius;
  151. m.m[0][1] = 0;
  152. m.m[0][2] = 0;
  153. m.m[1][0] = 0;
  154. m.m[1][1] = 1.0f;
  155. m.m[1][2] = 0;
  156. m.m[2][0] = 0;
  157. m.m[2][1] = 0;
  158. m.m[2][2] = radius;
  159. m.setTranslation(center);
  160. return m;
  161. }
  162. const float len = std::sqrt(lenSq);
  163. const float invLen = 1.0f / len;
  164. const float ndx = dx * invLen;
  165. const float ndy = dy * invLen;
  166. const float ndz = dz * invLen;
  167. const float axisX = ndz;
  168. const float axisZ = -ndx;
  169. const float axisLenSq = axisX * axisX + axisZ * axisZ;
  170. float rot[3][3];
  171. if (axisLenSq < kEpsilonSq) {
  172. if (ndy < 0.0f) {
  173. rot[0][0] = 1;
  174. rot[0][1] = 0;
  175. rot[0][2] = 0;
  176. rot[1][0] = 0;
  177. rot[1][1] = -1;
  178. rot[1][2] = 0;
  179. rot[2][0] = 0;
  180. rot[2][1] = 0;
  181. rot[2][2] = -1;
  182. } else {
  183. rot[0][0] = 1;
  184. rot[0][1] = 0;
  185. rot[0][2] = 0;
  186. rot[1][0] = 0;
  187. rot[1][1] = 1;
  188. rot[1][2] = 0;
  189. rot[2][0] = 0;
  190. rot[2][1] = 0;
  191. rot[2][2] = 1;
  192. }
  193. } else {
  194. const float axisInvLen = 1.0f / std::sqrt(axisLenSq);
  195. const float ax = axisX * axisInvLen;
  196. const float az = axisZ * axisInvLen;
  197. const float dot = std::clamp(ndy, -1.0f, 1.0f);
  198. const float angle = std::acos(dot);
  199. const float c = std::cos(angle);
  200. const float s = std::sin(angle);
  201. const float t = 1.0f - c;
  202. rot[0][0] = t * ax * ax + c;
  203. rot[0][1] = t * ax * 0;
  204. rot[0][2] = t * ax * az - s * 0;
  205. rot[1][0] = t * 0 * ax + s * az;
  206. rot[1][1] = t * 0 * 0 + c;
  207. rot[1][2] = t * 0 * az - s * ax;
  208. rot[2][0] = t * az * ax + s * 0;
  209. rot[2][1] = t * az * 0 - s * ax;
  210. rot[2][2] = t * az * az + c;
  211. }
  212. Mat3x4 result = Mat3x4::TRS(center, rot, radius, len, radius);
  213. return result;
  214. }
  215. inline Mat3x4 sphereAtFast(const Vec3 &pos, float radius) noexcept {
  216. Mat3x4 m;
  217. m.m[0][0] = radius;
  218. m.m[0][1] = 0;
  219. m.m[0][2] = 0;
  220. m.m[1][0] = 0;
  221. m.m[1][1] = radius;
  222. m.m[1][2] = 0;
  223. m.m[2][0] = 0;
  224. m.m[2][1] = 0;
  225. m.m[2][2] = radius;
  226. m.setTranslation(pos);
  227. return m;
  228. }
  229. inline Mat3x4 cylinderBetweenFast(const Mat3x4 &parent, const Vec3 &a,
  230. const Vec3 &b, float radius) noexcept {
  231. Mat3x4 local = cylinderBetweenFast(a, b, radius);
  232. return parent * local;
  233. }
  234. inline Mat3x4 sphereAtFast(const Mat3x4 &parent, const Vec3 &pos,
  235. float radius) noexcept {
  236. Mat3x4 local = sphereAtFast(pos, radius);
  237. return parent * local;
  238. }
  239. } // namespace Render::Math