Browse Source

big: Fix `mul`.

Jeroen van Rijn 4 years ago
parent
commit
13fab36639
4 changed files with 168 additions and 8 deletions
  1. 4 4
      core/math/big/basic.odin
  2. 5 1
      core/math/big/build.bat
  3. 54 0
      core/math/big/test.odin
  4. 105 3
      core/math/big/test.py

+ 4 - 4
core/math/big/basic.odin

@@ -950,7 +950,7 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 			Limit ourselves to `digits` DIGITs of output.
 		*/
 		pb    := min(b.used, digits - ix);
-		carry := DIGIT(0);
+		carry := _WORD(0);
 		iy    := 0;
 		/*
 			Compute the column of the output and propagate the carry.
@@ -959,12 +959,12 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 			/*
 				Compute the column as a _WORD.
 			*/
-			column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry;
+			column := _WORD(t.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + carry;
 
 			/*
 				The new column is the lower part of the result.
 			*/
-			t.digit[ix + iy] = column & _MASK;
+			t.digit[ix + iy] = DIGIT(column & _WORD(_MASK));
 
 			/*
 				Get the carry word from the result.
@@ -975,7 +975,7 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 			Set carry if it is placed below digits
 		*/
 		if ix + iy < digits {
-			t.digit[ix + pb] = carry;
+			t.digit[ix + pb] = DIGIT(carry);
 		}
 	}
 

+ 5 - 1
core/math/big/build.bat

@@ -1,2 +1,6 @@
 @echo off
-odin run . -vet
+:odin run   . -vet
+odin build . -build-mode:dll
+
+:dumpbin /EXPORTS big.dll
+python test.py

+ 54 - 0
core/math/big/test.odin

@@ -41,4 +41,58 @@ PyRes :: struct {
 	r, err = int_itoa_cstring(sum, i8(radix), context.temp_allocator);
 	if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; }
 	return PyRes{res = r, err = .None};
+}
+
+@export test_sub_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	aa, bb, sum := &Int{}, &Int{}, &Int{};
+	defer destroy(aa, bb, sum);
+
+	if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; }
+	if err = sub(sum, aa, bb);               err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(sum, i8(radix), context.temp_allocator);
+	if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+@export test_mul_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	aa, bb, product := &Int{}, &Int{}, &Int{};
+	defer destroy(aa, bb, product);
+
+	if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; }
+	if err = mul(product, aa, bb);           err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(product, i8(radix), context.temp_allocator);
+	if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+/*
+	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
+*/
+@export test_div_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	aa, bb, quotient := &Int{}, &Int{}, &Int{};
+	defer destroy(aa, bb, quotient);
+
+	if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; }
+	if err = div(quotient, aa, bb);          err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(quotient, i8(radix), context.temp_allocator);
+	if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
+	return PyRes{res = r, err = .None};
 }

+ 105 - 3
core/math/big/test.py

@@ -1,5 +1,6 @@
 from  math import *
 from ctypes import *
+from random import *
 import os
 
 #
@@ -38,6 +39,9 @@ except:
 	print("Couldn't find or load " + LIB_PATH + ".")
 	exit(1)
 
+#
+# res = a + b, err
+#
 try:
 	l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong]
 	l.test_add_two.restype  = Res
@@ -47,6 +51,44 @@ except:
 
 add_two = l.test_add_two
 
+#
+# res = a - b, err
+#
+try:
+	l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong]
+	l.test_sub_two.restype  = Res
+except:
+	print("Couldn't find exported function 'test_sub_two'")
+	exit(2)
+
+sub_two = l.test_sub_two
+
+#
+# res = a * b, err
+#
+try:
+	l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong]
+	l.test_mul_two.restype  = Res
+except:
+	print("Couldn't find exported function 'test_add_two'")
+	exit(2)
+
+mul_two = l.test_mul_two
+
+#
+# res = a / b, err
+#
+try:
+	l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong]
+	l.test_div_two.restype  = Res
+except:
+	print("Couldn't find exported function 'test_div_two'")
+	exit(2)
+
+div_two = l.test_div_two
+
+
+
 try:
 	l.test_error_string.argtypes = [c_byte]
 	l.test_error_string.restype  = c_char_p
@@ -91,15 +133,52 @@ def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_res
 		expected_result = a + b
 	return test("test_add_two", res, [str(a), str(b), radix], expected_error, expected_result)
 
+def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
+	res = sub_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
+	if expected_result == None:
+		expected_result = a - b
+	return test("test_sub_two", res, [str(a), str(b), radix], expected_error, expected_result)
+
+def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
+	res = mul_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
+	if expected_result == None:
+		expected_result = a * b
+	return test("test_mul_two", res, [str(a), str(b), radix], expected_error, expected_result)
+
+def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
+	res = div_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
+	if expected_result == None:
+		expected_result = a // b if b != 0 else None
+	return test("test_add_two", res, [str(a), str(b), radix], expected_error, expected_result)
+
+# TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
+#
+# The last two arguments in tests are the expected error and expected result.
+#
+# The expected error defaults to None.
+# By default the Odin implementation will be tested against the Python one.
+# You can override that by supplying an expected result as the last argument instead.
 
 TESTS = {
 	test_add_two: [
-		[ 1234, 5432,  10,                     ],
-		[ 1234, 5432, 110, E_Invalid_Argument, ],
+		[ 1234,   5432,    10, ],
+		[ 1234,   5432,   110, E_Invalid_Argument, ],
+	],
+	test_sub_two: [
+		[ 1234,   5432,    10, ],
+	],
+	test_mul_two: [
+		[ 1234,   5432,    10, ],
+		[ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
+	],
+	test_div_two: [
+		[ 54321, 12345,    10, ],
+		[ 55431,     0,    10, E_Division_by_Zero, ],
 	],
 }
 
 if __name__ == '__main__':
+	print()
 	print("---- core:math/big tests ----")
 	print()
 
@@ -112,4 +191,27 @@ if __name__ == '__main__':
 			else:
 				count_fail += 1
 
-		print("{}: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))
+		print("{}: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))
+
+	print()		
+	print("---- core:math/big random tests ----")
+	print()
+
+	for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]:
+		count_pass = 0
+		count_fail = 0
+
+		a = randint(0, 1 << 120)
+		b = randint(0, 1 << 120)
+		res = None
+
+		# We've already tested division by zero above.
+		if b == 0 and test_proc == test_div_two:
+			b = b + 1
+
+		if test_proc(a, b):
+			count_pass += 1
+		else:
+			count_fail += 1
+
+		print("{} random: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))