Browse Source

big: Add `_private_int_sqr_toom`.

Jeroen van Rijn 4 years ago
parent
commit
5f34ff9f9f
4 changed files with 125 additions and 21 deletions
  1. 3 0
      core/math/big/common.odin
  2. 0 4
      core/math/big/example.odin
  3. 1 4
      core/math/big/internal.odin
  4. 121 13
      core/math/big/private.odin

+ 3 - 0
core/math/big/common.odin

@@ -40,6 +40,9 @@ SQR_TOOM_CUTOFF      := _DEFAULT_SQR_TOOM_CUTOFF;
 	It would also be cool if we collected some data across various processor families.
 	This would let uss set reasonable defaults at runtime as this library initializes
 	itself by using `cpuid` or the ARM equivalent.
+
+	IMPORTANT: The 32_BIT path has largely gone untested. It needs to be tested and
+	debugged where necessary.
 */
 
 _DEFAULT_MUL_KARATSUBA_CUTOFF :: #config(MUL_KARATSUBA_CUTOFF,  80);

File diff suppressed because it is too large
+ 0 - 4
core/math/big/example.odin


+ 1 - 4
core/math/big/internal.odin

@@ -36,8 +36,6 @@ import "core:mem"
 import "core:intrinsics"
 import rnd "core:math/rand"
 
-// import "core:fmt"
-
 /*
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
 
@@ -630,8 +628,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			/*
 				Use Toom-Cook?
 			*/
