Browse Source

big: Fix `sqrt`, `div`, `add` with certain inputs.

Jeroen van Rijn 4 years ago
parent
commit
149c7b88df

+ 13 - 17
core/math/big/basic.odin

@@ -359,7 +359,7 @@ int_halve :: proc(dest, src: ^Int) -> (err: Error) {
 	*/
 	*/
 	fwd_carry := DIGIT(0);
 	fwd_carry := DIGIT(0);
 
 
-	for x := dest.used; x >= 0; x -= 1 {
+	for x := dest.used - 1; x >= 0; x -= 1 {
 		/*
 		/*
 			Get the carry for the next iteration.
 			Get the carry for the next iteration.
 		*/
 		*/
@@ -761,21 +761,16 @@ sqrmod :: proc { int_sqrmod, };
 */
 */
 _int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 _int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 	dest := dest; x := a; y := b;
 	dest := dest; x := a; y := b;
-	if err = clear_if_uninitialized(x); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(y); err != .None {
-		return err;
-	}
 
 
 	old_used, min_used, max_used, i: int;
 	old_used, min_used, max_used, i: int;
 
 
 	if x.used < y.used {
 	if x.used < y.used {
 		x, y = y, x;
 		x, y = y, x;
+		assert(x.used >= y.used);
 	}
 	}
 
 
-	min_used = x.used;
-	max_used = y.used;
+	min_used = y.used;
+	max_used = x.used;
 	old_used = dest.used;
 	old_used = dest.used;
 
 
 	if err = grow(dest, max(max_used + 1, _DEFAULT_DIGIT_COUNT)); err != .None {
 	if err = grow(dest, max(max_used + 1, _DEFAULT_DIGIT_COUNT)); err != .None {
@@ -827,7 +822,6 @@ _int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 		Add remaining carry.
 		Add remaining carry.
 	*/
 	*/
 	dest.digit[i] = carry;
 	dest.digit[i] = carry;
-
 	zero_count := old_used - dest.used;
 	zero_count := old_used - dest.used;
 	/*
 	/*
 		Zero remainder.
 		Zero remainder.
@@ -1111,6 +1105,7 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) {
 _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
 _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
 
 
 	ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
 	ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
+	c: int;
 
 
 	goto_end: for {
 	goto_end: for {
 		if err = one(tq);									err != .None { break goto_end; }
 		if err = one(tq);									err != .None { break goto_end; }
@@ -1121,20 +1116,21 @@ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (er
 
 
 		if err = abs(ta, numerator);						err != .None { break goto_end; }
 		if err = abs(ta, numerator);						err != .None { break goto_end; }
 		if err = abs(tb, denominator);						err != .None { break goto_end; }
 		if err = abs(tb, denominator);						err != .None { break goto_end; }
-
 		if err = shl(tb, tb, n);							err != .None { break goto_end; }
 		if err = shl(tb, tb, n);							err != .None { break goto_end; }
 		if err = shl(tq, tq, n);							err != .None { break goto_end; }
 		if err = shl(tq, tq, n);							err != .None { break goto_end; }
 
 
-		for ; n >= 0; n -= 1 {
-			c: int;
-			if c, err = cmp(tb, ta);						err != .None { break goto_end; }
-			if c != 1 {
+		for n >= 0 {
+			if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
+				// ta -= tb
 				if err = sub(ta, ta, tb);					err != .None { break goto_end; }
 				if err = sub(ta, ta, tb);					err != .None { break goto_end; }
-				if err = add( q, tq,  q);					err != .None { break goto_end; }
+				//  q += tq
+				if err = add( q, q,  tq);					err != .None { break goto_end; }
 			}
 			}
 			if err = shr1(tb, tb);							err != .None { break goto_end; }
 			if err = shr1(tb, tb);							err != .None { break goto_end; }
 			if err = shr1(tq, tq);							err != .None { break goto_end; }
 			if err = shr1(tq, tq);							err != .None { break goto_end; }
-		}		
+
+			n -= 1;
+		}
 
 
 		/*
 		/*
 			Now q == quotient and ta == remainder.
 			Now q == quotient and ta == remainder.

+ 0 - 1
core/math/big/build.bat

@@ -1,5 +1,4 @@
 @echo off
 @echo off
-clear
 :odin run   . -vet
 :odin run   . -vet
 :odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
 :odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
 odin build . -build-mode:shared -show-timings -o:size -use-separate-modules
 odin build . -build-mode:shared -show-timings -o:size -use-separate-modules

+ 18 - 9
core/math/big/example.odin

@@ -66,17 +66,26 @@ print :: proc(name: string, a: ^Int, base := i8(10)) {
 }
 }
 
 
 demo :: proc() {
 demo :: proc() {
-	// err: Error;
-	// r := &rnd.Rand{};
-	// rnd.init(r, 12345);
+	err: Error;
+	destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
+	defer destroy(destination, source, quotient, remainder, numerator, denominator);
 
 
-	// as := cstring("12341234");
-	// bs := cstring("159671292010002348397151706347412301");
+	err = atoi(source, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
+	print("src    ", source);	
 
 
-	// res := test_log(as, 2, 10);
-	// fmt.print(res);
-	// destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
-	// defer destroy(destination, source, quotient, remainder, numerator, denominator);
+	fmt.println("sqrt should be 843478780275248664696797599030708027195155136953848512749494");
+
+	fmt.println();
+	err = sqrt(destination, source);
+	fmt.printf("sqrt returned: %v\n", err);
+	print("sqrt   ", destination);
+
+	err = atoi(denominator, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
+	err = root_n(quotient, denominator, 2);
+	fmt.printf("root_n(2) returned: %v\n", err);
+	print("root_n(2)", quotient);
+
+	// fmt.println();
 }
 }
 
 
 main :: proc() {
 main :: proc() {

+ 58 - 29
core/math/big/exp_log.odin

@@ -214,42 +214,53 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
 	This function is less generic than `root_n`, simpler and faster.
 	This function is less generic than `root_n`, simpler and faster.
 */
 */
 int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
 int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
-	if err = clear_if_uninitialized(dest);			err != .None { return err; }
-	if err = clear_if_uninitialized(src);			err != .None { return err; }
 
 
-	/*						Must be positive. 					*/
-	if src.sign == .Negative						{ return .Invalid_Argument; }
+	when true {
+		if err = clear_if_uninitialized(dest);			err != .None { return err; }
+		if err = clear_if_uninitialized(src);			err != .None { return err; }
 
 
-	/*			Easy out. If src is zero, so is dest.			*/
-	if z, _ := is_zero(src); 						z { return zero(dest); }
+		/*						Must be positive. 					*/
+		if src.sign == .Negative						{ return .Invalid_Argument; }
 
 
-	/*						Set up temporaries.					*/
-	t1, t2 := &Int{}, &Int{};
-	defer destroy(t1, t2);
+		/*			Easy out. If src is zero, so is dest.			*/
+		if z, _ := is_zero(src); 						z { return zero(dest); }
 
 
-	if err = copy(t1, src);							err != .None { return err; }
-	if err = zero(t2);								err != .None { return err; }
+		/*						Set up temporaries.					*/
+		x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{};
+		defer destroy(x, y, t1, t2);
 
 
-	/*	First approximation. Not very bad for large arguments.	*/
-	if err = shr_digit(t1, t1.used / 2);			err != .None { return err; }
-	/*							t1 > 0 							*/
-	if err = div(t2, src, t1);						err != .None { return err; }
-	if err = add(t1, t1, t2);						err != .None { return err; }
-	if err = shr(t1, t1, 1);						err != .None { return err; }
+		count: int;
+		if count, err = count_bits(src); err != .None { return err; }
 
 
-	/*					And now t1 > sqrt(arg).					*/
-	for {
-		if err = div(t2, src, t1);						err != .None { return err; }
-		if err = add(t1, t1, t2);						err != .None { return err; }
-		if err = shr(t1, t1, 1);						err != .None { return err; }
-		/* t1 >= sqrt(arg) >= t2 at this point */
+		a, b := count >> 1, count & 1;
+		err = power_of_two(x, a+b);
 
 
-		cm, _ := cmp_mag(t1, t2);
-		if cm != 1 { break; }
-	}
+		iter := 0;
+		for {
+			iter += 1;
+			if iter > 100 {
+				swap(dest, x);
+				return .Max_Iterations_Reached;
+			}
+			/*
+				y = (x + n//x)//2
+			*/
+			div(t1, src, x);
+			add(t2, t1, x);
+			shr(y, t2, 1);
+
+			if c, _ := cmp(y, x); c == 0 || c == 1 {
+				swap(dest, x);
+				return .None;
+			}
+			swap(x, y);
+		}
 
 
-	swap(dest, t1);
-	return err;
+		swap(dest, x);
+		return err;
+	} else {
+		// return root_n(dest, src, 2);
+	}
 }
 }
 sqrt :: proc { int_sqrt, };
 sqrt :: proc { int_sqrt, };
 
 
@@ -263,7 +274,7 @@ sqrt :: proc { int_sqrt, };
 */
 */
 int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 	/*						Fast path for n == 2 						*/
 	/*						Fast path for n == 2 						*/
-	if n == 2 { return sqrt(dest, src); }
+	// if n == 2 { return sqrt(dest, src); }
 
 
 	/*					Initialize dest + src if needed. 				*/
 	/*					Initialize dest + src if needed. 				*/
 	if err = clear_if_uninitialized(dest);			err != .None { return err; }
 	if err = clear_if_uninitialized(dest);			err != .None { return err; }
@@ -321,6 +332,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 	if err = power_of_two(t2, ilog2); err != .None { return err; }
 	if err = power_of_two(t2, ilog2); err != .None { return err; }
 
 
 	c: int;
 	c: int;
+	iterations := 0;
 	for {
 	for {
 		/* t1 = t2 */
 		/* t1 = t2 */
 		if err = copy(t1, t2); err != .None { return err; }
 		if err = copy(t1, t2); err != .None { return err; }
@@ -353,12 +365,23 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 			break;
 			break;
 		}
 		}
 		if c, err = cmp(t1, t2); c == 0 { break; }
 		if c, err = cmp(t1, t2); c == 0 { break; }
+		iterations += 1;
+		if iterations == 101 {
+			return .Max_Iterations_Reached;
+		}
 	}
 	}
 
 
 	/*						Result can be off by a few so check.					*/
 	/*						Result can be off by a few so check.					*/
 	/* Loop beneath can overshoot by one if found root is smaller than actual root. */
 	/* Loop beneath can overshoot by one if found root is smaller than actual root. */
 
 
+	iterations = 0;
 	for {
 	for {
+		if iterations == 101 {
+			return .Max_Iterations_Reached;
+		}
+		//fmt.printf("root_n iteration: %v\n", iterations);
+		iterations += 1;
+
 		if err = pow(t2, t1, n); err != .None { return err; }
 		if err = pow(t2, t1, n); err != .None { return err; }
 
 
 		c, err = cmp(t2, a);
 		c, err = cmp(t2, a);
@@ -372,8 +395,14 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
 		}
 		}
 	}
 	}
 
 
