Browse Source

Add tests for `internal_int_is_square'.

Jeroen van Rijn 4 years ago
parent
commit
852643e6ba

+ 8 - 8
core/math/big/build.bat

@@ -1,9 +1,9 @@
 @echo off
 @echo off
-odin run . -vet
-: -o:size
-:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py
-: -fast-tests
-:odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests
+:odin run . -vet
+
+set TEST_ARGS=-fast-tests
+:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests %TEST_ARGS%

+ 4 - 4
core/math/big/example.odin

@@ -206,13 +206,13 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 	defer destroy(a, b, c, d, e, f);
 
 
-	set(a, 512);
-	sqr(b, a);
+	foo := "92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201";
+
+	set(a, foo);
 
 
 	print("a: ", a);
 	print("a: ", a);
-	print("b: ", b);
 
 
-	is_sqr, _ := internal_int_is_square(b);
+	is_sqr, _ := internal_int_is_square(a);
 	fmt.printf("is_square: %v\n", is_sqr);
 	fmt.printf("is_square: %v\n", is_sqr);
 }
 }
 
 

+ 24 - 24
core/math/big/internal.odin

@@ -670,7 +670,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 		digits   := src.used + multiplier.used + 1;
 		digits   := src.used + multiplier.used + 1;
 
 
 		if        false &&  min_used     >= MUL_KARATSUBA_CUTOFF &&
 		if        false &&  min_used     >= MUL_KARATSUBA_CUTOFF &&
-						    max_used / 2 >= MUL_KARATSUBA_CUTOFF &&
+							max_used / 2 >= MUL_KARATSUBA_CUTOFF &&
 			/*
 			/*
 				Not much effect was observed below a ratio of 1:2, but again: YMMV.
 				Not much effect was observed below a ratio of 1:2, but again: YMMV.
 			*/
 			*/
@@ -1111,10 +1111,10 @@ internal_compare :: proc { internal_int_compare, internal_int_compare_digit, };
 internal_cmp :: internal_compare;
 internal_cmp :: internal_compare;
 
 
 /*
 /*
-    Compare an `Int` to an unsigned number upto `DIGIT & _MASK`.
-    Returns -1 if `a` < `b`, 0 if `a` == `b` and 1 if `b` > `a`.
+	Compare an `Int` to an unsigned number upto `DIGIT & _MASK`.
+	Returns -1 if `a` < `b`, 0 if `a` == `b` and 1 if `b` > `a`.
 
 
-    Expects: `a` and `b` both to be valid `Int`s, i.e. initialized and not `nil`.
+	Expects: `a` and `b` both to be valid `Int`s, i.e. initialized and not `nil`.
 */
 */
 internal_int_compare_digit :: #force_inline proc(a: ^Int, b: DIGIT) -> (comparison: int) {
 internal_int_compare_digit :: #force_inline proc(a: ^Int, b: DIGIT) -> (comparison: int) {
 	a_is_negative := #force_inline internal_is_negative(a);
 	a_is_negative := #force_inline internal_is_negative(a);
@@ -1170,7 +1170,7 @@ internal_int_compare_magnitude :: #force_inline proc(a, b: ^Int) -> (comparison:
 		}
 		}
 	}
 	}
 
 
-   	return 0;
+	return 0;
 }
 }
 internal_compare_magnitude :: proc { internal_int_compare_magnitude, };
 internal_compare_magnitude :: proc { internal_int_compare_magnitude, };
 internal_cmp_mag :: internal_compare_magnitude;
 internal_cmp_mag :: internal_compare_magnitude;
@@ -1202,7 +1202,7 @@ internal_int_is_square :: proc(a: ^Int, allocator := context.allocator) -> (squa
 	*/
 	*/
 	c: DIGIT;
 	c: DIGIT;
 	c, err = internal_mod(a, 105);
 	c, err = internal_mod(a, 105);
-	if _private_int_rem_128[c] == 1                                  { return; }
+	if _private_int_rem_105[c] == 1                                  { return; }
 
 
 	t := &Int{};
 	t := &Int{};
 	defer destroy(t);
 	defer destroy(t);
@@ -2366,12 +2366,12 @@ internal_int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int, all
 			/*
 			/*
 				Shift the current word and mix in the carry bits from the previous word.
 				Shift the current word and mix in the carry bits from the previous word.
 			*/
 			*/
-	        quotient.digit[x] = (quotient.digit[x] >> uint(bits)) | (carry << shift);
+			quotient.digit[x] = (quotient.digit[x] >> uint(bits)) | (carry << shift);
 
 
-	        /*
-	        	Update carry from forward carry.
-	        */
-	        carry = fwd_carry;
+			/*
+				Update carry from forward carry.
+			*/
+			carry = fwd_carry;
 		}
 		}
 
 
 	}
 	}
@@ -2397,17 +2397,17 @@ internal_int_shr_digit :: proc(quotient: ^Int, digits: int, allocator := context
 	*/
 	*/
 	if digits > quotient.used { return internal_zero(quotient); }
 	if digits > quotient.used { return internal_zero(quotient); }
 
 
-   	/*
+	/*
 		Much like `int_shl_digit`, this is implemented using a sliding window,
 		Much like `int_shl_digit`, this is implemented using a sliding window,
 		except the window goes the other way around.
 		except the window goes the other way around.
 
 
 		b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
 		b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
-		            /\                   |      ---->
-		             \-------------------/      ---->
-    */
+					/\                   |      ---->
+					 \-------------------/      ---->
+	*/
 
 
 	#no_bounds_check for x := 0; x < (quotient.used - digits); x += 1 {
 	#no_bounds_check for x := 0; x < (quotient.used - digits); x += 1 {
-    	quotient.digit[x] = quotient.digit[x + digits];
+		quotient.digit[x] = quotient.digit[x + digits];
 	}
 	}
 	quotient.used -= digits;
 	quotient.used -= digits;
 	internal_zero_unused(quotient);
 	internal_zero_unused(quotient);
@@ -2511,14 +2511,14 @@ internal_int_shl_digit :: proc(quotient: ^Int, digits: int, allocator := context
 	/*
 	/*
 		Much like `int_shr_digit`, this is implemented using a sliding window,
 		Much like `int_shr_digit`, this is implemented using a sliding window,
 		except the window goes the other way around.
 		except the window goes the other way around.
-    */
-    #no_bounds_check for x := quotient.used; x > 0; x -= 1 {
-    	quotient.digit[x+digits-1] = quotient.digit[x-1];
-    }
-
-   	quotient.used += digits;
-    mem.zero_slice(quotient.digit[:digits]);
-    return nil;
+	*/
+	#no_bounds_check for x := quotient.used; x > 0; x -= 1 {
+		quotient.digit[x+digits-1] = quotient.digit[x-1];
+	}
+
+	quotient.used += digits;
+	mem.zero_slice(quotient.digit[:digits]);
+	return nil;
 }
 }
 internal_shl_digit :: proc { internal_int_shl_digit, };
 internal_shl_digit :: proc { internal_int_shl_digit, };
 
 

+ 6 - 3
core/math/big/private.odin

@@ -2007,7 +2007,7 @@ _private_copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0)) ->
 	Tables used by `internal_*` and `_*`.
 	Tables used by `internal_*` and `_*`.
 */
 */
 
 
-_private_int_rem_128 := [128]DIGIT{
+_private_int_rem_128 := [?]DIGIT{
 	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
@@ -2017,8 +2017,9 @@ _private_int_rem_128 := [128]DIGIT{
 	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
 };
 };
+#assert(128 * size_of(DIGIT) == size_of(_private_int_rem_128));
 
 
-_private_int_rem_105 := [105]DIGIT{
+_private_int_rem_105 := [?]DIGIT{
 	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
 	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
 	0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
 	0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
 	0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1,
 	0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1,
@@ -2027,8 +2028,9 @@ _private_int_rem_105 := [105]DIGIT{
 	1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1,
 	1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1,
 	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
 	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
 };
 };
+#assert(105 * size_of(DIGIT) == size_of(_private_int_rem_105));
 
 
-_private_prime_table := []DIGIT{
+_private_prime_table := [?]DIGIT{
 	0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
 	0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
 	0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
 	0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
 	0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
 	0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
@@ -2065,6 +2067,7 @@ _private_prime_table := []DIGIT{
 	0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
 	0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
 	0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653,
 	0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653,
 };
 };
+#assert(256 * size_of(DIGIT) == size_of(_private_prime_table));
 
 
 when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
 when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
 	_factorial_table := [35]_WORD{
 	_factorial_table := [35]_WORD{

+ 1 - 1
core/math/big/radix.odin

@@ -235,7 +235,7 @@ int_to_cstring :: int_itoa_cstring;
 /*
 /*
 	Read a string [ASCII] in a given radix.
 	Read a string [ASCII] in a given radix.
 */
 */
-int_atoi :: proc(res: ^Int, input: string, radix: i8, allocator := context.allocator) -> (err: Error) {
+int_atoi :: proc(res: ^Int, input: string, radix := i8(10), allocator := context.allocator) -> (err: Error) {
 	assert_if_nil(res);
 	assert_if_nil(res);
 	input := input;
 	input := input;
 	context.allocator = allocator;
 	context.allocator = allocator;

+ 19 - 0
core/math/big/test.odin

@@ -369,3 +369,22 @@ PyRes :: struct {
 	return PyRes{res = r, err = nil};
 	return PyRes{res = r, err = nil};
 }
 }
 
 
+/*
+	dest = lcm(a, b)
+*/
+@export test_is_square :: proc "c" (a: cstring) -> (res: PyRes) {
+	context = runtime.default_context();
+	err:    Error;
+	square: bool;
+
+	ai := &Int{};
+	defer internal_destroy(ai);
+
+	if err = atoi(ai, string(a), 16); err != nil { return PyRes{res=":is_square:atoi(a):", err=err}; }
+	if square, err = #force_inline internal_int_is_square(ai); err != nil { return PyRes{res=":is_square:is_square(a):", err=err}; }
+
+	if square {
+		return PyRes{"True", nil};
+	}
+	return PyRes{"False", nil};
+}

+ 29 - 12
core/math/big/test.py

@@ -160,11 +160,11 @@ print("initialize_constants: ", initialize_constants())
 
 
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 
 
-add        =     load(l.test_add,        [c_char_p, c_char_p],   Res)
-sub        =     load(l.test_sub,        [c_char_p, c_char_p],   Res)
-mul        =     load(l.test_mul,        [c_char_p, c_char_p],   Res)
-sqr        =     load(l.test_sqr,        [c_char_p          ],   Res)
-div        =     load(l.test_div,        [c_char_p, c_char_p],   Res)
+add        =     load(l.test_add,        [c_char_p, c_char_p  ], Res)
+sub        =     load(l.test_sub,        [c_char_p, c_char_p  ], Res)
+mul        =     load(l.test_mul,        [c_char_p, c_char_p  ], Res)
+sqr        =     load(l.test_sqr,        [c_char_p            ], Res)
+div        =     load(l.test_div,        [c_char_p, c_char_p  ], Res)
 
 
 # Powers and such
 # Powers and such
 int_log    =     load(l.test_log,        [c_char_p, c_longlong], Res)
 int_log    =     load(l.test_log,        [c_char_p, c_longlong], Res)
@@ -179,9 +179,11 @@ int_shl        = load(l.test_shl,        [c_char_p, c_longlong], Res)
 int_shr        = load(l.test_shr,        [c_char_p, c_longlong], Res)
 int_shr        = load(l.test_shr,        [c_char_p, c_longlong], Res)
 int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
 int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
 
 
-int_factorial  = load(l.test_factorial,  [c_uint64], Res)
-int_gcd        = load(l.test_gcd,        [c_char_p, c_char_p], Res)
-int_lcm        = load(l.test_lcm,        [c_char_p, c_char_p], Res)
+int_factorial  = load(l.test_factorial,  [c_uint64            ], Res)
+int_gcd        = load(l.test_gcd,        [c_char_p, c_char_p  ], Res)
+int_lcm        = load(l.test_lcm,        [c_char_p, c_char_p  ], Res)
+
+is_square      = load(l.test_is_square,  [c_char_p            ], Res)
 
 
 def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = "", radix=16):
 def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = "", radix=16):
 	passed = True
 	passed = True
@@ -428,6 +430,15 @@ def test_lcm(a = 0, b = 0, expected_error = Error.Okay):
 		
 		
 	return test("test_lcm", res, [a, b], expected_error, expected_result)
 	return test("test_lcm", res, [a, b], expected_error, expected_result)
 
 
+def test_is_square(a = 0, b = 0, expected_error = Error.Okay):
+	args  = [arg_to_odin(a)]
+	res   = is_square(*args)
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = str(math.isqrt(a) ** 2 == a) if a > 0 else "False"
+		
+	return test("test_is_square", res, [a], expected_error, expected_result)
+
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 #
 #
 # The last two arguments in tests are the expected error and expected result.
 # The last two arguments in tests are the expected error and expected result.
@@ -527,6 +538,10 @@ TESTS = {
 		[   0, 0,  ],
 		[   0, 0,  ],
 		[   0, 125,],
 		[   0, 125,],
 	],
 	],
+	test_is_square: [
+		[ 12, ],
+		[ 92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201, ]
+	],
 }
 }
 
 
 if not args.fast_tests:
 if not args.fast_tests:
@@ -545,7 +560,7 @@ RANDOM_TESTS = [
 	test_add, test_sub, test_mul, test_sqr, test_div,
 	test_add, test_sub, test_mul, test_sqr, test_div,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
-	test_gcd, test_lcm,
+	test_gcd, test_lcm, test_is_square,
 ]
 ]
 SKIP_LARGE   = [
 SKIP_LARGE   = [
 	test_pow, test_root_n, # test_gcd,
 	test_pow, test_root_n, # test_gcd,
@@ -648,11 +663,13 @@ if __name__ == '__main__':
 					a = abs(a)
 					a = abs(a)
 					b = randint(0, 10);
 					b = randint(0, 10);
 				elif test_proc == test_shl:
 				elif test_proc == test_shl:
-					b = randint(0, min(BITS, 120));
+					b = randint(0, min(BITS, 120))
 				elif test_proc == test_shr_signed:
 				elif test_proc == test_shr_signed:
-					b = randint(0, min(BITS, 120));
+					b = randint(0, min(BITS, 120))
+				elif test_proc == test_is_square:
+					a = randint(0, 1 << BITS)
 				else:
 				else:
-					b = randint(0, 1 << BITS)					
+					b = randint(0, 1 << BITS)
 
 
 				res = None
 				res = None