浏览代码

big: Add `_private_int_sqr_comba`.

Jeroen van Rijn 4 年之前
父节点
当前提交
6c681b258c

+ 18 - 5
core/math/big/example.odin

@@ -52,10 +52,12 @@ FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS,
 }
 
 print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline := true, print_extra_info := false) {
-	as, err := itoa(a, base);
+	assert_if_nil(a);
 
+	as, err := itoa(a, base);
 	defer delete(as);
-	cb, _ := count_bits(a);
+
+	cb := internal_count_bits(a);
 	if print_name {
 		fmt.printf("%v", name);
 	}
@@ -64,7 +66,7 @@ print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline
 	}
 	fmt.printf("%v", as);
 	if print_extra_info {
-		fmt.printf(" (base: %v, bits used: %v, flags: %v)", base, cb, a.flags);
+		fmt.printf(" (base: %v, bits: %v (digits: %v), flags: %v)", base, cb, a.used, a.flags);
 	}
 	if newline {
 		fmt.println();
@@ -75,8 +77,19 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 
-	err := set(a, 1);
-	fmt.printf("err: %v\n", err);
+	err: Error;
+	bs: string;
+
+	if err = factorial(a, 500); err != nil { fmt.printf("factorial err: %v\n", err); return; }
+	{
+		SCOPED_TIMING(.sqr);
+		if err = sqr(b, a);     err != nil { fmt.printf("sqr err: %v\n", err); return; }
+	}
+
+	bs, err = itoa(b, 10);
+	defer delete(bs);
+
+	assert(bs[:50] == "14887338741396604108836218987068397819515734169330");
 }
 
 main :: proc() {

+ 0 - 2
core/math/big/helpers.odin

@@ -9,10 +9,8 @@ package big
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 */
 
-import "core:mem"
 import "core:intrinsics"
 import rnd "core:math/rand"
-import "core:fmt"
 
 /*
 	TODO: Int.flags and Constants like ONE, NAN, etc, are not yet properly handled everywhere.

+ 13 - 6
core/math/big/internal.odin

@@ -36,6 +36,8 @@ import "core:mem"
 import "core:intrinsics"
 import rnd "core:math/rand"
 
+//import "core:fmt"
+
 /*
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
 
@@ -624,16 +626,21 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 		/*
 			Do we need to square?
 		*/
-		if        false && src.used >= SQR_TOOM_CUTOFF {
+		if src.used >= SQR_TOOM_CUTOFF {
 			/* Use Toom-Cook? */
 			// err = s_mp_sqr_toom(a, c);
-		} else if false && src.used >= SQR_KARATSUBA_CUTOFF {
+			// fmt.printf("_private_int_sqr_toom: %v\n", src.used);
+			err = #force_inline _private_int_sqr(dest, src);
+		} else if src.used >= SQR_KARATSUBA_CUTOFF {
 			/* Karatsuba? */
 			// err = s_mp_sqr_karatsuba(a, c);
-		} else if false && ((src.used * 2) + 1) < _WARRAY &&
-		                   src.used < (_MAX_COMBA / 2) {
-			/* Fast comba? */
-			// err = s_mp_sqr_comba(a, c);
+			// fmt.printf("_private_int_sqr_karatsuba: %v\n", src.used);
+			err = #force_inline _private_int_sqr(dest, src);
+		} else if ((src.used * 2) + 1) < _WARRAY && src.used < (_MAX_COMBA / 2) {
+			/*
+				Fast comba?
+			*/
+			err = #force_inline _private_int_sqr_comba(dest, src);
 		} else {
 			err = #force_inline _private_int_sqr(dest, src);
 		}

+ 0 - 2
core/math/big/logical.odin

@@ -11,8 +11,6 @@ package big
 	This file contains logical operations like `and`, `or` and `xor`.
 */
 
-import "core:mem"
-
 /*
 	The `and`, `or` and `xor` binops differ in two lines only.
 	We could handle those with a switch, but that adds overhead.

+ 99 - 0
core/math/big/private.odin

@@ -255,6 +255,105 @@ _private_int_sqr :: proc(dest, src: ^Int, allocator := context.allocator) -> (er
 	return err;
 }
 
+/*
+	The jist of squaring...
+	You do like mult except the offset of the tmpx [one that starts closer to zero] can't equal the offset of tmpy.
+	So basically you set up iy like before then you min it with (ty-tx) so that it never happens.
+	You double all those you add in the inner loop. After that loop you do the squares and add them in.
+
+	Assumes `dest` and `src` not to be `nil` and `src` to have been initialized.	
+*/
+_private_int_sqr_comba :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	W: [_WARRAY]DIGIT = ---;
+
+	/*
+		Grow the destination as required.
+	*/
+	pa := uint(src.used) + uint(src.used);
+	if err = internal_grow(dest, int(pa)); err != nil { return err; }
+
+	/*
+		Number of output digits to produce.
+	*/
+	W1 := _WORD(0);
+	_W  : _WORD = ---;
+	ix := uint(0);
+
+	#no_bounds_check for ; ix < pa; ix += 1 {
+		/*
+			Clear counter.
+		*/
+		_W = {};
+
+		/*
+			Get offsets into the two bignums.
+		*/
+		ty := min(uint(src.used) - 1, ix);
+		tx := ix - ty;
+
+		/*
+			This is the number of times the loop will iterate,
+			essentially while (tx++ < a->used && ty-- >= 0) { ... }
+		*/
+		iy := min(uint(src.used) - tx, ty + 1);
+
+		/*
+			Now for squaring, tx can never equal ty.
+			We halve the distance since they approach at a rate of 2x,
+			and we have to round because odd cases need to be executed.
+		*/
+		iy = min(iy, ((ty - tx) + 1) >> 1 );
+
+		/*
+			Execute loop.
+		*/
+		#no_bounds_check for iz := uint(0); iz < iy; iz += 1 {
+			_W += _WORD(src.digit[tx + iz]) * _WORD(src.digit[ty - iz]);
+		}
+
+		/*
+			Double the inner product and add carry.
+		*/
+		_W = _W + _W + W1;
+
+		/*
+			Even columns have the square term in them.
+		*/
+		if ix & 1 == 0 {
+			_W += _WORD(src.digit[ix >> 1]) * _WORD(src.digit[ix >> 1]);
+		}
+
+		/*
+			Store it.
+		*/
+		W[ix] = DIGIT(_W & _WORD(_MASK));
+
+		/*
+			Make next carry.
+		*/
+		W1 = _W >> _DIGIT_BITS;
+	}
+
+	/*
+		Setup dest.
+	*/
+	old_used := dest.used;
+	dest.used = src.used + src.used;
+
+	#no_bounds_check for ix = 0; ix < pa; ix += 1 {
+		dest.digit[ix] = W[ix] & _MASK;
+	}
+
+	/*
+		Clear unused digits [that existed in the old copy of dest].
+	*/
+	internal_zero_unused(dest, old_used);
+
+	return internal_clamp(dest);
+}
+
 /*
 	Divide by three (based on routine from MPI and the GMP manual).
 */