+	iterations = 0;
 	/*					Correct overshoot from above or from recurrence.			*/
 	/*					Correct overshoot from above or from recurrence.			*/
 	for {
 	for {
+		if iterations == 101 {
+			return .Max_Iterations_Reached;
+		}
+		iterations += 1;
+
 		if err = pow(t2, t1, n); err != .None { return err; }
 		if err = pow(t2, t1, n); err != .None { return err; }
 
 
 		c, err = cmp(t2, a);
 		c, err = cmp(t2, a);

+ 5 - 5
core/math/big/logical.odin

@@ -285,7 +285,7 @@ int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int) -> (err: Err
 		shift := DIGIT(_DIGIT_BITS - bits);
 		shift := DIGIT(_DIGIT_BITS - bits);
 		carry := DIGIT(0);
 		carry := DIGIT(0);
 
 
-		for x := quotient.used; x >= 0; x -= 1 {
+		for x := quotient.used - 1; x >= 0; x -= 1 {
 			/*
 			/*
 				Get the lower bits of this word in a temp.
 				Get the lower bits of this word in a temp.
 			*/
 			*/
@@ -344,7 +344,7 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
 	}
 	}
 	quotient.used -= digits;
 	quotient.used -= digits;
 	_zero_unused(quotient);
 	_zero_unused(quotient);
-	return .None;
+	return clamp(quotient);
 }
 }
 shr_digit :: proc { int_shr_digit, };
 shr_digit :: proc { int_shr_digit, };
 
 
