Matrix.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2021 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #pragma once
  5. #include <Jolt/Math/Vector.h>
  6. #include <Jolt/Math/GaussianElimination.h>
  7. JPH_NAMESPACE_BEGIN
  8. /// Templatized matrix class
  9. template <uint Rows, uint Cols>
  10. class [[nodiscard]] Matrix
  11. {
  12. public:
  13. /// Constructor
  14. inline Matrix() = default;
  15. inline Matrix(const Matrix &inM2) { *this = inM2; }
  16. /// Dimensions
  17. inline uint GetRows() const { return Rows; }
  18. inline uint GetCols() const { return Cols; }
  19. /// Zero matrix
  20. inline void SetZero()
  21. {
  22. for (uint c = 0; c < Cols; ++c)
  23. mCol[c].SetZero();
  24. }
  25. inline static Matrix sZero() { Matrix m; m.SetZero(); return m; }
  26. /// Check if this matrix consists of all zeros
  27. inline bool IsZero() const
  28. {
  29. for (uint c = 0; c < Cols; ++c)
  30. if (!mCol[c].IsZero())
  31. return false;
  32. return true;
  33. }
  34. /// Identity matrix
  35. inline void SetIdentity()
  36. {
  37. // Clear matrix
  38. SetZero();
  39. // Set diagonal to 1
  40. for (uint rc = 0, min_rc = min(Rows, Cols); rc < min_rc; ++rc)
  41. mCol[rc].mF32[rc] = 1.0f;
  42. }
  43. inline static Matrix sIdentity() { Matrix m; m.SetIdentity(); return m; }
  44. /// Check if this matrix is identity
  45. bool IsIdentity() const { return *this == sIdentity(); }
  46. /// Diagonal matrix
  47. inline void SetDiagonal(const Vector<Rows < Cols? Rows : Cols> &inV)
  48. {
  49. // Clear matrix
  50. SetZero();
  51. // Set diagonal
  52. for (uint rc = 0, min_rc = min(Rows, Cols); rc < min_rc; ++rc)
  53. mCol[rc].mF32[rc] = inV[rc];
  54. }
  55. inline static Matrix sDiagonal(const Vector<Rows < Cols? Rows : Cols> &inV)
  56. {
  57. Matrix m;
  58. m.SetDiagonal(inV);
  59. return m;
  60. }
  61. /// Copy a (part) of another matrix into this matrix
  62. template <class OtherMatrix>
  63. void CopyPart(const OtherMatrix &inM, uint inSourceRow, uint inSourceCol, uint inNumRows, uint inNumCols, uint inDestRow, uint inDestCol)
  64. {
  65. for (uint c = 0; c < inNumCols; ++c)
  66. for (uint r = 0; r < inNumRows; ++r)
  67. mCol[inDestCol + c].mF32[inDestRow + r] = inM(inSourceRow + r, inSourceCol + c);
  68. }
  69. /// Get float component by element index
  70. inline float operator () (uint inRow, uint inColumn) const
  71. {
  72. JPH_ASSERT(inRow < Rows);
  73. JPH_ASSERT(inColumn < Cols);
  74. return mCol[inColumn].mF32[inRow];
  75. }
  76. inline float & operator () (uint inRow, uint inColumn)
  77. {
  78. JPH_ASSERT(inRow < Rows);
  79. JPH_ASSERT(inColumn < Cols);
  80. return mCol[inColumn].mF32[inRow];
  81. }
  82. /// Comparison
  83. inline bool operator == (const Matrix &inM2) const
  84. {
  85. for (uint c = 0; c < Cols; ++c)
  86. if (mCol[c] != inM2.mCol[c])
  87. return false;
  88. return true;
  89. }
  90. inline bool operator != (const Matrix &inM2) const
  91. {
  92. for (uint c = 0; c < Cols; ++c)
  93. if (mCol[c] != inM2.mCol[c])
  94. return true;
  95. return false;
  96. }
  97. /// Assignment
  98. inline Matrix & operator = (const Matrix &inM2)
  99. {
  100. for (uint c = 0; c < Cols; ++c)
  101. mCol[c] = inM2.mCol[c];
  102. return *this;
  103. }
  104. /// Multiply matrix by matrix
  105. template <uint OtherCols>
  106. inline Matrix<Rows, OtherCols> operator * (const Matrix<Cols, OtherCols> &inM) const
  107. {
  108. Matrix<Rows, OtherCols> m;
  109. for (uint c = 0; c < OtherCols; ++c)
  110. for (uint r = 0; r < Rows; ++r)
  111. {
  112. float dot = 0.0f;
  113. for (uint i = 0; i < Cols; ++i)
  114. dot += mCol[i].mF32[r] * inM.mCol[c].mF32[i];
  115. m.mCol[c].mF32[r] = dot;
  116. }
  117. return m;
  118. }
  119. /// Multiply vector by matrix
  120. inline Vector<Rows> operator * (const Vector<Cols> &inV) const
  121. {
  122. Vector<Rows> v;
  123. for (uint r = 0; r < Rows; ++r)
  124. {
  125. float dot = 0.0f;
  126. for (uint c = 0; c < Cols; ++c)
  127. dot += mCol[c].mF32[r] * inV.mF32[c];
  128. v.mF32[r] = dot;
  129. }
  130. return v;
  131. }
  132. /// Multiply matrix with float
  133. inline Matrix operator * (float inV) const
  134. {
  135. Matrix m;
  136. for (uint c = 0; c < Cols; ++c)
  137. m.mCol[c] = mCol[c] * inV;
  138. return m;
  139. }
  140. inline friend Matrix operator * (float inV, const Matrix &inM)
  141. {
  142. return inM * inV;
  143. }
  144. /// Per element addition of matrix
  145. inline Matrix operator + (const Matrix &inM) const
  146. {
  147. Matrix m;
  148. for (uint c = 0; c < Cols; ++c)
  149. m.mCol[c] = mCol[c] + inM.mCol[c];
  150. return m;
  151. }
  152. /// Per element subtraction of matrix
  153. inline Matrix operator - (const Matrix &inM) const
  154. {
  155. Matrix m;
  156. for (uint c = 0; c < Cols; ++c)
  157. m.mCol[c] = mCol[c] - inM.mCol[c];
  158. return m;
  159. }
  160. /// Transpose matrix
  161. inline Matrix<Cols, Rows> Transposed() const
  162. {
  163. Matrix<Cols, Rows> m;
  164. for (uint r = 0; r < Rows; ++r)
  165. for (uint c = 0; c < Cols; ++c)
  166. m.mCol[r].mF32[c] = mCol[c].mF32[r];
  167. return m;
  168. }
  169. /// Inverse matrix
  170. bool SetInversed(const Matrix &inM)
  171. {
  172. if constexpr (Rows != Cols) JPH_ASSERT(false);
  173. Matrix copy(inM);
  174. SetIdentity();
  175. return GaussianElimination(copy, *this);
  176. }
  177. inline Matrix Inversed() const
  178. {
  179. Matrix m;
  180. m.SetInversed(*this);
  181. return m;
  182. }
  183. /// To String
  184. friend ostream & operator << (ostream &inStream, const Matrix &inM)
  185. {
  186. for (uint i = 0; i < Cols - 1; ++i)
  187. inStream << inM.mCol[i] << ", ";
  188. inStream << inM.mCol[Cols - 1];
  189. return inStream;
  190. }
  191. /// Column access
  192. const Vector<Rows> & GetColumn(int inIdx) const { return mCol[inIdx]; }
  193. Vector<Rows> & GetColumn(int inIdx) { return mCol[inIdx]; }
  194. Vector<Rows> mCol[Cols]; ///< Column
  195. };
  196. // The template specialization doesn't sit well with Doxygen
  197. #ifndef JPH_PLATFORM_DOXYGEN
  198. /// Specialization of SetInversed for 2x2 matrix
  199. template <>
  200. inline bool Matrix<2, 2>::SetInversed(const Matrix<2, 2> &inM)
  201. {
  202. // Fetch elements
  203. float a = inM.mCol[0].mF32[0];
  204. float b = inM.mCol[1].mF32[0];
  205. float c = inM.mCol[0].mF32[1];
  206. float d = inM.mCol[1].mF32[1];
  207. // Calculate determinant
  208. float det = a * d - b * c;
  209. if (det == 0.0f)
  210. return false;
  211. // Construct inverse
  212. mCol[0].mF32[0] = d / det;
  213. mCol[1].mF32[0] = -b / det;
  214. mCol[0].mF32[1] = -c / det;
  215. mCol[1].mF32[1] = a / det;
  216. return true;
  217. }
  218. #endif // !JPH_PLATFORM_DOXYGEN
  219. JPH_NAMESPACE_END