Browse Source

big: Add `_private_int_sqr_comba`.

Jeroen van Rijn 4 years ago
parent
commit
6c681b258c

+ 18 - 5
core/math/big/example.odin

@@ -52,10 +52,12 @@ FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS,
 }
 }
 
 
 print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline := true, print_extra_info := false) {
 print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline := true, print_extra_info := false) {
-	as, err := itoa(a, base);
+	assert_if_nil(a);
 
 
+	as, err := itoa(a, base);
 	defer delete(as);
 	defer delete(as);
-	cb, _ := count_bits(a);
+
+	cb := internal_count_bits(a);
 	if print_name {
 	if print_name {
 		fmt.printf("%v", name);
 		fmt.printf("%v", name);
 	}
 	}
@@ -64,7 +66,7 @@ print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline
 	}
 	}
 	fmt.printf("%v", as);
 	fmt.printf("%v", as);
 	if print_extra_info {
 	if print_extra_info {
-		fmt.printf(" (base: %v, bits used: %v, flags: %v)", base, cb, a.flags);
+		fmt.printf(" (base: %v, bits: %v (digits: %v), flags: %v)", base, cb, a.used, a.flags);
 	}
 	}
 	if newline {
 	if newline {
 		fmt.println();
 		fmt.println();
@@ -75,8 +77,19 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 	defer destroy(a, b, c, d, e, f);
 
 
-	err := set(a, 1);
-	fmt.printf("err: %v\n", err);
+	err: Error;
+	bs: string;
+
+	if err = factorial(a, 500); err != nil { fmt.printf("factorial err: %v\n", err); return; }
+	{
+		SCOPED_TIMING(.sqr);
+		if err = sqr(b, a);     err != nil { fmt.printf("sqr err: %v\n", err); return; }
+	}
+
+	bs, err = itoa(b, 10);
+	defer delete(bs);
+
+	assert(bs[:50] == "14887338741396604108836218987068397819515734169330");
 }
 }
 
 
 main :: proc() {
 main :: proc() {

+ 0 - 2
core/math/big/helpers.odin

@@ -9,10 +9,8 @@ package big
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 */
 */
 
 
-import "core:mem"
 import "core:intrinsics"
 import "core:intrinsics"
 import rnd "core:math/rand"
 import rnd "core:math/rand"
-import "core:fmt"
 
 
 /*
 /*
 	TODO: Int.flags and Constants like ONE, NAN, etc, are not yet properly handled everywhere.
 	TODO: Int.flags and Constants like ONE, NAN, etc, are not yet properly handled everywhere.

+ 13 - 6
core/math/big/internal.odin

@@ -36,6 +36,8 @@ import "core:mem"
 import "core:intrinsics"
 import "core:intrinsics"
 import rnd "core:math/rand"
 import rnd "core:math/rand"
 
 
+//import "core:fmt"
+
 /*
 /*
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
 
 
@@ -624,16 +626,21 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 		/*
 		/*
 			Do we need to square?
 			Do we need to square?
 		*/
 		*/
-		if        false && src.used >= SQR_TOOM_CUTOFF {
+		if src.used >= SQR_TOOM_CUTOFF {
 			/* Use Toom-Cook? */
 			/* Use Toom-Cook? */
 			// err = s_mp_sqr_toom(a, c);
 			// err = s_mp_sqr_toom(a, c);
-		} else if false && src.used >= SQR_KARATSUBA_CUTOFF {
+			// fmt.printf("_private_int_sqr_toom: %v\n", src.used);
+			err = #force_inline _private_int_sqr(dest, src);
+		} else if src.used >= SQR_KARATSUBA_CUTOFF {
 			/* Karatsuba? */
 			/* Karatsuba? */
 			// err = s_mp_sqr_karatsuba(a, c);
 			// err = s_mp_sqr_karatsuba(a, c);
-		} else if false && ((src.used * 2) + 1) < _WARRAY &&
-		                   src.used < (_MAX_COMBA / 2) {
-			/* Fast comba? */
-			// err = s_mp_sqr_comba(a, c);
+			// fmt.printf("_private_int_sqr_karatsuba: %v\n", src.used);
+			err = #force_inline _private_int_sqr(dest, src);
+		} else if ((src.used * 2) + 1) < _WARRAY && src.used < (_MAX_COMBA / 2) {
+			/*
+				Fast comba?
+			*/
+			err = #force_inline _private_int_sqr_comba(dest, src);
 		} else {
 		} else {
 			err = #force_inline _private_int_sqr(dest, src);
 			err = #force_inline _private_int_sqr(dest, src);
 		}
 		}

+ 0 - 2
core/math/big/logical.odin

@@ -11,8 +11,6 @@ package big
 	This file contains logical operations like `and`, `or` and `xor`.
 	This file contains logical operations like `and`, `or` and `xor`.
 */
 */
 
 
-import "core:mem"
-
 /*
 /*
 	The `and`, `or` and `xor` binops differ in two lines only.
 	The `and`, `or` and `xor` binops differ in two lines only.
 	We could handle those with a switch, but that adds overhead.
 	We could handle those with a switch, but that adds overhead.

+ 99 - 0
core/math/big/private.odin

@@ -255,6 +255,105 @@ _private_int_sqr :: proc(dest, src: ^Int, allocator := context.allocator) -> (er
 	return err;
 	return err;
 }
 }
 
 
+/*
+	The jist of squaring...
+	You do like mult except the offset of the tmpx [one that starts closer to zero] can't equal the offset of tmpy.
+	So basically you set up iy like before then you min it with (ty-tx) so that it never happens.
+	You double all those you add in the inner loop. After that loop you do the squares and add them in.
+
+	Assumes `dest` and `src` not to be `nil` and `src` to have been initialized.	
+*/
+_private_int_sqr_comba :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	W: [_WARRAY]DIGIT = ---;
+
+	/*
+		Grow the destination as required.
+	*/
+	pa := uint(src.used) + uint(src.used);
+	if err = internal_grow(dest, int(pa)); err != nil { return err; }
+
+	/*
+		Number of output digits to produce.
+	*/
+	W1 := _WORD(0);
+	_W  : _WORD = ---;
+	ix := uint(0);
+
+	#no_bounds_check for ; ix < pa; ix += 1 {
+		/*
+			Clear counter.
+		*/
+		_W = {};
+
+		/*
+			Get offsets into the two bignums.
+		*/
+		ty := min(uint(src.used) - 1, ix);
+		tx := ix - ty;
+
+		/*
+			This is the number of times the loop will iterate,
+			essentially while (tx++ < a->used && ty-- >= 0) { ... }
+		*/
+		iy := min(uint(src.used) - tx, ty + 1);
+
+		/*
+			Now for squaring, tx can never equal ty.
+			We halve the distance since they approach at a rate of 2x,
+			and we have to round because odd cases need to be executed.
+		*/
+		iy = min(iy, ((ty - tx) + 1) >> 1 );
+
+		/*
+			Execute loop.
+		*/
+		#no_bounds_check for iz := uint(0); iz < iy; iz += 1 {
+			_W += _WORD(src.digit[tx + iz]) * _WORD(src.digit[ty - iz]);
+		}
+
+		/*
+			Double the inner product and add carry.
+		*/
+		_W = _W + _W + W1;
+
+		/*
+			Even columns have the square term in them.
+		*/
+		if ix & 1 == 0 {
+			_W += _WORD(src.digit[ix >> 1]) * _WORD(src.digit[ix >> 1]);
+		}
+
+		/*
+			Store it.
+		*/
+		W[ix] = DIGIT(_W & _WORD(_MASK));
+
+		/*
+			Make next carry.
+		*/
+		W1 = _W >> _DIGIT_BITS;
+	}
+
+	/*
+		Setup dest.
+	*/
+	old_used := dest.used;
+	dest.used = src.used + src.used;
+
+	#no_bounds_check for ix = 0; ix < pa; ix += 1 {
+		dest.digit[ix] = W[ix] & _MASK;
+	}
+
+	/*
+		Clear unused digits [that existed in the old copy of dest].
+	*/
+	internal_zero_unused(dest, old_used);
+
+	return internal_clamp(dest);
+}
+
 /*
 /*
 	Divide by three (based on routine from MPI and the GMP manual).
 	Divide by three (based on routine from MPI and the GMP manual).
 */
 */

