浏览代码

big: Fast square method.

Jeroen van Rijn 4 年之前
父节点
当前提交
c3a4d7dda2
共有 3 个文件被更改,包括 92 次插入19 次删除
  1. 74 14
      core/math/big/basic.odin
  2. 1 2
      core/math/big/example.odin
  3. 17 3
      core/math/big/helpers.odin

+ 74 - 14
core/math/big/basic.odin

@@ -8,12 +8,11 @@ package big
 	For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3.
 	For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 
 
-	This file contains basic arithmetic operations like `add`, `sub`, `div`, ...
+	This file contains basic arithmetic operations like `add`, `sub`, `mul`, `div`, ...
 */
 */
 
 
 import "core:mem"
 import "core:mem"
 import "core:intrinsics"
 import "core:intrinsics"
-import "core:fmt"
 
 
 /*
 /*
 	===========================
 	===========================
@@ -26,15 +25,9 @@ import "core:fmt"
 */
 */
 int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 	dest := dest; x := a; y := b;
 	dest := dest; x := a; y := b;
-	if err = clear_if_uninitialized(a); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(b); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(dest); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(a);    err != .None { return err; }
+	if err = clear_if_uninitialized(b);    err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
 	/*
 	/*
 		All parameters have been initialized.
 		All parameters have been initialized.
 		We can now safely ignore errors from comparison routines.
 		We can now safely ignore errors from comparison routines.
@@ -599,7 +592,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 	digits   := src.used + multiplier.used + 1;
 	digits   := src.used + multiplier.used + 1;
 	neg      := src.sign != multiplier.sign;
 	neg      := src.sign != multiplier.sign;
 
 
-	if false && src == multiplier {
+	if src == multiplier {
 		/*
 		/*
 			Do we need to square?
 			Do we need to square?
 		*/
 		*/
@@ -614,7 +607,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			/* Fast comba? */
 			/* Fast comba? */
 			// err = s_mp_sqr_comba(a, c);
 			// err = s_mp_sqr_comba(a, c);
 		} else {
 		} else {
-			// err = s_mp_sqr(a, c);
+			err = _int_sqr(dest, src);
 		}
 		}
 	} else {
 	} else {
 		/*
 		/*
@@ -646,7 +639,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			*/
 			*/
 			// err = s_mp_mul_comba(a, b, c, digs);
 			// err = s_mp_mul_comba(a, b, c, digs);
 		} else {
 		} else {
-			fmt.println("Hai");
 			err = _int_mul(dest, src, multiplier, digits);
 			err = _int_mul(dest, src, multiplier, digits);
 		}
 		}
 	}
 	}
@@ -889,4 +881,72 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 	swap(dest, t);
 	swap(dest, t);
 	destroy(t);
 	destroy(t);
 	return clamp(dest);
 	return clamp(dest);
+}
+
+/*
+	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
+*/
+_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
+	pa := src.used;
+
+	t := &Int{}; ix, iy: int;
+	/*
+		Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
+	*/
+	if err = grow(t, min((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+	t.used = (2 * pa) + 1;
+
+	for ix = 0; ix < pa; ix += 1 {
+		carry := DIGIT(0);
+		/*
+			First calculate the digit at 2*ix; calculate double precision result.
+		*/
+		r := _WORD(t.digit[ix+ix]) + _WORD(src.digit[ix] * src.digit[ix]);
+
+		/*
+			Store lower part in result.
+		*/
+		t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
+
+		/*
+			Get the carry.
+		*/
+		carry = DIGIT(r >> _DIGIT_BITS);
+
+		for iy = ix + 1; iy < pa; iy += 1 {
+			/*
+				First calculate the product.
+			*/
+			r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
+
+			/* Now calculate the double precision result. Nte we use
+			 * addition instead of *2 since it's easier to optimize
+			 */
+			r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
+
+			/*
+				Store lower part.
+			*/
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+
+			/*
+				Get carry.
+			*/
+			carry = DIGIT(r >> _DIGIT_BITS);
+		}
+		/*
+			Propagate upwards.
+		*/
+		for carry != 0 {
+			r     = _WORD(t.digit[ix+iy]) + _WORD(carry);
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+			carry = DIGIT(r >> _WORD(_DIGIT_BITS));
+			iy += 1;
+		}
+	}
+
+	err = clamp(t);
+	swap(dest, t);
+	destroy(t);
+	return err;
 }
 }

+ 1 - 2
core/math/big/example.odin

@@ -64,9 +64,8 @@ demo :: proc() {
 	print("b", b, 10);
 	print("b", b, 10);
 
 
 	fmt.println("--- mul ---");
 	fmt.println("--- mul ---");
-	mul(c, a, b);
+	mul(c, a, a);
 	print("c", c, 10);
 	print("c", c, 10);
-
 }
 }
 
 
 main :: proc() {
 main :: proc() {

+ 17 - 3
core/math/big/helpers.odin

@@ -57,9 +57,7 @@ set :: proc { int_set_from_integer, int_copy };
 	Copy one `Int` to another.
 	Copy one `Int` to another.
 */
 */
 int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
 int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
-	if err = clear_if_uninitialized(src); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(src); err != .None { return err; }
 	/*
 	/*
 		If dest == src, do nothing
 		If dest == src, do nothing
 	*/
 	*/
@@ -535,6 +533,22 @@ clear_if_uninitialized :: proc(dest: ^Int, minimize := false) -> (err: Error) {
 	return .None;
 	return .None;
 }
 }
 
 
+_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
+	digits := digits;
+	if err = clear_if_uninitialized(src);  err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	/*
+		If dest == src, do nothing
+	*/
+	if (dest == src) {
+		return .None;
+	}
+
+	digits = min(digits, len(src.digit), len(dest.digit));
+	mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits);
+	return .None;
+}
+
 /*
 /*
 	Trim unused digits.
 	Trim unused digits.