2
0

Matrix.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // SPDX-FileCopyrightText: 2021 Jorrit Rouwe
  2. // SPDX-License-Identifier: MIT
  3. #pragma once
  4. #include <Jolt/Math/Vector.h>
  5. #include <Jolt/Math/GaussianElimination.h>
  6. JPH_NAMESPACE_BEGIN
  7. /// Templatized matrix class
  8. template <uint Rows, uint Cols>
  9. class [[nodiscard]] Matrix
  10. {
  11. public:
  12. /// Constructor
  13. inline Matrix() = default;
  14. inline Matrix(const Matrix &inM2) { *this = inM2; }
  15. /// Dimensions
  16. inline uint GetRows() const { return Rows; }
  17. inline uint GetCols() const { return Cols; }
  18. /// Zero matrix
  19. inline void SetZero()
  20. {
  21. for (uint c = 0; c < Cols; ++c)
  22. mCol[c].SetZero();
  23. }
  24. inline static Matrix sZero() { Matrix m; m.SetZero(); return m; }
  25. /// Check if this matrix consists of all zeros
  26. inline bool IsZero() const
  27. {
  28. for (uint c = 0; c < Cols; ++c)
  29. if (!mCol[c].IsZero())
  30. return false;
  31. return true;
  32. }
  33. /// Identity matrix
  34. inline void SetIdentity()
  35. {
  36. // Clear matrix
  37. SetZero();
  38. // Set diagonal to 1
  39. for (uint rc = 0, min_rc = min(Rows, Cols); rc < min_rc; ++rc)
  40. mCol[rc].mF32[rc] = 1.0f;
  41. }
  42. inline static Matrix sIdentity() { Matrix m; m.SetIdentity(); return m; }
  43. /// Check if this matrix is identity
  44. bool IsIdentity() const { return *this == sIdentity(); }
  45. /// Diagonal matrix
  46. inline void SetDiagonal(const Vector<Rows < Cols? Rows : Cols> &inV)
  47. {
  48. // Clear matrix
  49. SetZero();
  50. // Set diagonal
  51. for (uint rc = 0, min_rc = min(Rows, Cols); rc < min_rc; ++rc)
  52. mCol[rc].mF32[rc] = inV[rc];
  53. }
  54. inline static Matrix sDiagonal(const Vector<Rows < Cols? Rows : Cols> &inV)
  55. {
  56. Matrix m;
  57. m.SetDiagonal(inV);
  58. return m;
  59. }
  60. /// Copy a (part) of another matrix into this matrix
  61. template <class OtherMatrix>
  62. void CopyPart(const OtherMatrix &inM, uint inSourceRow, uint inSourceCol, uint inNumRows, uint inNumCols, uint inDestRow, uint inDestCol)
  63. {
  64. for (uint c = 0; c < inNumCols; ++c)
  65. for (uint r = 0; r < inNumRows; ++r)
  66. mCol[inDestCol + c].mF32[inDestRow + r] = inM(inSourceRow + r, inSourceCol + c);
  67. }
  68. /// Get float component by element index
  69. inline float operator () (uint inRow, uint inColumn) const
  70. {
  71. JPH_ASSERT(inRow < Rows);
  72. JPH_ASSERT(inColumn < Cols);
  73. return mCol[inColumn].mF32[inRow];
  74. }
  75. inline float & operator () (uint inRow, uint inColumn)
  76. {
  77. JPH_ASSERT(inRow < Rows);
  78. JPH_ASSERT(inColumn < Cols);
  79. return mCol[inColumn].mF32[inRow];
  80. }
  81. /// Comparison
  82. inline bool operator == (const Matrix &inM2) const
  83. {
  84. for (uint c = 0; c < Cols; ++c)
  85. if (mCol[c] != inM2.mCol[c])
  86. return false;
  87. return true;
  88. }
  89. inline bool operator != (const Matrix &inM2) const
  90. {
  91. for (uint c = 0; c < Cols; ++c)
  92. if (mCol[c] != inM2.mCol[c])
  93. return true;
  94. return false;
  95. }
  96. /// Assignment
  97. inline Matrix & operator = (const Matrix &inM2)
  98. {
  99. for (uint c = 0; c < Cols; ++c)
  100. mCol[c] = inM2.mCol[c];
  101. return *this;
  102. }
  103. /// Multiply matrix by matrix
  104. template <uint OtherCols>
  105. inline Matrix<Rows, OtherCols> operator * (const Matrix<Cols, OtherCols> &inM) const
  106. {
  107. Matrix<Rows, OtherCols> m;
  108. for (uint c = 0; c < OtherCols; ++c)
  109. for (uint r = 0; r < Rows; ++r)
  110. {
  111. float dot = 0.0f;
  112. for (uint i = 0; i < Cols; ++i)
  113. dot += mCol[i].mF32[r] * inM.mCol[c].mF32[i];
  114. m.mCol[c].mF32[r] = dot;
  115. }
  116. return m;
  117. }
  118. /// Multiply vector by matrix
  119. inline Vector<Rows> operator * (const Vector<Cols> &inV) const
  120. {
  121. Vector<Rows> v;
  122. for (uint r = 0; r < Rows; ++r)
  123. {
  124. float dot = 0.0f;
  125. for (uint c = 0; c < Cols; ++c)
  126. dot += mCol[c].mF32[r] * inV.mF32[c];
  127. v.mF32[r] = dot;
  128. }
  129. return v;
  130. }
  131. /// Multiply matrix with float
  132. inline Matrix operator * (float inV) const
  133. {
  134. Matrix m;
  135. for (uint c = 0; c < Cols; ++c)
  136. m.mCol[c] = mCol[c] * inV;
  137. return m;
  138. }
  139. inline friend Matrix operator * (float inV, const Matrix &inM)
  140. {
  141. return inM * inV;
  142. }
  143. /// Per element addition of matrix
  144. inline Matrix operator + (const Matrix &inM) const
  145. {
  146. Matrix m;
  147. for (uint c = 0; c < Cols; ++c)
  148. m.mCol[c] = mCol[c] + inM.mCol[c];
  149. return m;
  150. }
  151. /// Per element subtraction of matrix
  152. inline Matrix operator - (const Matrix &inM) const
  153. {
  154. Matrix m;
  155. for (uint c = 0; c < Cols; ++c)
  156. m.mCol[c] = mCol[c] - inM.mCol[c];
  157. return m;
  158. }
  159. /// Transpose matrix
  160. inline Matrix<Cols, Rows> Transposed() const
  161. {
  162. Matrix<Cols, Rows> m;
  163. for (uint r = 0; r < Rows; ++r)
  164. for (uint c = 0; c < Cols; ++c)
  165. m.mCol[r].mF32[c] = mCol[c].mF32[r];
  166. return m;
  167. }
  168. /// Inverse matrix
  169. bool SetInversed(const Matrix &inM)
  170. {
  171. if constexpr (Rows != Cols) JPH_ASSERT(false);
  172. Matrix copy(inM);
  173. SetIdentity();
  174. return GaussianElimination(copy, *this);
  175. }
  176. inline Matrix Inversed() const
  177. {
  178. Matrix m;
  179. m.SetInversed(*this);
  180. return m;
  181. }
  182. /// To String
  183. friend ostream & operator << (ostream &inStream, const Matrix &inM)
  184. {
  185. for (uint i = 0; i < Cols - 1; ++i)
  186. inStream << inM.mCol[i] << ", ";
  187. inStream << inM.mCol[Cols - 1];
  188. return inStream;
  189. }
  190. /// Column access
  191. const Vector<Rows> & GetColumn(int inIdx) const { return mCol[inIdx]; }
  192. Vector<Rows> & GetColumn(int inIdx) { return mCol[inIdx]; }
  193. Vector<Rows> mCol[Cols]; ///< Column
  194. };
  195. // The template specialization doesn't sit well with Doxygen
  196. #ifndef JPH_PLATFORM_DOXYGEN
  197. /// Specialization of SetInversed for 2x2 matrix
  198. template <>
  199. inline bool Matrix<2, 2>::SetInversed(const Matrix<2, 2> &inM)
  200. {
  201. // Fetch elements
  202. float a = inM.mCol[0].mF32[0];
  203. float b = inM.mCol[1].mF32[0];
  204. float c = inM.mCol[0].mF32[1];
  205. float d = inM.mCol[1].mF32[1];
  206. // Calculate determinant
  207. float det = a * d - b * c;
  208. if (det == 0.0f)
  209. return false;
  210. // Construct inverse
  211. mCol[0].mF32[0] = d / det;
  212. mCol[1].mF32[0] = -b / det;
  213. mCol[0].mF32[1] = -c / det;
  214. mCol[1].mF32[1] = a / det;
  215. return true;
  216. }
  217. #endif // !JPH_PLATFORM_DOXYGEN
  218. JPH_NAMESPACE_END