GaussianElimination.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. // SPDX-FileCopyrightText: 2021 Jorrit Rouwe
  2. // SPDX-License-Identifier: MIT
  3. #pragma once
  4. JPH_NAMESPACE_BEGIN
  5. /// This function performs Gauss-Jordan elimination to solve a matrix equation.
  6. /// A must be an NxN matrix and B must be an NxM matrix forming the equation A * x = B
  7. /// on output B will contain x and A will be destroyed.
  8. ///
  9. /// This code can be used for example to compute the inverse of a matrix.
  10. /// Set A to the matrix to invert, set B to identity and let GaussianElimination solve
  11. /// the equation, on return B will be the inverse of A. And A is destroyed.
  12. ///
  13. /// Taken and adapted from Numerical Recipies in C paragraph 2.1
  14. template <class MatrixA, class MatrixB>
  15. bool GaussianElimination(MatrixA &ioA, MatrixB &ioB, float inTolerance = 1.0e-16f)
  16. {
  17. // Get problem dimensions
  18. const uint n = ioA.GetCols();
  19. const uint m = ioB.GetCols();
  20. // Check matrix requirement
  21. JPH_ASSERT(ioA.GetRows() == n);
  22. JPH_ASSERT(ioB.GetRows() == n);
  23. // Create array for bookkeeping on pivoting
  24. int *ipiv = (int *)JPH_STACK_ALLOC(n * sizeof(int));
  25. memset(ipiv, 0, n * sizeof(int));
  26. for (uint i = 0; i < n; ++i)
  27. {
  28. // Initialize pivot element as the diagonal
  29. uint pivot_row = i, pivot_col = i;
  30. // Determine pivot element
  31. float largest_element = 0.0f;
  32. for (uint j = 0; j < n; ++j)
  33. if (ipiv[j] != 1)
  34. for (uint k = 0; k < n; ++k)
  35. {
  36. if (ipiv[k] == 0)
  37. {
  38. float element = abs(ioA(j, k));
  39. if (element >= largest_element)
  40. {
  41. largest_element = element;
  42. pivot_row = j;
  43. pivot_col = k;
  44. }
  45. }
  46. else if (ipiv[k] > 1)
  47. {
  48. return false;
  49. }
  50. }
  51. // Mark this column as used
  52. ++ipiv[pivot_col];
  53. // Exchange rows when needed so that the pivot element is at ioA(pivot_col, pivot_col) instead of at ioA(pivot_row, pivot_col)
  54. if (pivot_row != pivot_col)
  55. {
  56. for (uint j = 0; j < n; ++j)
  57. swap(ioA(pivot_row, j), ioA(pivot_col, j));
  58. for (uint j = 0; j < m; ++j)
  59. swap(ioB(pivot_row, j), ioB(pivot_col, j));
  60. }
  61. // Get diagonal element that we are about to set to 1
  62. float diagonal_element = ioA(pivot_col, pivot_col);
  63. if (abs(diagonal_element) < inTolerance)
  64. return false;
  65. // Divide the whole row by the pivot element, making ioA(pivot_col, pivot_col) = 1
  66. for (uint j = 0; j < n; ++j)
  67. ioA(pivot_col, j) /= diagonal_element;
  68. for (uint j = 0; j < m; ++j)
  69. ioB(pivot_col, j) /= diagonal_element;
  70. ioA(pivot_col, pivot_col) = 1.0f;
  71. // Next reduce the rows, except for the pivot one,
  72. // after this step the pivot_col column is zero except for the pivot element which is 1
  73. for (uint j = 0; j < n; ++j)
  74. if (j != pivot_col)
  75. {
  76. float element = ioA(j, pivot_col);
  77. for (uint k = 0; k < n; ++k)
  78. ioA(j, k) -= ioA(pivot_col, k) * element;
  79. for (uint k = 0; k < m; ++k)
  80. ioB(j, k) -= ioB(pivot_col, k) * element;
  81. ioA(j, pivot_col) = 0.0f;
  82. }
  83. }
  84. // Success
  85. return true;
  86. }
  87. JPH_NAMESPACE_END