Browse Source

Corrections in the implementation of '%' for floats.

The multiplication (m*b) used to test whether 'm' is non-zero and
'm' and 'b' have different signs can underflow for very small numbers,
giving a wrong result. The use of explicit comparisons solves this
problem. This commit also adds several new tests for '%' (both for
floats and for integers) to exercise more corner cases, such as
very large and very small values.
Roberto Ierusalimschy 7 years ago
parent
commit
5382a22e0e
5 changed files with 87 additions and 18 deletions
  1. 8 6
      llimits.h
  2. 1 5
      lobject.c
  3. 12 6
      lvm.c
  4. 1 0
      lvm.h
  5. 65 1
      testes/math.lua

+ 8 - 6
llimits.h

@@ -293,15 +293,17 @@ typedef unsigned long Instruction;
 #endif
 #endif
 
 
 /*
 /*
-** modulo: defined as 'a - floor(a/b)*b'; this definition gives NaN when
-** 'b' is huge, but the result should be 'a'. 'fmod' gives the result of
-** 'a - trunc(a/b)*b', and therefore must be corrected when 'trunc(a/b)
-** ~= floor(a/b)'. That happens when the division has a non-integer
-** negative result, which is equivalent to the test below.
+** modulo: defined as 'a - floor(a/b)*b'; the direct computation
+** using this definition has several problems with rounding errors,
+** so it is better to use 'fmod'. 'fmod' gives the result of
+** 'a - trunc(a/b)*b', and therefore must be corrected when
+** 'trunc(a/b) ~= floor(a/b)'. That happens when the division has a
+** non-integer negative result, which is equivalent to the tests below.
 */
 */
 #if !defined(luai_nummod)
 #if !defined(luai_nummod)
 #define luai_nummod(L,a,b,m)  \
 #define luai_nummod(L,a,b,m)  \
-  { (m) = l_mathop(fmod)(a,b); if ((m)*(b) < 0) (m) += (b); }
+  { (void)L; (m) = l_mathop(fmod)(a,b); \
+    if (((m) > 0) ? (b) < 0 : ((m) < 0 && (b) > 0)) (m) += (b); }
 #endif
 #endif
 
 
 /* exponentiation */
 /* exponentiation */

+ 1 - 5
lobject.c

@@ -106,11 +106,7 @@ static lua_Number numarith (lua_State *L, int op, lua_Number v1,
     case LUA_OPPOW: return luai_numpow(L, v1, v2);
     case LUA_OPPOW: return luai_numpow(L, v1, v2);
     case LUA_OPIDIV: return luai_numidiv(L, v1, v2);
     case LUA_OPIDIV: return luai_numidiv(L, v1, v2);
     case LUA_OPUNM: return luai_numunm(L, v1);
     case LUA_OPUNM: return luai_numunm(L, v1);
-    case LUA_OPMOD: {
-      lua_Number m;
-      luai_nummod(L, v1, v2, m);
-      return m;
-    }
+    case LUA_OPMOD: return luaV_modf(L, v1, v2);
     default: lua_assert(0); return 0;
     default: lua_assert(0); return 0;
   }
   }
 }
 }

+ 12 - 6
lvm.c

@@ -655,6 +655,16 @@ lua_Integer luaV_mod (lua_State *L, lua_Integer m, lua_Integer n) {
 }
 }
 
 
 
 
+/*
+** Float modulus
+*/
+lua_Number luaV_modf (lua_State *L, lua_Number m, lua_Number n) {
+  lua_Number r;
+  luai_nummod(L, m, n, r);
+  return r;
+}
+
+
 /* number of bits in an integer */
 /* number of bits in an integer */
 #define NBITS	cast_int(sizeof(lua_Integer) * CHAR_BIT)
 #define NBITS	cast_int(sizeof(lua_Integer) * CHAR_BIT)
 
 
@@ -1142,10 +1152,8 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), luaV_mod(L, ivalue(rb), ic));
           setivalue(s2v(ra), luaV_mod(L, ivalue(rb), ic));
         }
         }
         else if (tonumberns(rb, nb)) {
         else if (tonumberns(rb, nb)) {
-          lua_Number m;
           lua_Number nc = cast_num(ic);
           lua_Number nc = cast_num(ic);
-          luai_nummod(L, nb, nc, m);
-          setfltvalue(s2v(ra), m);
+          setfltvalue(s2v(ra), luaV_modf(L, nb, nc));
         }
         }
         else
         else
           Protect(luaT_trybiniTM(L, rb, ic, 0, ra, TM_MOD));
           Protect(luaT_trybiniTM(L, rb, ic, 0, ra, TM_MOD));
@@ -1370,9 +1378,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), luaV_mod(L, ib, ic));
           setivalue(s2v(ra), luaV_mod(L, ib, ic));
         }
         }
         else if (tonumberns(rb, nb) && tonumberns(rc, nc)) {
         else if (tonumberns(rb, nb) && tonumberns(rc, nc)) {
-          lua_Number m;
-          luai_nummod(L, nb, nc, m);
-          setfltvalue(s2v(ra), m);
+          setfltvalue(s2v(ra), luaV_modf(L, nb, nc));
         }
         }
         else
         else
           Protect(luaT_trybinTM(L, rb, rc, ra, TM_MOD));
           Protect(luaT_trybinTM(L, rb, rc, ra, TM_MOD));

