Browse Source

big: Add `_private_int_sqr_karatsuba`.

Jeroen van Rijn 4 years ago
parent
commit
2b274fefbb
3 changed files with 91 additions and 12 deletions
  1. 15 4
      core/math/big/example.odin
  2. 10 8
      core/math/big/internal.odin
  3. 66 0
      core/math/big/private.odin

+ 15 - 4
core/math/big/example.odin

@@ -80,16 +80,27 @@ demo :: proc() {
 	err: Error;
 	bs: string;
 
-	if err = factorial(a, 500); err != nil { fmt.printf("factorial err: %v\n", err); return; }
+	// if err = factorial(a, 850); err != nil { fmt.printf("factorial err: %v\n", err); return; }
+
+	foo := "615037959146039477924633848896619112832171971562900618409305032006863881436080";
+	if err = atoi(a, foo, 10); err != nil { return; }
+	print("a: ", a, 10, true, true, true);
+	fmt.println();
+
 	{
 		SCOPED_TIMING(.sqr);
-		if err = sqr(b, a);     err != nil { fmt.printf("sqr err: %v\n", err); return; }
+		if err = sqr(b, a); err != nil { fmt.printf("sqr err: %v\n", err); return; }
 	}
+	fmt.println();
+	print("b _sqr_karatsuba: ", b);
+	fmt.println();
 
-	bs, err = itoa(b, 10);
+	bs, err = itoa(b, 16);
 	defer delete(bs);
 
-	assert(bs[:50] == "14887338741396604108836218987068397819515734169330");
+	if bs[:50] != "1C367982F3050A8A3C62A8A7906D165438B54B287AF3F15D36" {
+		fmt.println("sqr failed");
+	}
 }
 
 main :: proc() {

+ 10 - 8
core/math/big/internal.odin

@@ -36,7 +36,7 @@ import "core:mem"
 import "core:intrinsics"
 import rnd "core:math/rand"
 
-//import "core:fmt"
+// import "core:fmt"
 
 /*
 	Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
@@ -627,20 +627,22 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			Do we need to square?
 		*/
 		if src.used >= SQR_TOOM_CUTOFF {
-			/* Use Toom-Cook? */
-			// err = s_mp_sqr_toom(a, c);
+			/*
+				Use Toom-Cook?
+			*/
 			// fmt.printf("_private_int_sqr_toom: %v\n", src.used);
-			err = #force_inline _private_int_sqr(dest, src);
+			err = #force_inline _private_int_sqr_karatsuba(dest, src);
 		} else if src.used >= SQR_KARATSUBA_CUTOFF {
-			/* Karatsuba? */
-			// err = s_mp_sqr_karatsuba(a, c);
-			// fmt.printf("_private_int_sqr_karatsuba: %v\n", src.used);
-			err = #force_inline _private_int_sqr(dest, src);
+			/*
+				Karatsuba?
+			*/
+			err = #force_inline _private_int_sqr_karatsuba(dest, src);
 		} else if ((src.used * 2) + 1) < _WARRAY && src.used < (_MAX_COMBA / 2) {
 			/*
 				Fast comba?
 			*/
 			err = #force_inline _private_int_sqr_comba(dest, src);
+			//err = #force_inline _private_int_sqr(dest, src);
 		} else {
 			err = #force_inline _private_int_sqr(dest, src);
 		}

+ 66 - 0
core/math/big/private.odin

@@ -354,6 +354,72 @@ _private_int_sqr_comba :: proc(dest, src: ^Int, allocator := context.allocator)
 	return internal_clamp(dest);
 }
 
+/*
+	Karatsuba squaring, computes `dest` = `src` * `src` using three half-size squarings.
+ 
+ 	See comments of `_private_int_mul_karatsuba` for details.
+ 	It is essentially the same algorithm but merely tuned to perform recursive squarings.
+*/
+_private_int_sqr_karatsuba :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	x0, x1, t1, t2, x0x0, x1x1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
+	defer internal_destroy(x0, x1, t1, t2, x0x0, x1x1);
+
+	/*
+		Min # of digits, divided by two.
+	*/
+	B := src.used >> 1;
+
+	/*
+		Init temps.
+	*/
+	if err = internal_grow(x0,   B);                  err != nil { return err; }
+	if err = internal_grow(x1,   src.used - B);       err != nil { return err; }
+	if err = internal_grow(t1,   src.used * 2);       err != nil { return err; }
+	if err = internal_grow(t2,   src.used * 2);       err != nil { return err; }
+	if err = internal_grow(x0x0, B * 2       );       err != nil { return err; }
+	if err = internal_grow(x1x1, (src.used - B) * 2); err != nil { return err; }
+
+	/*
+		Now shift the digits.
+	*/
+	x0.used = B;
+	x1.used = src.used - B;
+
+	internal_copy_digits(x0, src, x0.used);
+	#force_inline mem.copy_non_overlapping(&x1.digit[0], &src.digit[B], size_of(DIGIT) * x1.used);
+	internal_clamp(x0);
+
+	/*
+		Now calc the products x0*x0 and x1*x1.
+	*/
+	if err = internal_sqr(x0x0, x0);          err != nil { return err; }
+	if err = internal_sqr(x1x1, x1);          err != nil { return err; }
+
+	/*
+		Now calc (x1+x0)^2
+	*/
+	if err = internal_add(t1, x0, x1);        err != nil { return err; }
+	if err = internal_sqr(t1, t1);            err != nil { return err; }
+
+	/*
+		Add x0y0
+	*/
+	if err = internal_add(t2, x0x0, x1x1);    err != nil { return err; }
+	if err = internal_sub(t1, t1, t2);        err != nil { return err; }
+
+	/*
+		Shift by B.
+	*/
+	if err = internal_shl_digit(t1, B);       err != nil { return err; }
+	if err = internal_shl_digit(x1x1, B * 2); err != nil { return err; }
+	if err = internal_add(t1, t1, x0x0);      err != nil { return err; }
+	if err = internal_add(dest, t1, x1x1);    err != nil { return err; }
+
+	return internal_clamp(dest);
+}
+
 /*
 	Divide by three (based on routine from MPI and the GMP manual).
 */