@@ -446,16 +446,16 @@ int_shl_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
 	/*
 	/*
 		Increment the used by the shift amount then copy upwards.
 		Increment the used by the shift amount then copy upwards.
 	*/
 	*/
-   	quotient.used += digits;
 
 
 	/*
 	/*
 		Much like `int_shr_digit`, this is implemented using a sliding window,
 		Much like `int_shr_digit`, this is implemented using a sliding window,
 		except the window goes the other way around.
 		except the window goes the other way around.
     */
     */
-    for x := quotient.used; x >= digits; x -= 1 {
-    	quotient.digit[x] = quotient.digit[x - digits];
+    for x := quotient.used; x > 0; x -= 1 {
+    	quotient.digit[x+digits-1] = quotient.digit[x-1];
     }
     }
 
 
+   	quotient.used += digits;
     mem.zero_slice(quotient.digit[:digits]);
     mem.zero_slice(quotient.digit[:digits]);
     return .None;
     return .None;
 }
 }

+ 137 - 22
core/math/big/test.odin

@@ -26,74 +26,74 @@ PyRes :: struct {
 	return strings.clone_to_cstring(es[err], context.temp_allocator);
 	return strings.clone_to_cstring(es[err], context.temp_allocator);
 }
 }
 
 
-@export test_add_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
+@export test_add :: proc "c" (a, b: cstring) -> (res: PyRes) {
 	context = runtime.default_context();
 	context = runtime.default_context();
 	err: Error;
 	err: Error;
 
 
 	aa, bb, sum := &Int{}, &Int{}, &Int{};
 	aa, bb, sum := &Int{}, &Int{}, &Int{};
 	defer destroy(aa, bb, sum);
 	defer destroy(aa, bb, sum);
 
 
-	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add_two:atoi(a):", err=err}; }
-	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add_two:atoi(b):", err=err}; }
-	if err = add(sum, aa, bb);        err != .None { return PyRes{res=":add_two:add(sum,a,b):", err=err}; }
+	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add:atoi(b):", err=err}; }
+	if err = add(sum, aa, bb);        err != .None { return PyRes{res=":add:add(sum,a,b):", err=err}; }
 
 
 	r: cstring;
 	r: cstring;
 	r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
 	r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
-	if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; }
+	if err != .None { return PyRes{res=":add:itoa(sum):", err=err}; }
 	return PyRes{res = r, err = .None};
 	return PyRes{res = r, err = .None};
 }
 }
 
 
