소스 검색

big: Fast square method.

Jeroen van Rijn 4 년 전
부모
커밋
c3a4d7dda2
3개의 변경된 파일92개의 추가작업 그리고 19개의 파일을 삭제
  1. 74 14
      core/math/big/basic.odin
  2. 1 2
      core/math/big/example.odin
  3. 17 3
      core/math/big/helpers.odin

+ 74 - 14
core/math/big/basic.odin

@@ -8,12 +8,11 @@ package big
 	For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3.
 	The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
 
-	This file contains basic arithmetic operations like `add`, `sub`, `div`, ...
+	This file contains basic arithmetic operations like `add`, `sub`, `mul`, `div`, ...
 */
 
 import "core:mem"
 import "core:intrinsics"
-import "core:fmt"
 
 /*
 	===========================
@@ -26,15 +25,9 @@ import "core:fmt"
 */
 int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
 	dest := dest; x := a; y := b;
-	if err = clear_if_uninitialized(a); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(b); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(dest); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(a);    err != .None { return err; }
+	if err = clear_if_uninitialized(b);    err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
 	/*
 		All parameters have been initialized.
 		We can now safely ignore errors from comparison routines.
@@ -599,7 +592,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 	digits   := src.used + multiplier.used + 1;
 	neg      := src.sign != multiplier.sign;
 
-	if false && src == multiplier {
+	if src == multiplier {
 		/*
 			Do we need to square?
 		*/
@@ -614,7 +607,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			/* Fast comba? */
 			// err = s_mp_sqr_comba(a, c);
 		} else {
-			// err = s_mp_sqr(a, c);
+			err = _int_sqr(dest, src);
 		}
 	} else {
 		/*
@@ -646,7 +639,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			*/
 			// err = s_mp_mul_comba(a, b, c, digs);
 		} else {
-			fmt.println("Hai");
 			err = _int_mul(dest, src, multiplier, digits);
 		}
 	}
@@ -889,4 +881,72 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 	swap(dest, t);
 	destroy(t);
 	return clamp(dest);
+}
+
+/*
+	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
+*/
+_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
+	pa := src.used;
+
+	t := &Int{}; ix, iy: int;
+	/*
+		Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
+	*/
+	if err = grow(t, min((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+	t.used = (2 * pa) + 1;
+
+	for ix = 0; ix < pa; ix += 1 {
+		carry := DIGIT(0);
+		/*
+			First calculate the digit at 2*ix; calculate double precision result.
+		*/
+		r := _WORD(t.digit[ix+ix]) + _WORD(src.digit[ix] * src.digit[ix]);
+
+		/*
+			Store lower part in result.
+		*/
+		t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
+
+		/*
+			Get the carry.
+		*/
+		carry = DIGIT(r >> _DIGIT_BITS);
+
+		for iy = ix + 1; iy < pa; iy += 1 {
+			/*
+				First calculate the product.
+			*/
+			r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
+
+			/* Now calculate the double precision result. Nte we use
+			 * addition instead of *2 since it's easier to optimize
+			 */
+			r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
+
+			/*
+				Store lower part.
+			*/
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+
+			/*
+				Get carry.
+			*/
+			carry = DIGIT(r >> _DIGIT_BITS);
+		}
+		/*
+			Propagate upwards.
+		*/
+		for carry != 0 {
+			r     = _WORD(t.digit[ix+iy]) + _WORD(carry);
+			t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+			carry = DIGIT(r >> _WORD(_DIGIT_BITS));
+			iy += 1;
+		}
+	}
+
+	err = clamp(t);
+	swap(dest, t);
+	destroy(t);
+	return err;
 }

+ 1 - 2
core/math/big/example.odin

@@ -64,9 +64,8 @@ demo :: proc() {
 	print("b", b, 10);
 
 	fmt.println("--- mul ---");
-	mul(c, a, b);
+	mul(c, a, a);
 	print("c", c, 10);
-
 }
 
 main :: proc() {

+ 17 - 3
core/math/big/helpers.odin

@@ -57,9 +57,7 @@ set :: proc { int_set_from_integer, int_copy };
 	Copy one `Int` to another.
 */
 int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
-	if err = clear_if_uninitialized(src); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(src); err != .None { return err; }
 	/*
 		If dest == src, do nothing
 	*/
@@ -535,6 +533,22 @@ clear_if_uninitialized :: proc(dest: ^Int, minimize := false) -> (err: Error) {
 	return .None;
 }
 
+_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
+	digits := digits;
+	if err = clear_if_uninitialized(src);  err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	/*
+		If dest == src, do nothing
+	*/
+	if (dest == src) {
+		return .None;
+	}
+
+	digits = min(digits, len(src.digit), len(dest.digit));
+	mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits);
+	return .None;
+}
+
 /*
 	Trim unused digits.