-			// fmt.printf("_private_int_sqr_toom: %v\n", src.used);
-			err = #force_inline _private_int_sqr_karatsuba(dest, src);
+			err = #force_inline _private_int_sqr_toom(dest, src);
 		} else if src.used >= SQR_KARATSUBA_CUTOFF {
 			/*
 				Karatsuba?

+ 121 - 13
core/math/big/private.odin

@@ -387,37 +387,145 @@ _private_int_sqr_karatsuba :: proc(dest, src: ^Int, allocator := context.allocat
 	x0.used = B;
 	x1.used = src.used - B;
 
-	internal_copy_digits(x0, src, x0.used);
+	#force_inline internal_copy_digits(x0, src, x0.used);
 	#force_inline mem.copy_non_overlapping(&x1.digit[0], &src.digit[B], size_of(DIGIT) * x1.used);
-	internal_clamp(x0);
+	#force_inline internal_clamp(x0);
 
 	/*
 		Now calc the products x0*x0 and x1*x1.
 	*/
-	if err = internal_sqr(x0x0, x0);          err != nil { return err; }
-	if err = internal_sqr(x1x1, x1);          err != nil { return err; }
+	if err = internal_sqr(x0x0, x0);                  err != nil { return err; }
+	if err = internal_sqr(x1x1, x1);                  err != nil { return err; }
 
 	/*
 		Now calc (x1+x0)^2
 	*/
-	if err = internal_add(t1, x0, x1);        err != nil { return err; }
-	if err = internal_sqr(t1, t1);            err != nil { return err; }
+	if err = internal_add(t1, x0, x1);                err != nil { return err; }
+	if err = internal_sqr(t1, t1);                    err != nil { return err; }
 
 	/*
 		Add x0y0
 	*/
-	if err = internal_add(t2, x0x0, x1x1);    err != nil { return err; }
-	if err = internal_sub(t1, t1, t2);        err != nil { return err; }
+	if err = internal_add(t2, x0x0, x1x1);            err != nil { return err; }
+	if err = internal_sub(t1, t1, t2);                err != nil { return err; }
 
 	/*
 		Shift by B.
 	*/
-	if err = internal_shl_digit(t1, B);       err != nil { return err; }
-	if err = internal_shl_digit(x1x1, B * 2); err != nil { return err; }
-	if err = internal_add(t1, t1, x0x0);      err != nil { return err; }
-	if err = internal_add(dest, t1, x1x1);    err != nil { return err; }
+	if err = internal_shl_digit(t1, B);               err != nil { return err; }
+	if err = internal_shl_digit(x1x1, B * 2);         err != nil { return err; }
+	if err = internal_add(t1, t1, x0x0);              err != nil { return err; }
+	if err = internal_add(dest, t1, x1x1);            err != nil { return err; }
 
-	return internal_clamp(dest);
+	return #force_inline internal_clamp(dest);
+}
+
+/*
+	Squaring using Toom-Cook 3-way algorithm.
+
+	Setup and interpolation from algorithm SQR_3 in Chung, Jaewook, and M. Anwar Hasan. "Asymmetric squaring formulae."
+	  18th IEEE Symposium on Computer Arithmetic (ARITH'07). IEEE, 2007.
+*/
+_private_int_sqr_toom :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	S0, a0, a1, a2 := &Int{}, &Int{}, &Int{}, &Int{};
+	defer destroy(S0, a0, a1, a2);
+
+	/*
+		Init temps.
+	*/
+	if err = internal_zero(S0);                     err != nil { return err; }
+
+	/*
+		B
+	*/
+	B := src.used / 3;
+
+	/*
+		a = a2 * x^2 + a1 * x + a0;
+	*/
+	if err = internal_grow(a0, B);                  err != nil { return err; }
+	if err = internal_grow(a1, B);                  err != nil { return err; }
+	if err = internal_grow(a2, src.used - (2 * B)); err != nil { return err; }
+
+	a0.used = B;
+	a1.used = B;
+	a2.used = src.used - 2 * B;
+
+	#force_inline mem.copy_non_overlapping(&a0.digit[0], &src.digit[    0], size_of(DIGIT) * a0.used);
+	#force_inline mem.copy_non_overlapping(&a1.digit[0], &src.digit[    B], size_of(DIGIT) * a1.used);
+	#force_inline mem.copy_non_overlapping(&a2.digit[0], &src.digit[2 * B], size_of(DIGIT) * a2.used);
+
+	internal_clamp(a0);
+	internal_clamp(a1);
+	internal_clamp(a2);
+
+	/** S0 = a0^2;  */
+	if err = internal_sqr(S0, a0);                  err != nil { return err; }
+
+	/** \\S1 = (a2 + a1 + a0)^2 */
+	/** \\S2 = (a2 - a1 + a0)^2  */
+	/** \\S1 = a0 + a2; */
+	/** a0 = a0 + a2; */
+	if err = internal_add(a0, a0, a2);              err != nil { return err; }
+	/** \\S2 = S1 - a1; */
+	/** b = a0 - a1; */
+	if err = internal_sub(dest, a0, a1);            err != nil { return err; }
+	/** \\S1 = S1 + a1; */
+	/** a0 = a0 + a1; */
+	if err = internal_add(a0, a0, a1);              err != nil { return err; }
+	/** \\S1 = S1^2;  */
+	/** a0 = a0^2; */
+	if err = internal_sqr(a0, a0);                  err != nil { return err; }
+	/** \\S2 = S2^2;  */
+	/** b = b^2; */
+	if err = internal_sqr(dest, dest);              err != nil { return err; }
+	/** \\ S3 = 2 * a1 * a2  */
+	/** \\S3 = a1 * a2;  */
+	/** a1 = a1 * a2; */
+	if err = internal_mul(a1, a1, a2);              err != nil { return err; }
+	/** \\S3 = S3 << 1;  */
+	/** a1 = a1 << 1; */
+	if err = internal_shl(a1, a1, 1);               err != nil { return err; }
+	/** \\S4 = a2^2;  */
+	/** a2 = a2^2; */
+	if err = internal_sqr(a2, a2);                  err != nil { return err; }
+	/** \\ tmp = (S1 + S2)/2  */
+	/** \\tmp = S1 + S2; */
+	/** b = a0 + b; */
+	if err = internal_add(dest, a0, dest);          err != nil { return err; }
+	/** \\tmp = tmp >> 1; */
+	/** b = b >> 1; */
+	if err = internal_shr(dest, dest, 1);           err != nil { return err; }
+	/** \\ S1 = S1 - tmp - S3  */
+	/** \\S1 = S1 - tmp; */
+	/** a0 = a0 - b; */
+	if err = internal_sub(a0, a0, dest);            err != nil { return err; }
+	/** \\S1 = S1 - S3;  */
+	/** a0 = a0 - a1; */
+	if err = internal_sub(a0, a0, a1);              err != nil { return err; }
+	/** \\S2 = tmp - S4 -S0  */
+	/** \\S2 = tmp - S4;  */
+	/** b = b - a2; */
+	if err = internal_sub(dest, dest, a2);          err != nil { return err; }
+	/** \\S2 = S2 - S0;  */
+	/** b = b - S0; */
+	if err = internal_sub(dest, dest, S0);          err != nil { return err; }
+	/** \\P = S4*x^4 + S3*x^3 + S2*x^2 + S1*x + S0; */
+	/** P = a2*x^4 + a1*x^3 + b*x^2 + a0*x + S0; */
+	if err = internal_shl_digit(  a2, 4 * B);       err != nil { return err; }
+	if err = internal_shl_digit(  a1, 3 * B);       err != nil { return err; }
+	if err = internal_shl_digit(dest, 2 * B);       err != nil { return err; }
+	if err = internal_shl_digit(  a0, 1 * B);       err != nil { return err; }
+
+	if err = internal_add(a2, a2, a1);              err != nil { return err; }
+	if err = internal_add(dest, dest, a2);          err != nil { return err; }
+	if err = internal_add(dest, dest, a0);          err != nil { return err; }
+	if err = internal_add(dest, dest, S0);          err != nil { return err; }
+	/** a^2 - P  */
+
+	return #force_inline internal_clamp(dest);
 }
 
 /*

Some files were not shown because too many files changed in this diff