소스 검색

big: Add multiplication.

Jeroen van Rijn 4 년 전
부모
커밋
b4a29844e9
2개의 변경된 파일205개의 추가작업 그리고 48개의 파일을 삭제
  1. 198 39
      core/math/big/basic.odin
  2. 7 9
      core/math/big/example.odin

+ 198 - 39
core/math/big/basic.odin

@@ -13,6 +13,7 @@ package big
 
 import "core:mem"
 import "core:intrinsics"
+import "core:fmt"
 
 /*
 	===========================
@@ -467,13 +468,8 @@ shl1   :: double;
 	remainder = numerator % (1 << bits)
 */
 int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
-	remainder := remainder; numerator := numerator;
-	if err = clear_if_uninitialized(remainder); err != .None {
-		return err;
-	}
-	if err = clear_if_uninitialized(numerator); err != .None {
-		return err;
-	}
+	if err = clear_if_uninitialized(remainder); err != .None { return err; }
+	if err = clear_if_uninitialized(numerator); err != .None { return err; }
 
 	if bits  < 0 { return .Invalid_Argument; }
 	if bits == 0 { return zero(remainder); }
@@ -505,6 +501,161 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
 }
 mod_bits :: proc { int_mod_bits, };
 
+/*
+	Multiply by a DIGIT.
+*/
+int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
+	if err = clear_if_uninitialized(src ); err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
+
+	if multiplier == 0 {
+		return zero(dest);
+	}
+	if multiplier == 1 {
+		return copy(dest, src);
+	}
+
+	/*
+		Power of two?
+	*/
+	if multiplier == 2 {
+		return double(dest, src);
+	}
+	if is_power_of_two(int(multiplier)) {
+		ix: int;
+		if ix, err = log_n(multiplier, 2); err != .None { return err; }
+		return shl(dest, src, ix);
+	}
+
+	/*
+		Ensure `dest` is big enough to hold `src` * `multiplier`.
+	*/
+	if err = grow(dest, max(src.used + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+
+	/*
+		Save the original used count.
+	*/
+	old_used := dest.used;
+	/*
+		Set the sign.
+	*/
+	dest.sign = src.sign;
+	/*
+		Set up carry.
+	*/
+	carry := _WORD(0);
+	/*
+		Compute columns.
+	*/
+	ix := 0;
+	for ; ix < src.used; ix += 1 {
+		/*
+			Compute product and carry sum for this term
+		*/
+		product := carry + _WORD(src.digit[ix]) * _WORD(multiplier);
+		/*
+			Mask off higher bits to get a single DIGIT.
+		*/
+		dest.digit[ix] = DIGIT(product & _WORD(_MASK));
+		/*
+			Send carry into next iteration
+		*/
+		carry = product >> _DIGIT_BITS;
+	}
+
+	/*
+		Store final carry [if any] and increment used.
+	*/
+	dest.digit[ix] = DIGIT(carry);
+	dest.used = src.used + 1;
+
+	/*
+		Zero unused digits.
+	*/
+	zero_count := old_used - dest.used;
+	if zero_count > 0 {
+		mem.zero_slice(dest.digit[zero_count:]);
+	}
+	return clamp(dest);
+}
+
+/*
+	High level multiplication (handles sign).
+*/
+int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
+	if err = clear_if_uninitialized(src);        err != .None { return err; }
+	if err = clear_if_uninitialized(dest);       err != .None { return err; }
+	if err = clear_if_uninitialized(multiplier); err != .None { return err; }
+
+	/*
+		Early out for `multiplier` is zero; Set `dest` to zero.
+	*/
+	if z, _ := is_zero(multiplier); z {
+		return zero(dest);
+	}
+
+	min_used := min(src.used, multiplier.used);
+	max_used := max(src.used, multiplier.used);
+	digits   := src.used + multiplier.used + 1;
+	neg      := src.sign != multiplier.sign;
+
+	if false && src == multiplier {
+		/*
+			Do we need to square?
+		*/
+		if        false && src.used >= _SQR_TOOM_CUTOFF {
+			/* Use Toom-Cook? */
+			// err = s_mp_sqr_toom(a, c);
+		} else if false && src.used >= _SQR_KARATSUBA_CUTOFF {
+			/* Karatsuba? */
+			// err = s_mp_sqr_karatsuba(a, c);
+		} else if false && ((src.used * 2) + 1) < _WARRAY &&
+		                   src.used < (_MAX_COMBA / 2) {
+			/* Fast comba? */
+			// err = s_mp_sqr_comba(a, c);
+		} else {
+			// err = s_mp_sqr(a, c);
+		}
+	} else {
+		/*
+			Can we use the balance method? Check sizes.
+			* The smaller one needs to be larger than the Karatsuba cut-off.
+			* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
+			* to make some sense, but it depends on architecture, OS, position of the
+			* stars... so YMMV.
+			* Using it to cut the input into slices small enough for _mul_comba
+			* was actually slower on the author's machine, but YMMV.
+		*/
+		if        false &&  min_used     >= _MUL_KARATSUBA_CUTOFF &&
+						    max_used / 2 >= _MUL_KARATSUBA_CUTOFF &&
+			/*
+				Not much effect was observed below a ratio of 1:2, but again: YMMV.
+			*/
+							max_used     >= 2 * min_used {
+			// err = s_mp_mul_balance(a,b,c);
+		} else if false && min_used >= _MUL_TOOM_CUTOFF {
+			// err = s_mp_mul_toom(a, b, c);
+		} else if false && min_used >= _MUL_KARATSUBA_CUTOFF {
+			// err = s_mp_mul_karatsuba(a, b, c);
+		} else if false && digits < _WARRAY && min_used <= _MAX_COMBA {
+			/*
+				Can we use the fast multiplier?
+				* The fast multiplier can be used if the output will
+				* have less than MP_WARRAY digits and the number of
+				* digits won't affect carry propagation
+			*/
+			// err = s_mp_mul_comba(a, b, c, digs);
+		} else {
+			fmt.println("Hai");
+			err = _int_mul(dest, src, multiplier, digits);
+		}
+	}
+	dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
+	return err;
+}
+
+mul :: proc { int_mul, int_mul_digit, };
+
 /*
 	==========================
 		Low-level routines    
@@ -688,46 +839,54 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 		}
 	}
 
-	if err = grow(dest, digits); err != .None { return err; }
-	dest.used = digits;
-
 	/*
+		Set up temporary output `Int`, which we'll swap for `dest` when done.
+	*/
 