+ 4 - 0
core/math/big/radix.odin

@@ -9,6 +9,10 @@ package big
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 
 	This file contains radix conversions, `string_to_int` (atoi) and `int_to_string` (itoa).
+
+	TODO:
+		- Use Barrett reduction for non-powers-of-two.
+		- Also look at extracting and splatting several digits at once.
 */
 
 import "core:intrinsics"

+ 16 - 0
core/math/big/test.odin

@@ -94,6 +94,22 @@ PyRes :: struct {
 	return PyRes{res = r, err = nil};
 }
 
+@export test_sqr :: proc "c" (a: cstring) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+
+	aa, square := &Int{}, &Int{};
+	defer internal_destroy(aa, square);
+
+	if err = atoi(aa, string(a), 16); err != nil { return PyRes{res=":sqr:atoi(a):", err=err}; }
+	if err = #force_inline internal_sqr(square, aa);        err != nil { return PyRes{res=":sqr:sqr(square,a):", err=err}; }
+
+	r: cstring;
+	r, err = int_itoa_cstring(square, 16, context.temp_allocator);
+	if err != nil { return PyRes{res=":sqr:itoa(square):", err=err}; }
+	return PyRes{res = r, err = nil};
+}
+
 /*
 	NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
 */

+ 29 - 9
core/math/big/test.py