+ 4 - 0
core/math/big/radix.odin

@@ -9,6 +9,10 @@ package big
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 
 
 	This file contains radix conversions, `string_to_int` (atoi) and `int_to_string` (itoa).
 	This file contains radix conversions, `string_to_int` (atoi) and `int_to_string` (itoa).
+
+	TODO:
+		- Use Barrett reduction for non-powers-of-two.
+		- Also look at extracting and splatting several digits at once.
 */
 */
 
 
 import "core:intrinsics"
 import "core:intrinsics"

+ 16 - 0
core/math/big/test.odin

@@ -94,6 +94,22 @@ PyRes :: struct {
 	return PyRes{res = r, err = nil};
 	return PyRes{res = r, err = nil};
 }
 }
 
 
+@export test_sqr :: proc "c" (a: cstring) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	aa, square := &Int{}, &Int{};
+	defer internal_destroy(aa, square);
+
+	if err = atoi(aa, string(a), 16); err != nil { return PyRes{res=":sqr:atoi(a):", err=err}; }
+	if err = #force_inline internal_sqr(square, aa);        err != nil { return PyRes{res=":sqr:sqr(square,a):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(square, 16, context.temp_allocator);
+	if err != nil { return PyRes{res=":sqr:itoa(square):", err=err}; }
+	return PyRes{res = r, err = nil};
+}
+
 /*
 /*
 	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
 	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
 */
 */