-@export test_sub_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
+@export test_sub :: proc "c" (a, b: cstring) -> (res: PyRes) {
 	context = runtime.default_context();
 	context = runtime.default_context();
 	err: Error;
 	err: Error;
 
 
 	aa, bb, sum := &Int{}, &Int{}, &Int{};
 	aa, bb, sum := &Int{}, &Int{}, &Int{};
 	defer destroy(aa, bb, sum);
 	defer destroy(aa, bb, sum);
 
 
-	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; }
-	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; }
-	if err = sub(sum, aa, bb);        err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; }
+	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub:atoi(b):", err=err}; }
+	if err = sub(sum, aa, bb);        err != .None { return PyRes{res=":sub:sub(sum,a,b):", err=err}; }
 
 
 	r: cstring;
 	r: cstring;
 	r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
 	r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
-	if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; }
+	if err != .None { return PyRes{res=":sub:itoa(sum):", err=err}; }
 	return PyRes{res = r, err = .None};
 	return PyRes{res = r, err = .None};
 }
 }
 
 
-@export test_mul_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
+@export test_mul :: proc "c" (a, b: cstring) -> (res: PyRes) {
 	context = runtime.default_context();
 	context = runtime.default_context();
 	err: Error;
 	err: Error;
 
 
 	aa, bb, product := &Int{}, &Int{}, &Int{};
 	aa, bb, product := &Int{}, &Int{}, &Int{};
 	defer destroy(aa, bb, product);
 	defer destroy(aa, bb, product);
 
 
-	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; }
-	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; }
-	if err = mul(product, aa, bb);    err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; }
+	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul:atoi(b):", err=err}; }
+	if err = mul(product, aa, bb);    err != .None { return PyRes{res=":mul:mul(product,a,b):", err=err}; }
 
 
 	r: cstring;
 	r: cstring;
 	r, err = int_itoa_cstring(product, 10, context.temp_allocator);
 	r, err = int_itoa_cstring(product, 10, context.temp_allocator);
-	if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; }
+	if err != .None { return PyRes{res=":mul:itoa(product):", err=err}; }
 	return PyRes{res = r, err = .None};
 	return PyRes{res = r, err = .None};
 }
 }
 
 
 /*
 /*
 	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
 	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
 */
 */
-@export test_div_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
+@export test_div :: proc "c" (a, b: cstring) -> (res: PyRes) {
 	context = runtime.default_context();
 	context = runtime.default_context();
 	err: Error;
 	err: Error;
 
 
 	aa, bb, quotient := &Int{}, &Int{}, &Int{};
 	aa, bb, quotient := &Int{}, &Int{}, &Int{};
 	defer destroy(aa, bb, quotient);
 	defer destroy(aa, bb, quotient);
 
 
-	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; }
-	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; }
-	if err = div(quotient, aa, bb);   err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; }
+	if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div:atoi(a):", err=err}; }
+	if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div:atoi(b):", err=err}; }
+	if err = div(quotient, aa, bb);   err != .None { return PyRes{res=":div:div(quotient,a,b):", err=err}; }
 
 
 	r: cstring;
 	r: cstring;
 	r, err = int_itoa_cstring(quotient, 10, context.temp_allocator);
 	r, err = int_itoa_cstring(quotient, 10, context.temp_allocator);
-	if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
+	if err != .None { return PyRes{res=":div:itoa(quotient):", err=err}; }
 	return PyRes{res = r, err = .None};
 	return PyRes{res = r, err = .None};
 }
 }
 
 
