瀏覽代碼

Improved detection of numerical errors in cubic equation solver

Chlumsky 3 年之前
父節點
當前提交
0b633e75f7
共有 1 個文件被更改,包括 21 次插入26 次删除
  1. 21 26
      core/equation-solver.cpp

+ 21 - 26
core/equation-solver.cpp

@@ -4,17 +4,15 @@
 #define _USE_MATH_DEFINES
 #include <cmath>
 
-#define TOO_LARGE_RATIO 1e12
-
 namespace msdfgen {
 
 int solveQuadratic(double x[2], double a, double b, double c) {
-    // a = 0 -> linear equation
-    if (a == 0 || fabs(b)+fabs(c) > TOO_LARGE_RATIO*fabs(a)) {
-        // a, b = 0 -> no solution
-        if (b == 0 || fabs(c) > TOO_LARGE_RATIO*fabs(b)) {
+    // a == 0 -> linear equation
+    if (a == 0 || fabs(b) > 1e12*fabs(a)) {
+        // a == 0, b == 0 -> no solution
+        if (b == 0) {
             if (c == 0)
-                return -1; // 0 = 0
+                return -1; // 0 == 0
             return 0;
         }
         x[0] = -c/b;
@@ -35,41 +33,38 @@ int solveQuadratic(double x[2], double a, double b, double c) {
 
 static int solveCubicNormed(double x[3], double a, double b, double c) {
     double a2 = a*a;
-    double q  = (a2 - 3*b)/9; 
-    double r  = (a*(2*a2-9*b) + 27*c)/54;
+    double q = 1/9.*(a2-3*b);
+    double r = 1/54.*(a*(2*a2-9*b)+27*c);
     double r2 = r*r;
     double q3 = q*q*q;
-    double A, B;
+    a *= 1/3.;
     if (r2 < q3) {
         double t = r/sqrt(q3);
         if (t < -1) t = -1;
         if (t > 1) t = 1;
         t = acos(t);
-        a /= 3; q = -2*sqrt(q);
-        x[0] = q*cos(t/3)-a;
-        x[1] = q*cos((t+2*M_PI)/3)-a;
-        x[2] = q*cos((t-2*M_PI)/3)-a;
+        q = -2*sqrt(q);
+        x[0] = q*cos(1/3.*t)-a;
+        x[1] = q*cos(1/3.*(t+2*M_PI))-a;
+        x[2] = q*cos(1/3.*(t-2*M_PI))-a;
         return 3;
     } else {
-        A = -pow(fabs(r)+sqrt(r2-q3), 1/3.); 
-        if (r < 0) A = -A;
-        B = A == 0 ? 0 : q/A;
-        a /= 3;
-        x[0] = (A+B)-a;
-        x[1] = -0.5*(A+B)-a;
-        x[2] = 0.5*sqrt(3.)*(A-B);
-        if (fabs(x[2]) < 1e-14)
+        double u = (r < 0 ? 1 : -1)*pow(fabs(r)+sqrt(r2-q3), 1/3.); 
+        double v = u == 0 ? 0 : q/u;
+        x[0] = (u+v)-a;
+        if (u == v || fabs(u-v) < 1e-12*fabs(u+v)) {
+            x[1] = -.5*(u+v)-a;
             return 2;
+        }
         return 1;
     }
 }
 
 int solveCubic(double x[3], double a, double b, double c, double d) {
     if (a != 0) {
-        double bn = b/a, cn = c/a, dn = d/a;
-        // Check that a isn't "almost zero"
-        if (fabs(bn) < TOO_LARGE_RATIO && fabs(cn) < TOO_LARGE_RATIO && fabs(dn) < TOO_LARGE_RATIO)
-            return solveCubicNormed(x, bn, cn, dn);
+        double bn = b/a;
+        if (fabs(bn) < 1e6) // Above this ratio, the numerical error gets larger than if we treated a as zero
+            return solveCubicNormed(x, bn, c/a, d/a);
     }
     return solveQuadratic(x, b, c, d);
 }