Browse Source

big: Test `root_n`.

Jeroen van Rijn 4 years ago
parent
commit
db0196abc7
4 changed files with 64 additions and 47 deletions
  1. 2 0
      core/math/big/common.odin
  2. 13 20
      core/math/big/exp_log.odin
  3. 18 0
      core/math/big/test.odin
  4. 31 27
      core/math/big/test.py

+ 2 - 0
core/math/big/common.odin

@@ -35,6 +35,8 @@ _DEFAULT_SQR_KARATSUBA_CUTOFF :: 120;
 _DEFAULT_MUL_TOOM_CUTOFF      :: 350;
 _DEFAULT_SQR_TOOM_CUTOFF      :: 400;
 
+_MAX_ITERATIONS_ROOT_N        :: 500;
+
 Sign :: enum u8 {
 	Zero_or_Positive = 0,
 	Negative         = 1,

+ 13 - 20
core/math/big/exp_log.odin

@@ -233,15 +233,9 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
 		if count, err = count_bits(src); err != .None { return err; }
 
 		a, b := count >> 1, count & 1;
-		err = power_of_two(x, a+b);
+		if err = power_of_two(x, a+b);                  err != .None { return err; }
 
-		iter := 0;
 		for {
-			iter += 1;
-			if iter > 100 {
-				swap(dest, x);
-				return .Max_Iterations_Reached;
-			}
 			/*
 				y = (x + n//x)//2
 			*/
@@ -274,7 +268,7 @@ sqrt :: proc { int_sqrt, };
 */
 int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 	/*						Fast path for n == 2 						*/
-	// if n == 2 { return sqrt(dest, src); }
+	if n == 2 { return sqrt(dest, src); }
 
 	/*					Initialize dest + src if needed. 				*/
 	if err = clear_if_uninitialized(dest);			err != .None { return err; }
@@ -366,7 +360,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 		}
 		if c, err = cmp(t1, t2); c == 0 { break; }
 		iterations += 1;
-		if iterations == 101 {
+		if iterations == _MAX_ITERATIONS_ROOT_N {
 			return .Max_Iterations_Reached;
 		}
 	}
@@ -376,12 +370,6 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 
 	iterations = 0;
 	for {
-		if iterations == 101 {
-			return .Max_Iterations_Reached;
-		}
-		//fmt.printf("root_n iteration: %v\n", iterations);
-		iterations += 1;
-
 		if err = pow(t2, t1, n); err != .None { return err; }
 
 		c, err = cmp(t2, a);
@@ -393,16 +381,16 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 		} else {
 			break;
 		}
+
+		iterations += 1;
+		if iterations == _MAX_ITERATIONS_ROOT_N {
+			return .Max_Iterations_Reached;
+		}
 	}
 
 	iterations = 0;
 	/*					Correct overshoot from above or from recurrence.			*/
 	for {
-		if iterations == 101 {
-			return .Max_Iterations_Reached;
-		}
-		iterations += 1;
-
 		if err = pow(t2, t1, n); err != .None { return err; }
 
 		c, err = cmp(t2, a);
@@ -411,6 +399,11 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 		} else {
 			break;
 		}
+
+		iterations += 1;
+		if iterations == _MAX_ITERATIONS_ROOT_N {
+			return .Max_Iterations_Reached;
+		}
 	}
 
 	/*								Set the result.									*/

+ 18 - 0
core/math/big/test.odin

@@ -162,6 +162,24 @@ PyRes :: struct {
 	return PyRes{res = r, err = .None};
 }
 
+/*
+	dest = root_n(src, power)
+*/
+@export test_root_n :: proc "c" (source: cstring, power: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":root_n:atoi(src):", err=err}; }
+	if err = root_n(src, src, power);       err != .None { return PyRes{res=":root_n:root_n(src):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":root_n:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
 
 /*
 	dest = shr_digit(src, digits)

+ 31 - 27
core/math/big/test.py

@@ -121,9 +121,10 @@ mul  = load(l.test_mul, [c_char_p, c_char_p], Res)
 div  = load(l.test_div, [c_char_p, c_char_p], Res)
 
 # Powers and such
-int_log  = load(l.test_log,  [c_char_p, c_longlong], Res)
-int_pow  = load(l.test_pow,  [c_char_p, c_longlong], Res)
-int_sqrt = load(l.test_sqrt, [c_char_p], Res)
+int_log    = load(l.test_log,  [c_char_p, c_longlong], Res)
+int_pow    = load(l.test_pow,  [c_char_p, c_longlong], Res)
+int_sqrt   = load(l.test_sqrt, [c_char_p], Res)
+int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
 
 # Logical operations
 
@@ -266,6 +267,23 @@ def root_n(number, root):
 		u = t // root
 	return s
 
+def test_root_n(number = 0, root = 0, expected_error = Error.Okay):
+	args  = [str(number), root]
+	sa_c  = args[0].encode('utf-8')
+	try:
+		res   = int_root_n(sa_c, root)
+	except:
+		print("root_n:", number, root)
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		if number < 0:
+			expected_result = 0
+		else:
+			expected_result = root_n(number, root)
+
+	return test("test_root_n", res, args, expected_error, expected_result)
+
 def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
 	args  = [str(a), digits]
 	sa_c  = args[0].encode('utf-8')
@@ -384,6 +402,9 @@ TESTS = {
 		[  12345678901234567890, Error.Okay, ],
 		[  1298074214633706907132624082305024, Error.Okay, ],
 	],
+	test_root_n: [
+		[  1298074214633706907132624082305024, 2, Error.Okay, ],	
+	],
 	test_shl_digit: [
 		[ 3192,			1 ],
 		[ 1298074214633706907132624082305024, 2 ],
@@ -420,10 +441,12 @@ total_failures = 0
 #
 RANDOM_TESTS = [
 	test_add, test_sub, test_mul, test_div,
-	test_log, test_pow, test_sqrt,
+	test_log, test_pow, test_sqrt, test_root_n,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 ]
-SKIP_LARGE   = [test_pow]
+SKIP_LARGE   = [
+	test_pow, test_root_n,
+]
 SKIP_LARGEST = []
 
 # Untimed warmup.
@@ -431,28 +454,6 @@ for test_proc in TESTS:
 	for t in TESTS[test_proc]:
 		res   = test_proc(*t)
 
-
-def isqrt(x):
-	n = int(x)
-	a, b = divmod(n.bit_length(), 2)
-	print("isqrt({}), a: {}, b: {}". format(n, a, b))
-	x = 2**(a+b)
-	print("initial: {}".format(x))
-	i = 0
-	while True:
-		# y = (x + n//x)//2
-		t1 = n // x
-		t2 = x + t1
-		t3 = t2 // 2
-		y = (x + n//x)//2
-
-		i += 1
-		print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
-
-		if y >= x:
-			return x
-		x = y
-
 if __name__ == '__main__':
 	print("---- math/big tests ----")
 	print()
@@ -517,6 +518,9 @@ if __name__ == '__main__':
 				elif test_proc == test_sqrt:
 					a = randint(1, 1 << BITS)
 					b = Error.Okay
+				elif test_proc == test_root_n:
+					a = randint(1, 1 << BITS)
+					b = randint(1, 10);
 				elif test_proc == test_shl_digit:
 					b = randint(0, 10);
 				elif test_proc == test_shr_digit: