Browse Source

big: Fast square method.

Jeroen van Rijn 4 years ago
parent
commit
c3a4d7dda2
3 changed files with 92 additions and 19 deletions
  1. 74 14
      core/math/big/basic.odin
  2. 1 2
      core/math/big/example.odin
  3. 17 3
      core/math/big/helpers.odin

+ 74 - 14
core/math/big/basic.odin

@@ -8,12 +8,11 @@ package big
 	For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 
-	This file contains basic arithmetic operations like `add`, `sub`, `div`, ...
+	This file contains basic arithmetic operations like `add`, `sub`, `mul`, `div`, ...
 */
 
 import "core:mem"
 import "core:intrinsics"
-import "core:fmt"
 
 /*
 	===========================
@@ -26,15 +25,9 @@ import "core:fmt"
 */
 int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 	dest := dest; x := a; y := b;
-	if err = clear_if_uninitialized(a); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(b); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(dest); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(a);    err != .None { return err; }
+	if err = clear_if_uninitialized(b);    err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
 	/*
 		All parameters have been initialized.
 		We can now safely ignore errors from comparison routines.
@@ -599,7 +592,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 	digits   := src.used + multiplier.used + 1;
 	neg      := src.sign != multiplier.sign;
 
-	if false && src == multiplier {
+	if src == multiplier {
 		/*
 			Do we need to square?
 		*/
@@ -614,7 +607,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			/* Fast comba? */
 			// err = s_mp_sqr_comba(a, c);
 		} else {
-			// err = s_mp_sqr(a, c);
+			err = _int_sqr(dest, src);
 		}
 	} else {
 		/*
@@ -646,7 +639,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			*/
 			// err = s_mp_mul_comba(a, b, c, digs);
 		} else {
-			fmt.println("Hai");
 			err = _int_mul(dest, src, multiplier, digits);
 		}
 	}
@@ -889,4 +881,72 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 	swap(dest, t);
 	destroy(t);
 	return clamp(dest);
+}
+
+/*
+	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
+*/
+_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
+	pa := src.used;
+
+	t := &Int{}; ix, iy: int;
+	/*
+		Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
+	*/
+	if err = grow(t, min((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+	t.used = (2 * pa) + 1;
+
+	for ix = 0; ix < pa; ix += 1 {
+		carry := DIGIT(0);
+		/*
+			First calculate the digit at 2*ix; calculate double precision result.
+		*/
+		r := _WORD(t.digit[ix+ix]) + _WORD(src.digit[ix] * src.digit[ix]);
+
+		/*
+			Store lower part in result.
+		*/
+		t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
+
+		/*
+			Get the carry.
+		*/
+		carry = DIGIT(r >> _DIGIT_BITS);
+
+		for iy = ix + 1; iy < pa; iy += 1 {
+			/*
+				First calculate the product.
+			*/
+			r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
+
+			/* Now calculate the double precision result. Nte we use
+			 * addition instead of *2 since it's easier to optimize
+			 */
+			r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
+
+			/*
+				Store lower part.
+			*/
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+
+			/*
+				Get carry.
+			*/
+			carry = DIGIT(r >> _DIGIT_BITS);
+		}
+		/*
+			Propagate upwards.
+		*/
+		for carry != 0 {
+			r     = _WORD(t.digit[ix+iy]) + _WORD(carry);
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+			carry = DIGIT(r >> _WORD(_DIGIT_BITS));
+			iy += 1;
+		}
+	}
+
+	err = clamp(t);
+	swap(dest, t);
+	destroy(t);
+	return err;
 }

+ 1 - 2
core/math/big/example.odin

@@ -64,9 +64,8 @@ demo :: proc() {
 	print("b", b, 10);
 
 	fmt.println("--- mul ---");
-	mul(c, a, b);
+	mul(c, a, a);
 	print("c", c, 10);
-
 }
 
 main :: proc() {

+ 17 - 3
core/math/big/helpers.odin

@@ -57,9 +57,7 @@ set :: proc { int_set_from_integer, int_copy };
 	Copy one `Int` to another.
 */
 int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
-	if err = clear_if_uninitialized(src); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(src); err != .None { return err; }
 	/*
 		If dest == src, do nothing
 	*/
@@ -535,6 +533,22 @@ clear_if_uninitialized :: proc(dest: ^Int, minimize := false) -> (err: Error) {
 	return .None;
 }
 
+_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
+	digits := digits;
+	if err = clear_if_uninitialized(src);  err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	/*
+		If dest == src, do nothing
+	*/
+	if (dest == src) {
+		return .None;
+	}
+
+	digits = min(digits, len(src.digit), len(dest.digit));
+	mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits);
+	return .None;
+}
+
 /*
 	Trim unused digits.