+ 1 - 0
lvm.h

@@ -116,6 +116,7 @@ LUAI_FUNC void luaV_execute (lua_State *L, CallInfo *ci);
 LUAI_FUNC void luaV_concat (lua_State *L, int total);
 LUAI_FUNC void luaV_concat (lua_State *L, int total);
 LUAI_FUNC lua_Integer luaV_div (lua_State *L, lua_Integer x, lua_Integer y);
 LUAI_FUNC lua_Integer luaV_div (lua_State *L, lua_Integer x, lua_Integer y);
 LUAI_FUNC lua_Integer luaV_mod (lua_State *L, lua_Integer x, lua_Integer y);
 LUAI_FUNC lua_Integer luaV_mod (lua_State *L, lua_Integer x, lua_Integer y);
+LUAI_FUNC lua_Number luaV_modf (lua_State *L, lua_Number x, lua_Number y);
 LUAI_FUNC lua_Integer luaV_shiftl (lua_Integer x, lua_Integer y);
 LUAI_FUNC lua_Integer luaV_shiftl (lua_Integer x, lua_Integer y);
 LUAI_FUNC void luaV_objlen (lua_State *L, StkId ra, const TValue *rb);
 LUAI_FUNC void luaV_objlen (lua_State *L, StkId ra, const TValue *rb);
 
 

+ 65 - 1
testes/math.lua

@@ -1,4 +1,4 @@
--- $Id: testes/math.lua $
+-- $Id: testes/math.lua 2018-07-25 15:31:04 -0300 $
 -- See Copyright Notice in file all.lua
 -- See Copyright Notice in file all.lua
 
 
 print("testing numbers and math lib")
 print("testing numbers and math lib")
@@ -541,9 +541,73 @@ assert(eqT(-4 % 3, 2))
 assert(eqT(4 % -3, -2))
 assert(eqT(4 % -3, -2))
 assert(eqT(-4.0 % 3, 2.0))
 assert(eqT(-4.0 % 3, 2.0))
 assert(eqT(4 % -3.0, -2.0))
 assert(eqT(4 % -3.0, -2.0))
+assert(eqT(4 % -5, -1))
+assert(eqT(4 % -5.0, -1.0))
+assert(eqT(4 % 5, 4))
+assert(eqT(4 % 5.0, 4.0))
+assert(eqT(-4 % -5, -4))
+assert(eqT(-4 % -5.0, -4.0))
+assert(eqT(-4 % 5, 1))
+assert(eqT(-4 % 5.0, 1.0))
+assert(eqT(4.25 % 4, 0.25))
+assert(eqT(10.0 % 2, 0.0))
+assert(eqT(-10.0 % 2, 0.0))
+assert(eqT(-10.0 % -2, 0.0))
 assert(math.pi - math.pi % 1 == 3)
 assert(math.pi - math.pi % 1 == 3)
 assert(math.pi - math.pi % 0.001 == 3.141)
 assert(math.pi - math.pi % 0.001 == 3.141)
 
 
+do   -- very small numbers
+  local i, j = 0, 20000
+  while i < j do
+    local m = (i + j) // 2
+    if 10^-m > 0 then
+      i = m + 1
+    else
+      j = m
+    end
+  end
+  -- 'i' is the smallest possible ten-exponent
+  local b = 10^-(i - (i // 10))   -- a very small number
+  assert(b > 0 and b * b == 0)
+  local delta = b / 1000
+  assert(eq((2.1 * b) % (2 * b), (0.1 * b), delta))
+  assert(eq((-2.1 * b) % (2 * b), (2 * b) - (0.1 * b), delta))
+  assert(eq((2.1 * b) % (-2 * b), (0.1 * b) - (2 * b), delta))
+  assert(eq((-2.1 * b) % (-2 * b), (-0.1 * b), delta))
+end
+
+
+-- basic consistency between integer modulo and float modulo
+for i = -10, 10 do
+  for j = -10, 10 do
+    if j ~= 0 then
+      assert((i + 0.0) % j == i % j)
+    end
+  end
+end
+
+for i = 0, 10 do
+  for j = -10, 10 do
+    if j ~= 0 then
+      assert((2^i) % j == (1 << i) % j)
+    end
+  end
+end
+
+do    -- precision of module for large numbers
+  local i = 10
+  while (1 << i) > 0 do
+    assert((1 << i) % 3 == i % 2 + 1)
+    i = i + 1
+  end
+
+  i = 10
+  while 2^i < math.huge do
+    assert(2^i % 3 == i % 2 + 1)
+    i = i + 1
+  end
+end
+
 assert(eqT(minint % minint, 0))
 assert(eqT(minint % minint, 0))
 assert(eqT(maxint % maxint, 0))
 assert(eqT(maxint % maxint, 0))
 assert((minint + 1) % minint == minint + 1)
 assert((minint + 1) % minint == minint + 1)