GaussianElimination.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2021 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #pragma once
  5. JPH_NAMESPACE_BEGIN
  6. /// This function performs Gauss-Jordan elimination to solve a matrix equation.
  7. /// A must be an NxN matrix and B must be an NxM matrix forming the equation A * x = B
  8. /// on output B will contain x and A will be destroyed.
  9. ///
  10. /// This code can be used for example to compute the inverse of a matrix.
  11. /// Set A to the matrix to invert, set B to identity and let GaussianElimination solve
  12. /// the equation, on return B will be the inverse of A. And A is destroyed.
  13. ///
  14. /// Taken and adapted from Numerical Recipies in C paragraph 2.1
  15. template <class MatrixA, class MatrixB>
  16. bool GaussianElimination(MatrixA &ioA, MatrixB &ioB, float inTolerance = 1.0e-16f)
  17. {
  18. // Get problem dimensions
  19. const uint n = ioA.GetCols();
  20. const uint m = ioB.GetCols();
  21. // Check matrix requirement
  22. JPH_ASSERT(ioA.GetRows() == n);
  23. JPH_ASSERT(ioB.GetRows() == n);
  24. // Create array for bookkeeping on pivoting
  25. int *ipiv = (int *)JPH_STACK_ALLOC(n * sizeof(int));
  26. memset(ipiv, 0, n * sizeof(int));
  27. for (uint i = 0; i < n; ++i)
  28. {
  29. // Initialize pivot element as the diagonal
  30. uint pivot_row = i, pivot_col = i;
  31. // Determine pivot element
  32. float largest_element = 0.0f;
  33. for (uint j = 0; j < n; ++j)
  34. if (ipiv[j] != 1)
  35. for (uint k = 0; k < n; ++k)
  36. {
  37. if (ipiv[k] == 0)
  38. {
  39. float element = abs(ioA(j, k));
  40. if (element >= largest_element)
  41. {
  42. largest_element = element;
  43. pivot_row = j;
  44. pivot_col = k;
  45. }
  46. }
  47. else if (ipiv[k] > 1)
  48. {
  49. return false;
  50. }
  51. }
  52. // Mark this column as used
  53. ++ipiv[pivot_col];
  54. // Exchange rows when needed so that the pivot element is at ioA(pivot_col, pivot_col) instead of at ioA(pivot_row, pivot_col)
  55. if (pivot_row != pivot_col)
  56. {
  57. for (uint j = 0; j < n; ++j)
  58. swap(ioA(pivot_row, j), ioA(pivot_col, j));
  59. for (uint j = 0; j < m; ++j)
  60. swap(ioB(pivot_row, j), ioB(pivot_col, j));
  61. }
  62. // Get diagonal element that we are about to set to 1
  63. float diagonal_element = ioA(pivot_col, pivot_col);
  64. if (abs(diagonal_element) < inTolerance)
  65. return false;
  66. // Divide the whole row by the pivot element, making ioA(pivot_col, pivot_col) = 1
  67. for (uint j = 0; j < n; ++j)
  68. ioA(pivot_col, j) /= diagonal_element;
  69. for (uint j = 0; j < m; ++j)
  70. ioB(pivot_col, j) /= diagonal_element;
  71. ioA(pivot_col, pivot_col) = 1.0f;
  72. // Next reduce the rows, except for the pivot one,
  73. // after this step the pivot_col column is zero except for the pivot element which is 1
  74. for (uint j = 0; j < n; ++j)
  75. if (j != pivot_col)
  76. {
  77. float element = ioA(j, pivot_col);
  78. for (uint k = 0; k < n; ++k)
  79. ioA(j, k) -= ioA(pivot_col, k) * element;
  80. for (uint k = 0; k < m; ++k)
  81. ioB(j, k) -= ioB(pivot_col, k) * element;
  82. ioA(j, pivot_col) = 0.0f;
  83. }
  84. }
  85. // Success
  86. return true;
  87. }
  88. JPH_NAMESPACE_END