@@ -130,7 +130,6 @@ PyRes :: struct {
 @export test_pow :: proc "c" (base: cstring, power := int(2)) -> (res: PyRes) {
 @export test_pow :: proc "c" (base: cstring, power := int(2)) -> (res: PyRes) {
 	context = runtime.default_context();
 	context = runtime.default_context();
 	err: Error;
 	err: Error;
-	l: int;
 
 
 	dest, bb := &Int{}, &Int{};
 	dest, bb := &Int{}, &Int{};
 	defer destroy(dest, bb);
 	defer destroy(dest, bb);
@@ -142,4 +141,120 @@ PyRes :: struct {
 	r, err = int_itoa_cstring(dest, 10, context.temp_allocator);
 	r, err = int_itoa_cstring(dest, 10, context.temp_allocator);
 	if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
 	if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
 	return PyRes{res = r, err = .None};
 	return PyRes{res = r, err = .None};
-}
+}
+
+/*
+	dest = sqrt(src)
+*/
+@export test_sqrt :: proc "c" (source: cstring) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":sqrt:atoi(src):", err=err}; }
+	if err = sqrt(src, src);                err != .None { return PyRes{res=":sqrt:sqrt(src):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+
+/*
+	dest = shr_digit(src, digits)
+*/
+@export test_shr_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_digit:atoi(src):", err=err}; }
+	if err = shr_digit(src, digits);        err != .None { return PyRes{res=":shr_digit:shr_digit(src):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":shr_digit:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+/*
+	dest = shl_digit(src, digits)
+*/
+@export test_shl_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl_digit:atoi(src):", err=err}; }
+	if err = shl_digit(src, digits);        err != .None { return PyRes{res=":shl_digit:shr_digit(src):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":shl_digit:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+/*
+	dest = shr(src, bits)
+*/
+@export test_shr :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr:atoi(src):", err=err}; }
+	if err = shr(src, src, bits);           err != .None { return PyRes{res=":shr:shr(src, bits):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":shr:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+/*
+	dest = shr_signed(src, bits)
+*/
+@export test_shr_signed :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_signed:atoi(src):", err=err}; }
+	if err = shr_signed(src, src, bits);    err != .None { return PyRes{res=":shr_signed:shr_signed(src, bits):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":shr_signed:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+
+/*
+	dest = shl(src, bits)
+*/
+@export test_shl :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	src := &Int{};
+	defer destroy(src);
+
+	if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl:atoi(src):", err=err}; }
+	if err = shl(src, src, bits);           err != .None { return PyRes{res=":shl:shl(src, bits):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(src, 10, context.temp_allocator);
+	if err != .None { return PyRes{res=":shl:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}
+

+ 224 - 39
core/math/big/test.py

@@ -1,6 +1,6 @@
-from  math import *
 from ctypes import *
 from ctypes import *
 from random import *
 from random import *
+import math
 import os
 import os
 import platform
 import platform
 import time
 import time
@@ -10,12 +10,14 @@ from enum import Enum
 # Normally, we report the number of passes and fails.
 # Normally, we report the number of passes and fails.
 # With EXIT_ON_FAIL set, we exit at the first fail.
 # With EXIT_ON_FAIL set, we exit at the first fail.
 #
 #
+EXIT_ON_FAIL = True
 EXIT_ON_FAIL = False
 EXIT_ON_FAIL = False
 
 
 #
 #
 # We skip randomized tests altogether if NO_RANDOM_TESTS is set.
 # We skip randomized tests altogether if NO_RANDOM_TESTS is set.
 #
 #
-NO_RANDOM_TESTS = False #True
+NO_RANDOM_TESTS = True
+NO_RANDOM_TESTS = False
 
 
 #
 #
 # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
 # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
@@ -113,13 +115,23 @@ class Res(Structure):
 
 
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 
 
-add_two = load(l.test_add_two, [c_char_p, c_char_p], Res)
-sub_two = load(l.test_sub_two, [c_char_p, c_char_p], Res)
-mul_two = load(l.test_mul_two, [c_char_p, c_char_p], Res)
-div_two = load(l.test_div_two, [c_char_p, c_char_p], Res)
+add  = load(l.test_add, [c_char_p, c_char_p], Res)
+sub  = load(l.test_sub, [c_char_p, c_char_p], Res)
+mul  = load(l.test_mul, [c_char_p, c_char_p], Res)
+div  = load(l.test_div, [c_char_p, c_char_p], Res)
 
 
-int_log = load(l.test_log, [c_char_p, c_longlong], Res)
-int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
+# Powers and such
+int_log  = load(l.test_log,  [c_char_p, c_longlong], Res)
+int_pow  = load(l.test_pow,  [c_char_p, c_longlong], Res)
+int_sqrt = load(l.test_sqrt, [c_char_p], Res)
+
+# Logical operations
+
+int_shl_digit  = load(l.test_shl_digit, [c_char_p, c_longlong], Res)
+int_shr_digit  = load(l.test_shr_digit, [c_char_p, c_longlong], Res)
+int_shl        = load(l.test_shl, [c_char_p, c_longlong], Res)
+int_shr        = load(l.test_shr, [c_char_p, c_longlong], Res)
+int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
 
 
 def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
 def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
 	passed = True
 	passed = True
@@ -156,38 +168,38 @@ def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expecte
 	return passed
 	return passed
 
 
 
 
-def test_add_two(a = 0, b = 0, expected_error = Error.Okay):
+def test_add(a = 0, b = 0, expected_error = Error.Okay):
 	args = [str(a), str(b)]
 	args = [str(a), str(b)]
 	sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
 	sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
-	res  = add_two(sa_c, sb_c)
+	res  = add(sa_c, sb_c)
 	expected_result = None
 	expected_result = None
 	if expected_error == Error.Okay:
 	if expected_error == Error.Okay:
 		expected_result = a + b
 		expected_result = a + b
-	return test("test_add_two", res, args, expected_error, expected_result)
+	return test("test_add", res, args, expected_error, expected_result)
 
 
-def test_sub_two(a = 0, b = 0, expected_error = Error.Okay):
+def test_sub(a = 0, b = 0, expected_error = Error.Okay):
 	sa,     sb = str(a), str(b)
 	sa,     sb = str(a), str(b)
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
-	res  = sub_two(sa_c, sb_c)
+	res  = sub(sa_c, sb_c)
 	expected_result = None
 	expected_result = None
 	if expected_error == Error.Okay:
 	if expected_error == Error.Okay:
 		expected_result = a - b
 		expected_result = a - b
-	return test("test_sub_two", res, [sa_c, sb_c], expected_error, expected_result)
+	return test("test_sub", res, [sa_c, sb_c], expected_error, expected_result)
 
 
-def test_mul_two(a = 0, b = 0, expected_error = Error.Okay):
+def test_mul(a = 0, b = 0, expected_error = Error.Okay):
 	sa,     sb = str(a), str(b)
 	sa,     sb = str(a), str(b)
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
-	res  = mul_two(sa_c, sb_c)
+	res  = mul(sa_c, sb_c)
 	expected_result = None
 	expected_result = None
 	if expected_error == Error.Okay:
 	if expected_error == Error.Okay:
 		expected_result = a * b
 		expected_result = a * b
-	return test("test_mul_two", res, [sa_c, sb_c], expected_error, expected_result)
+	return test("test_mul", res, [sa_c, sb_c], expected_error, expected_result)
 
 
-def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
+def test_div(a = 0, b = 0, expected_error = Error.Okay):
 	sa,     sb = str(a), str(b)
 	sa,     sb = str(a), str(b)
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	try:
 	try:
-		res  = div_two(sa_c, sb_c)
+		res  = div(sa_c, sb_c)
 	except:
 	except:
 		print("Exception with arguments:", a, b)
 		print("Exception with arguments:", a, b)
 		return False
 		return False
@@ -197,12 +209,12 @@ def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
 		# We don't round the division results, so if one component is negative, we're off by one.
 		# We don't round the division results, so if one component is negative, we're off by one.
 		#
 		#
 		if a < 0 and b > 0:
 		if a < 0 and b > 0:
-			expected_result = int(-(abs(a) / b))
+			expected_result = int(-(abs(a) // b))
 		elif b < 0 and a > 0:
 		elif b < 0 and a > 0:
-			expected_result = int(-(a / abs((b))))
+			expected_result = int(-(a // abs((b))))
 		else:
 		else:
 			expected_result = a // b if b != 0 else None
 			expected_result = a // b if b != 0 else None
-	return test("test_div_two", res, [sa_c, sb_c], expected_error, expected_result)
+	return test("test_div", res, [sa_c, sb_c], expected_error, expected_result)
 
 
 
 
 def test_log(a = 0, base = 0, expected_error = Error.Okay):
 def test_log(a = 0, base = 0, expected_error = Error.Okay):
@@ -212,7 +224,7 @@ def test_log(a = 0, base = 0, expected_error = Error.Okay):
 
 
 	expected_result = None
 	expected_result = None
 	if expected_error == Error.Okay:
 	if expected_error == Error.Okay:
-		expected_result = int(log(a, base))
+		expected_result = int(math.log(a, base))
 	return test("test_log", res, args, expected_error, expected_result)
 	return test("test_log", res, args, expected_error, expected_result)
 
 
 def test_pow(base = 0, power = 0, expected_error = Error.Okay):
 def test_pow(base = 0, power = 0, expected_error = Error.Okay):
@@ -225,9 +237,108 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
 		if power < 0:
 		if power < 0:
 			expected_result = 0
 			expected_result = 0
 		else:
 		else:
-			expected_result = int(base**power)
+			# NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation.
+			#               Use built-in `pow` or `a**b` instead.
+			expected_result = pow(base, power)
 	return test("test_pow", res, args, expected_error, expected_result)
 	return test("test_pow", res, args, expected_error, expected_result)
 
 
+def test_sqrt(number = 0, expected_error = Error.Okay):
+	args  = [str(number)]
+	sa_c  = args[0].encode('utf-8')
+	try:
+		res   = int_sqrt(sa_c)
+	except:
+		print("sqrt:", number)
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		if number < 0:
+			expected_result = 0
+		else:
+			expected_result = int(math.isqrt(number))
+	return test("test_sqrt", res, args, expected_error, expected_result)
+
+def root_n(number, root):
+	u, s = number, number + 1
+	while u < s:
+		s = u
+		t = (root-1) * s + number // pow(s, root - 1)
+		u = t // root
+	return s
+
+def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
+	args  = [str(a), digits]
+	sa_c  = args[0].encode('utf-8')
+	res   = int_shl_digit(sa_c, digits)
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = a << (digits * 60)
+	return test("test_shl_digit", res, args, expected_error, expected_result)
+
+def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay):
+	args  = [str(a), digits]
+	sa_c  = args[0].encode('utf-8')
+	try:
+		res   = int_shr_digit(sa_c, digits)
+	except:
+		print("int_shr_digit", a, digits)
+		exit()
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		if a < 0:
+			# Don't pass negative numbers. We have a shr_signed.
+			return False
+		else:
+			expected_result = a >> (digits * 60)
+		
+	return test("test_shr_digit", res, args, expected_error, expected_result)
+
+def test_shl(a = 0, bits = 0, expected_error = Error.Okay):
+	args  = [str(a), bits]
+	sa_c  = args[0].encode('utf-8')
+	res   = int_shl(sa_c, bits)
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = a << bits
+	return test("test_shl", res, args, expected_error, expected_result)
+
+def test_shr(a = 0, bits = 0, expected_error = Error.Okay):
+	args  = [str(a), bits]
+	sa_c  = args[0].encode('utf-8')
+	try:
+		res   = int_shr(sa_c, bits)
+	except:
+		print("int_shr", a, bits)
+		exit()
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		if a < 0:
+			# Don't pass negative numbers. We have a shr_signed.
+			return False
+		else:
+			expected_result = a >> bits
+		
+	return test("test_shr", res, args, expected_error, expected_result)
+
+def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
+	args  = [str(a), bits]
+	sa_c  = args[0].encode('utf-8')
+	try:
+		res   = int_shr_signed(sa_c, bits)
+	except:
+		print("int_shr_signed", a, bits)
+		exit()
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = a >> bits
+		
+	return test("test_shr_signed", res, args, expected_error, expected_result)
+
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 #
 #
 # The last two arguments in tests are the expected error and expected result.
 # The last two arguments in tests are the expected error and expected result.
@@ -237,19 +348,20 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
 # You can override that by supplying an expected result as the last argument instead.
 # You can override that by supplying an expected result as the last argument instead.
 
 
 TESTS = {
 TESTS = {
-	test_add_two: [
+	test_add: [
 		[ 1234,   5432],
 		[ 1234,   5432],
 	],
 	],
-	test_sub_two: [
+	test_sub: [
 		[ 1234,   5432],
 		[ 1234,   5432],
 	],
 	],
-	test_mul_two: [
+	test_mul: [
 		[ 1234,   5432],
 		[ 1234,   5432],
 		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7]
 		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7]
 	],
 	],
-	test_div_two: [
+	test_div: [
 		[ 54321,	12345],
 		[ 54321,	12345],
 		[ 55431,		0, Error.Division_by_Zero],
 		[ 55431,		0, Error.Division_by_Zero],
+		[ 12980742146337069150589594264770969721, 4611686018427387904 ],
 	],
 	],
 	test_log: [
 	test_log: [
 		[ 3192,			1, Error.Invalid_Argument],
 		[ 3192,			1, Error.Invalid_Argument],
@@ -260,29 +372,87 @@ TESTS = {
 	test_pow: [
 	test_pow: [
 		[ 0,  -1, Error.Math_Domain_Error ], # Math
 		[ 0,  -1, Error.Math_Domain_Error ], # Math
 		[ 0,   0 ], # 1
 		[ 0,   0 ], # 1
-	 	[ 0,   2 ], # 0
-	 	[ 42, -1,], # 0
-	 	[ 42,  1 ], # 1
-	 	[ 42,  0 ], # 42
-	 	[ 42,  2 ], # 42*42
-
-
+		[ 0,   2 ], # 0
+		[ 42, -1,], # 0
+		[ 42,  1 ], # 1
+		[ 42,  0 ], # 42
+		[ 42,  2 ], # 42*42
+	],
+	test_sqrt: [
+		[  -1, Error.Invalid_Argument, ],
+		[  42, Error.Okay, ],
+		[  12345678901234567890, Error.Okay, ],
+		[  1298074214633706907132624082305024, Error.Okay, ],
+	],
+	test_shl_digit: [
+		[ 3192,			1 ],
+		[ 1298074214633706907132624082305024, 2 ],
+		[ 1024,			3 ],
+	],
+	test_shr_digit: [
+		[ 3680125442705055547392, 1 ],
+		[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
+		[ 219504133884436710204395031992179571, 2 ],
+	],
+	test_shl: [
+		[ 3192,			1 ],
+		[ 1298074214633706907132624082305024, 2 ],
+		[ 1024,			3 ],
+	],
+	test_shr: [
+		[ 3680125442705055547392, 1 ],
+		[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
+		[ 219504133884436710204395031992179571, 2 ],
 	],
 	],
+	test_shr_signed: [
+		[ -611105530635358368578155082258244262, 12 ],
+		[ -149195686190273039203651143129455, 12 ],
+		[ 611105530635358368578155082258244262, 12 ],
+		[ 149195686190273039203651143129455, 12 ],
+	]
 }
 }
 
 
 total_passes   = 0
 total_passes   = 0
 total_failures = 0
 total_failures = 0
 
 
+#
+# test_shr_signed also tests shr, so we're not going to test shr randomly.
+#
 RANDOM_TESTS = [
 RANDOM_TESTS = [
-	test_add_two, test_sub_two, test_mul_two, test_div_two,
-	test_log, test_pow,
+	test_add, test_sub, test_mul, test_div,
+	test_log, test_pow, test_sqrt,
+	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 ]
 ]
+SKIP_LARGE   = [test_pow]
+SKIP_LARGEST = []
 
 
 # Untimed warmup.
 # Untimed warmup.
 for test_proc in TESTS:
 for test_proc in TESTS:
 	for t in TESTS[test_proc]:
 	for t in TESTS[test_proc]:
 		res   = test_proc(*t)
 		res   = test_proc(*t)
 
 
+
+def isqrt(x):
+	n = int(x)
+	a, b = divmod(n.bit_length(), 2)
+	print("isqrt({}), a: {}, b: {}". format(n, a, b))
+	x = 2**(a+b)
+	print("initial: {}".format(x))
+	i = 0
+	while True:
+		# y = (x + n//x)//2
+		t1 = n // x
+		t2 = x + t1
+		t3 = t2 // 2
+		y = (x + n//x)//2
+
+		i += 1
+		print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
+
+		if y >= x:
+			return x
+		x = y
+
 if __name__ == '__main__':
 if __name__ == '__main__':
 	print("---- math/big tests ----")
 	print("---- math/big tests ----")
 	print()
 	print()
@@ -317,7 +487,8 @@ if __name__ == '__main__':
 		print()
 		print()
 
 
 		for test_proc in RANDOM_TESTS:
 		for test_proc in RANDOM_TESTS:
-			if test_proc == test_pow and BITS > 1_200: continue
+			if BITS >  1_200 and test_proc in SKIP_LARGE: continue
+			if BITS >  4_096 and test_proc in SKIP_LARGEST: continue
 
 
 			count_pass = 0
 			count_pass = 0
 			count_fail = 0
 			count_fail = 0
@@ -331,8 +502,10 @@ if __name__ == '__main__':
 				a = randint(-(1 << BITS), 1 << BITS)
 				a = randint(-(1 << BITS), 1 << BITS)
 				b = randint(-(1 << BITS), 1 << BITS)
 				b = randint(-(1 << BITS), 1 << BITS)
 
 
-				if test_proc == test_div_two:
+				if test_proc == test_div:
 					# We've already tested division by zero above.
 					# We've already tested division by zero above.
+					bits = int(BITS * 0.6)
+					b = randint(-(1 << bits), 1 << bits)
 					if b == 0:
 					if b == 0:
 						b == 42
 						b == 42
 				elif test_proc == test_log:
 				elif test_proc == test_log:
@@ -341,6 +514,18 @@ if __name__ == '__main__':
 					b = randint(2, 1 << 60)
 					b = randint(2, 1 << 60)
 				elif test_proc == test_pow:
 				elif test_proc == test_pow:
 					b = randint(1, 10)
 					b = randint(1, 10)
+				elif test_proc == test_sqrt:
+					a = randint(1, 1 << BITS)
+					b = Error.Okay
+				elif test_proc == test_shl_digit:
+					b = randint(0, 10);
+				elif test_proc == test_shr_digit:
+					a = abs(a)
+					b = randint(0, 10);
+				elif test_proc == test_shl:
+					b = randint(0, min(BITS, 120));
+				elif test_proc == test_shr_signed:
+					b = randint(0, min(BITS, 120));
 				else:
 				else:
 					b = randint(0, 1 << BITS)					
 					b = randint(0, 1 << BITS)