-	/* compute the digits of the product directly */
-	pa = a->used;
-	for (ix = 0; ix < pa; ix++) {
-		int iy, pb;
-		mp_digit u = 0;
+	t := &Int{};
 
-		/* limit ourselves to making digs digits of output */
-		pb = MP_MIN(b->used, digs - ix);
+	if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+	t.used = digits;
 
-		/* compute the columns of the output and propagate the carry */
-		for (iy = 0; iy < pb; iy++) {
-			/* compute the column as a mp_word */
-			mp_word r = (mp_word)t.dp[ix + iy] +
-							((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
-							(mp_word)u;
+	/*
+		Compute the digits of the product directly.
+	*/
+	pa := a.used;
+	for ix := 0; ix < pa; ix += 1 {
+		/*
+			Limit ourselves to `digits` DIGITs of output.
+		*/
+		pb    := min(b.used, digits - ix);
+		carry := DIGIT(0);
+		iy    := 0;
+		/*
+			Compute the column of the output and propagate the carry.
+		*/
+		for iy = 0; iy < pb; iy += 1 {
+			/*
+				Compute the column as a _WORD.
+			*/
+			column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry;
 
-			/* the new column is the lower part of the result */
-			t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
+			/*
+				The new column is the lower part of the result.
+			*/
+			t.digit[ix + iy] = column & _MASK;
 
-			/* get the carry word from the result */
-			u       = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
+			/*
+				Get the carry word from the result.
+			*/
+			carry = column >> _DIGIT_BITS;
 		}
-		/* set carry if it is placed below digs */
-		if ((ix + iy) < digs) {
-			t.dp[ix + pb] = u;
+		/*
+			Set carry if it is placed below digits
+		*/
+		if ix + iy < digits {
+			t.digit[ix + pb] = carry;
 		}
 	}
 
-	mp_clamp(&t);
-	mp_exch(&t, c);
-
-	mp_clear(&t);
-	return MP_OKAY;
-}
-
-*/
-	return .None;
+	swap(dest, t);
+	destroy(t);
+	return clamp(dest);
 }

+ 7 - 9
core/math/big/example.odin

@@ -57,17 +57,15 @@ demo :: proc() {
 	a, b, c := &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c);
 
-	err = set(a, -512);
-	err = set(b, 1024);
+	err = set(a, -1024);
+	err = set(b, -1024);
 
-	print("a", a, 16);
-	print("b", b, 16);
+	print("a", a, 10);
+	print("b", b, 10);
 
-	fmt.println("--- swap ---");
-	foo(a, b);
-
-	print("a", a, 16);
-	print("b", b, 16);
+	fmt.println("--- mul ---");
+	mul(c, a, b);
+	print("c", c, 10);
 
 }