@@ -124,15 +124,16 @@ initialize_constants()
 
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 
-add  = load(l.test_add, [c_char_p, c_char_p], Res)
-sub  = load(l.test_sub, [c_char_p, c_char_p], Res)
-mul  = load(l.test_mul, [c_char_p, c_char_p], Res)
-div  = load(l.test_div, [c_char_p, c_char_p], Res)
+add        = load(l.test_add,    [c_char_p, c_char_p],   Res)
+sub        = load(l.test_sub,    [c_char_p, c_char_p],   Res)
+mul        = load(l.test_mul,    [c_char_p, c_char_p],   Res)
+sqr        = load(l.test_sqr,    [c_char_p          ],   Res)
+div        = load(l.test_div,    [c_char_p, c_char_p],   Res)
 
 # Powers and such
-int_log    = load(l.test_log,  [c_char_p, c_longlong], Res)
-int_pow    = load(l.test_pow,  [c_char_p, c_longlong], Res)
-int_sqrt   = load(l.test_sqrt, [c_char_p], Res)
+int_log    = load(l.test_log,    [c_char_p, c_longlong], Res)
+int_pow    = load(l.test_pow,    [c_char_p, c_longlong], Res)
+int_sqrt   = load(l.test_sqrt,   [c_char_p            ], Res)
 int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
 
 # Logical operations
@@ -218,6 +219,20 @@ def test_mul(a = 0, b = 0, expected_error = Error.Okay):
 		expected_result = a * b
 	return test("test_mul", res, [a, b], expected_error, expected_result)
 
+def test_sqr(a = 0, b = 0, expected_error = Error.Okay):
+	args = [arg_to_odin(a)]
+	try:
+		res  = sqr(*args)
+	except OSError as e:
+		print("{} while trying to square {} x {}.".format(e, a))
+		if EXIT_ON_FAIL: exit(3)
+		return False
+
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = a * a
+	return test("test_sqr", res, [a], expected_error, expected_result)
+
 def test_div(a = 0, b = 0, expected_error = Error.Okay):
 	args = [arg_to_odin(a), arg_to_odin(b)]
 	res  = div(*args)
@@ -390,7 +405,11 @@ TESTS = {
 	],
 	test_mul: [
 		[ 1234,   5432],
-		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7],
+		[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7 ],
+	],
+	test_sqr: [
+		[ 5432],
+		[ 0xd3b4e926aaba3040e1c12b5ea553b5 ],
 	],
 	test_div: [
 		[ 54321,	12345],
@@ -482,7 +501,7 @@ total_failures = 0
 # test_shr_signed also tests shr, so we're not going to test shr randomly.
 #
 RANDOM_TESTS = [
-	test_add, test_sub, test_mul, test_div,
+	test_add, test_sub, test_mul, test_sqr, test_div,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
 	test_gcd, test_lcm,
@@ -592,6 +611,7 @@ if __name__ == '__main__':
 				start = time.perf_counter()
 				res   = test_proc(a, b)
 				diff  = time.perf_counter() - start
+
 				TOTAL_TIME += diff
 
 				if test_proc not in TIMINGS:

+ 1 - 0
core/math/big/tune.odin

@@ -21,6 +21,7 @@ Category :: enum {
 	choose,
 	lsb,
 	ctz,
+	sqr,
 	bitfield_extract,
 };