Browse Source

Add `_mul_comba` path.

Jeroen van Rijn 4 years ago
parent
commit
a27612ec6a
3 changed files with 120 additions and 23 deletions
  1. 114 17
      core/math/big/basic.odin
  2. 3 3
      core/math/big/build.bat
  3. 3 3
      core/math/big/example.odin

+ 114 - 17
core/math/big/basic.odin

@@ -499,8 +499,7 @@ mod_bits :: proc { int_mod_bits, };
 	Multiply by a DIGIT.
 */
 int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
-	if err = clear_if_uninitialized(src ); err != .None { return err; }
-	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	if err = clear_if_uninitialized(src, dest); err != .None { return err; }
 
 	if multiplier == 0 {
 		return zero(dest);
@@ -576,9 +575,7 @@ int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
 	High level multiplication (handles sign).
 */
 int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
-	if err = clear_if_uninitialized(src);        err != .None { return err; }
-	if err = clear_if_uninitialized(dest);       err != .None { return err; }
-	if err = clear_if_uninitialized(multiplier); err != .None { return err; }
+	if err = clear_if_uninitialized(dest, src, multiplier); err != .None { return err; }
 
 	/*
 		Early out for `multiplier` is zero; Set `dest` to zero.
@@ -587,11 +584,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 		return zero(dest);
 	}
 
-	min_used := min(src.used, multiplier.used);
-	max_used := max(src.used, multiplier.used);
-	digits   := src.used + multiplier.used + 1;
-	neg      := src.sign != multiplier.sign;
-
 	if src == multiplier {
 		/*
 			Do we need to square?
@@ -619,6 +611,11 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			* Using it to cut the input into slices small enough for _mul_comba
 			* was actually slower on the author's machine, but YMMV.
 		*/
+
+		min_used := min(src.used, multiplier.used);
+		max_used := max(src.used, multiplier.used);
+		digits   := src.used + multiplier.used + 1;
+
 		if        false &&  min_used     >= _MUL_KARATSUBA_CUTOFF &&
 						    max_used / 2 >= _MUL_KARATSUBA_CUTOFF &&
 			/*
@@ -630,18 +627,19 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
 			// err = s_mp_mul_toom(a, b, c);
 		} else if false && min_used >= _MUL_KARATSUBA_CUTOFF {
 			// err = s_mp_mul_karatsuba(a, b, c);
-		} else if false && digits < _WARRAY && min_used <= _MAX_COMBA {
+		} else if digits < _WARRAY && min_used <= _MAX_COMBA {
 			/*
 				Can we use the fast multiplier?
 				* The fast multiplier can be used if the output will
 				* have less than MP_WARRAY digits and the number of
 				* digits won't affect carry propagation
 			*/
-			// err = s_mp_mul_comba(a, b, c, digs);
+			err = _int_mul_comba(dest, src, multiplier, digits);
 		} else {
 			err = _int_mul(dest, src, multiplier, digits);
 		}
 	}
+	neg      := src.sign != multiplier.sign;
 	dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
 	return err;
 }
@@ -1033,14 +1031,11 @@ _int_sub :: proc(dest, number, decrease: ^Int) -> (err: Error) {
 	many digits of output are created.
 */
 _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
-
 	/*
 		Can we use the fast multiplier?
 	*/
-	when false { // Have Comba?
-		if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
-			return _int_mul_comba(dest, a, b, digits);
-		}
+	if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
+		return _int_mul_comba(dest, a, b, digits);
 	}
 
 	/*
@@ -1095,6 +1090,108 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 	return clamp(dest);
 }
 
+/*
+	Fast (comba) multiplier
+
+	This is the fast column-array [comba] multiplier.  It is
+	designed to compute the columns of the product first
+	then handle the carries afterwards.  This has the effect
+	of making the nested loops that compute the columns very
+	simple and schedulable on super-scalar processors.
+
+	This has been modified to produce a variable number of
+	digits of output so if say only a half-product is required
+	you don't have to compute the upper half (a feature
+	required for fast Barrett reduction).
+
+	Based on Algorithm 14.12 on pp.595 of HAC.
+*/
+_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
+	/*
+		Set up array.
+	*/
+	W: [_WARRAY]DIGIT = ---;
+
+	/*
+		Grow the destination as required.
+	*/
+	if err = grow(dest, digits); err != .None { return err; }
+
+	/*
+		Number of output digits to produce.
+	*/
+	pa := min(digits, a.used + b.used);
+
+	/*
+		Clear the carry
+	*/
+	_W := _WORD(0);
+
+	ix: int;
+	for ix = 0; ix < pa; ix += 1 {
+		tx, ty, iy, iz: int;
+
+		/*
+			Get offsets into the two bignums.
+		*/
+		ty = min(b.used - 1, ix);
+		tx = ix - ty;
+
+		/*
+			This is the number of times the loop will iterate, essentially.
+			while (tx++ < a->used && ty-- >= 0) { ... }
+		*/
+		 
+		iy = min(a.used - tx, ty + 1);
+
+		/*
+			Execute loop.
+		*/
+		for iz = 0; iz < iy; iz += 1 {
+			_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
+		}
+
+		/*
+			Store term.
+		*/
+		W[ix] = DIGIT(_W) & _MASK;
+
+		/*
+			Make next carry.
+		*/
+		_W = _W >> _WORD(_DIGIT_BITS);
+	}
+
+	/*
+		Setup dest.
+	*/
+	old_used := dest.used;
+	dest.used = pa;
+
+	for ix = 0; ix < pa; ix += 1 {
+		/*
+			Now extract the previous digit [below the carry].
+		*/
+		dest.digit[ix] = W[ix];
+	}
+
+	/*
+		Clear unused digits [that existed in the old copy of dest].
+	*/
+	zero_count := old_used - dest.used;
+	/*
+		Zero remainder.
+	*/
+	if zero_count > 0 {
+		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
+	}
+	/*
+		Adjust dest.used based on leading zeroes.
+	*/
+
+	return clamp(dest);
+}
+
 /*
 	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
 */

+ 3 - 3
core/math/big/build.bat

@@ -1,10 +1,10 @@
 @echo off
-odin run . -vet
+:odin run . -vet
 : -o:size -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
 :odin build . -build-mode:shared -show-timings -o:size -use-separate-modules -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:size -use-separate-modules
-:odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules -no-bounds-check
+odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules
 
-:python test.py
+python test.py

+ 3 - 3
core/math/big/example.odin

@@ -87,7 +87,7 @@ Event :: struct {
 }
 Timings := [Category]Event{};
 
-print :: proc(name: string, a: ^Int, base := i8(10), print_extra_info := false, print_name := false, newline := true) {
+print :: proc(name: string, a: ^Int, base := i8(10), print_name := false, newline := true, print_extra_info := false) {
 	s := time.tick_now();
 	as, err := itoa(a, base);
 	Timings[.itoa].t += time.tick_since(s); Timings[.itoa].c += 1;
@@ -117,10 +117,10 @@ demo :: proc() {
 	defer destroy(a, b, c, d, e, f);
 
 	s := time.tick_now();
-	err = choose(a, 1024, 255);
+	err = choose(a, 65535, 255);
 	Timings[.choose].t += time.tick_since(s); Timings[.choose].c += 1;
 
-	print("1024 choose 255", a, 10, true, true, true);
+	print("65535 choose 255", a, 10, true, true, true);
 	fmt.printf("Error: %v\n", err);
 }