浏览代码

big: Add `_private_mul_karatsuba`.

Jeroen van Rijn 4 年之前
父节点
当前提交
8b49bbb0fc
共有 5 个文件被更改,包括 116 次插入26 次删除
  1. 2 2
      core/math/big/build.bat
  2. 4 8
      core/math/big/example.odin
  3. 2 4
      core/math/big/helpers.odin
  4. 4 10
      core/math/big/internal.odin
  5. 104 2
      core/math/big/private.odin

+ 2 - 2
core/math/big/build.bat

@@ -1,8 +1,8 @@
 @echo off
-:odin run . -vet
+odin run . -vet
 : -o:size
 :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
 :odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
 :odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py -fast-tests
-odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
+:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
 :odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests

+ 4 - 8
core/math/big/example.odin

@@ -206,16 +206,12 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 
-	atoi(a, "12980742146337069150589594264770969721", 10);
+	power_of_two(a, 312);
 	print("a: ", a, 10, true, true, true);
-	atoi(b, "4611686018427387904", 10);
+	power_of_two(b, 314);
 	print("b: ", b, 10, true, true, true);
-
-	if err := internal_divmod(c, d, a, b); err != nil {
-		fmt.printf("Error: %v\n", err);
-	}
-	print("c: ", c);
-	print("c: ", d);
+	_private_mul_karatsuba(c, a, b);
+	print("c: ", c, 10, true, true, true);
 }
 
 main :: proc() {

+ 2 - 4
core/math/big/helpers.odin

@@ -432,18 +432,16 @@ int_init_multi :: proc(integers: ..^Int, allocator := context.allocator) -> (err
 
 init_multi :: proc { int_init_multi, };
 
-copy_digits :: proc(dest, src: ^Int, digits: int, allocator := context.allocator) -> (err: Error) {
+copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0), allocator := context.allocator) -> (err: Error) {
 	context.allocator = allocator;
 
-	digits := digits;
 	/*
 		Check that `src` is usable and `dest` isn't immutable.
 	*/
 	assert_if_nil(dest, src);
 	#force_inline internal_clear_if_uninitialized(src) or_return;
 
-	digits = min(digits, len(src.digit), len(dest.digit));
-	return #force_inline internal_copy_digits(dest, src, digits);
+	return #force_inline internal_copy_digits(dest, src, digits, offset);
 }
 
 /*

+ 4 - 10
core/math/big/internal.odin

@@ -36,8 +36,6 @@ import "core:mem"
 import "core:intrinsics"
 import rnd "core:math/rand"
 
-import "core:fmt"
-
 /*
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
 
@@ -651,7 +649,6 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 				Fast comba?
 			*/
 			err = #force_inline _private_int_sqr_comba(dest, src);
-			//err = #force_inline _private_int_sqr(dest, src);
 		} else {
 			err = #force_inline _private_int_sqr(dest, src);
 		}