+ 29 - 9
core/math/big/test.py

@@ -124,15 +124,16 @@ initialize_constants()
 
 
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 
 
-add  = load(l.test_add, [c_char_p, c_char_p], Res)
-sub  = load(l.test_sub, [c_char_p, c_char_p], Res)
-mul  = load(l.test_mul, [c_char_p, c_char_p], Res)
-div  = load(l.test_div, [c_char_p, c_char_p], Res)
+add        = load(l.test_add,    [c_char_p, c_char_p],   Res)
+sub        = load(l.test_sub,    [c_char_p, c_char_p],   Res)
+mul        = load(l.test_mul,    [c_char_p, c_char_p],   Res)
+sqr        = load(l.test_sqr,    [c_char_p          ],   Res)
+div        = load(l.test_div,    [c_char_p, c_char_p],   Res)
 
 
 # Powers and such
 # Powers and such
-int_log    = load(l.test_log,  [c_char_p, c_longlong], Res)
-int_pow    = load(l.test_pow,  [c_char_p, c_longlong], Res)
-int_sqrt   = load(l.test_sqrt, [c_char_p], Res)
+int_log    = load(l.test_log,    [c_char_p, c_longlong], Res)
+int_pow    = load(l.test_pow,    [c_char_p, c_longlong], Res)
+int_sqrt   = load(l.test_sqrt,   [c_char_p            ], Res)
 int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
 int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
 
 
 # Logical operations
 # Logical operations
@@ -218,6 +219,20 @@ def test_mul(a = 0, b = 0, expected_error = Error.Okay):
 		expected_result = a * b
 		expected_result = a * b
 	return test("test_mul", res, [a, b], expected_error, expected_result)
 	return test("test_mul", res, [a, b], expected_error, expected_result)
 
 
+def test_sqr(a = 0, b = 0, expected_error = Error.Okay):
+	args = [arg_to_odin(a)]
+	try:
+		res  = sqr(*args)
+	except OSError as e:
+		print("{} while trying to square {} x {}.".format(e, a))
+		if EXIT_ON_FAIL: exit(3)
+		return False
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = a * a
+	return test("test_sqr", res, [a], expected_error, expected_result)
+
 def test_div(a = 0, b = 0, expected_error = Error.Okay):
 def test_div(a = 0, b = 0, expected_error = Error.Okay):
 	args = [arg_to_odin(a), arg_to_odin(b)]
 	args = [arg_to_odin(a), arg_to_odin(b)]
 	res  = div(*args)
 	res  = div(*args)
@@ -390,7 +405,11 @@ TESTS = {
 	],
 	],
 	test_mul: [
 	test_mul: [
 		[ 1234,   5432],
 		[ 1234,   5432],
-		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7],
+		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7 ],
+	],
+	test_sqr: [
+		[ 5432],
+		[ 0xd3b4e926aaba3040e1c12b5ea553b5 ],
 	],
 	],
 	test_div: [
 	test_div: [
 		[ 54321,	12345],
 		[ 54321,	12345],
@@ -482,7 +501,7 @@ total_failures = 0
 # test_shr_signed also tests shr, so we're not going to test shr randomly.
 # test_shr_signed also tests shr, so we're not going to test shr randomly.
 #
 #
 RANDOM_TESTS = [
 RANDOM_TESTS = [
-	test_add, test_sub, test_mul, test_div,
+	test_add, test_sub, test_mul, test_sqr, test_div,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 	test_gcd, test_lcm,
 	test_gcd, test_lcm,
@@ -592,6 +611,7 @@ if __name__ == '__main__':
 				start = time.perf_counter()
 				start = time.perf_counter()
 				res   = test_proc(a, b)
 				res   = test_proc(a, b)
 				diff  = time.perf_counter() - start
 				diff  = time.perf_counter() - start
+
 				TOTAL_TIME += diff
 				TOTAL_TIME += diff
 
 
 				if test_proc not in TIMINGS:
 				if test_proc not in TIMINGS:

+ 1 - 0
core/math/big/tune.odin

@@ -21,6 +21,7 @@ Category :: enum {
 	choose,
 	choose,
 	lsb,
 	lsb,
 	ctz,
 	ctz,
+	sqr,
 	bitfield_extract,
 	bitfield_extract,
 };
 };