@@ -679,8 +676,8 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			// err = s_mp_mul_balance(a,b,c);
 		} else if false && min_used >= MUL_TOOM_CUTOFF {
 			// err = s_mp_mul_toom(a, b, c);
-		} else if false && min_used >= MUL_KARATSUBA_CUTOFF {
-			// err = s_mp_mul_karatsuba(a, b, c);
+		} else if min_used >= MUL_KARATSUBA_CUTOFF {
+			err = #force_inline _private_mul_karatsuba(dest, src, multiplier);
 		} else if digits < _WARRAY && min_used <= _MAX_COMBA {
 			/*
 				Can we use the fast multiplier?
@@ -1628,16 +1625,13 @@ internal_int_set_from_integer :: proc(dest: ^Int, src: $T, minimize := false, al
 
 internal_set :: proc { internal_int_set_from_integer, internal_int_copy };
 
-internal_copy_digits :: #force_inline proc(dest, src: ^Int, digits: int) -> (err: Error) {
+internal_copy_digits :: #force_inline proc(dest, src: ^Int, digits: int, offset := int(0)) -> (err: Error) {
 	#force_inline internal_error_if_immutable(dest) or_return;
 
 	/*
 		If dest == src, do nothing
 	*/
-	if (dest == src) { return nil; }
-
-	#force_inline mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits);
-	return nil;
+	return #force_inline _private_copy_digits(dest, src, digits, offset);
 }
 
 /*

+ 104 - 2
core/math/big/private.odin

@@ -89,6 +89,108 @@ _private_int_mul :: proc(dest, a, b: ^Int, digits: int, allocator := context.all
 	return internal_clamp(dest);
 }
 
+/*
+	product = |a| * |b| using Karatsuba Multiplication using three half size multiplications.
+
+	Let `B` represent the radix [e.g. 2**_DIGIT_BITS] and let `n` represent
+	half of the number of digits in the min(a,b)
+
+	`a` = `a1` * `B`**`n` + `a0`
+	`b` = `b`1 * `B`**`n` + `b0`
+
+	Then, a * b => 1b1 * B**2n + ((a1 + a0)(b1 + b0) - (a0b0 + a1b1)) * B + a0b0
+
+	Note that a1b1 and a0b0 are used twice and only need to be computed once.
+	So in total three half size (half # of digit) multiplications are performed,
+		a0b0, a1b1 and (a1+b1)(a0+b0)
+
+	Note that a multiplication of half the digits requires 1/4th the number of
+	single precision multiplications, so in total after one call 25% of the
+	single precision multiplications are saved.
+
+	Note also that the call to `internal_mul` can end up back in this function
+	if the a0, a1, b0, or b1 are above the threshold.
+
+	This is known as divide-and-conquer and leads to the famous O(N**lg(3)) or O(N**1.584)
+	work which is asymptopically lower than the standard O(N**2) that the
+	baseline/comba methods use. Generally though, the overhead of this method doesn't pay off
+	until a certain size is reached, of around 80 used DIGITs.
+*/
+_private_mul_karatsuba :: proc(dest, a, b: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	x0, x1, y0, y1, t1, x0y0, x1y1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
+	defer destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
+
+	/*
+		min # of digits, divided by two.
+	*/
+	B := min(a.used, b.used) >> 1;
+
+	/*
+		Init all the temps.
+	*/
+	internal_grow(x0, B)          or_return;
+	internal_grow(x1, a.used - B) or_return;
+	internal_grow(y0, B)          or_return;
+	internal_grow(y1, b.used - B) or_return;
+	internal_grow(t1, B * 2)      or_return;
+	internal_grow(x0y0, B * 2)    or_return;
+	internal_grow(x1y1, B * 2)    or_return;
+
+	/*
+		Now shift the digits.
+	*/
+	x0.used, y0.used = B, B;
+	x1.used = a.used - B;
+	y1.used = b.used - B;
+
+	/*
+		We copy the digits directly instead of using higher level functions
+		since we also need to shift the digits.
+	*/
+	internal_copy_digits(x0, a, x0.used);
+	internal_copy_digits(y0, b, y0.used);
+	internal_copy_digits(x1, a, x1.used, B);
+	internal_copy_digits(y1, b, y1.used, B);
+
+	/*
+		Only need to clamp the lower words since by definition the
+		upper words x1/y1 must have a known number of digits.
+	*/
+	clamp(x0);
+	clamp(y0);
+
+	/*
+		Now calc the products x0y0 and x1y1,
+		after this x0 is no longer required, free temp [x0==t2]!
+	*/
+	internal_mul(x0y0, x0, y0)      or_return; /* x0y0 = x0*y0 */
+	internal_mul(x1y1, x1, y1)      or_return; /* x1y1 = x1*y1 */
+	internal_add(t1,   x1, x0)      or_return; /* now calc x1+x0 and */
+	internal_add(x0,   y1, y0)      or_return; /* t2 = y1 + y0 */
+	internal_mul(t1,   t1, x0)      or_return; /* t1 = (x1 + x0) * (y1 + y0) */
+
+	/*
+		Add x0y0.
+	*/
+	internal_add(x0, x0y0, x1y1)    or_return; /* t2 = x0y0 + x1y1 */
+	internal_sub(t1,   t1,   x0)    or_return; /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
+
+	/*
+		shift by B.
+	*/
+	internal_shl_digit(t1, B)       or_return; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
+	internal_shl_digit(x1y1, B * 2) or_return; /* x1y1 = x1y1 << 2*B */
+
+	internal_add(t1, x0y0, t1)      or_return; /* t1 = x0y0 + t1 */
+	internal_add(dest, t1, x1y1)    or_return; /* t1 = x0y0 + t1 + x1y1 */
+
+	return nil;
+}
+
+
+
 /*
 	Fast (comba) multiplier
 
@@ -1629,7 +1731,7 @@ _private_log_power_of_two :: proc(a: ^Int, base: DIGIT) -> (log: int, err: Error
 	Copies DIGITs from `src` to `dest`.
 	Assumes `src` and `dest` to not be `nil` and have been initialized.
 */
-_private_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
+_private_copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0)) -> (err: Error) {
 	digits := digits;
 	/*
 		If dest == src, do nothing
@@ -1639,7 +1741,7 @@ _private_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
 	}
 
 	digits = min(digits, len(src.digit), len(dest.digit));
-	mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits);
+	mem.copy_non_overlapping(&dest.digit[0], &src.digit[offset], size_of(DIGIT) * digits);
 	return nil;
 }