浏览代码

big: Correct `pow` bugs from the original.

Jeroen van Rijn 4 年之前
父节点
当前提交
5f63e3952e
共有 3 个文件被更改,包括 90 次插入38 次删除
  1. 3 0
      core/math/big/common.odin
  2. 6 7
      core/math/big/example.odin
  3. 81 31
      core/math/big/log.odin

+ 3 - 0
core/math/big/common.odin

@@ -60,6 +60,9 @@ Error :: enum byte {
 	Buffer_Overflow        = 6,
 	Integer_Overflow       = 7,
 
+	Division_by_Zero       = 8,
+	Math_Domain_Error      = 9,
+
 	Unimplemented          = 127,
 };
 

+ 6 - 7
core/math/big/example.odin

@@ -57,15 +57,14 @@ demo :: proc() {
 	a, b, c := &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c);
 
-	err = set(a, -1024);
-	err = set(b, -1024);
+	for base in -3..=3 {
+		for power in -3..=3 {
+			err = pow(a, base, power);
+			fmt.printf("err: %v | pow(%v, %v) = ", err, base, power); print("", a, 10);
+		}
+	}
 
-	print("a", a, 10);
-	print("b", b, 10);
 
-	fmt.println("--- mul ---");
-	mul(c, a, a);
-	print("c", c, 10);
 }
 
 main :: proc() {

+ 81 - 31
core/math/big/log.odin

@@ -51,43 +51,93 @@ log :: proc { int_log, int_log_digit, };
 	Calculate c = a**b  using a square-multiply algorithm.
 */
 int_pow :: proc(dest, base: ^Int, power: int) -> (err: Error) {
-	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	power := power;
 	if err = clear_if_uninitialized(base); err != .None { return err; }
+	if err = clear_if_uninitialized(dest); err != .None { return err; }
+	/*
+		Early outs.
+	*/
+	if z, _ := is_zero(base); z {
+		/*
+			A zero base is a special case.
+		*/
+		if power  < 0 {
+			if err = zero(dest); err != .None { return err; }
+			return .Math_Domain_Error;
+		}
+		if power == 0 { return  one(dest); }
+		if power  > 0 { return zero(dest); }
+
+	}
+	if power < 0 {
+		/*
+			Fraction, so we'll return zero.
+		*/
+		return zero(dest);
+	}
+	switch(power) {
+	case 0:
+		/*
+			Any base to the power zero is one.
+		*/
+		return one(dest);
+	case 1:
+		/*
+			Any base to the power one is itself.
+		*/
+		return copy(dest, base);
+	case 2:
+		return sqr(dest, base);
+	}
+
+	g := &Int{};
+	if err = copy(g, base); err != .None { return err; }
+
+	/*
+		Set initial result.
+	*/
+	if err = set(dest, 1); err != .None { return err; }
+
+	loop: for power > 0 {
+		/*
+			If the bit is set, multiply.
+		*/
+		if power & 1 != 0 {
+			if err = mul(dest, g, dest); err != .None {
+				break loop;
+			}
+		}
+		/*
+			Square.
+		*/
+		if power > 1 {
+			if err = sqr(g, g); err != .None {
+				break loop;
+			}
+		}
+
+		/* shift to next bit */
+		power >>= 1;
+	}
 
-// 	if ((err = mp_init_copy(&g, a)) != MP_OKAY) {
-// 		return err;
-// 	}
-
-// 	/* set initial result */
-// 	mp_set(c, 1uL);
-
-// 	while (b > 0) {
-// 		/* if the bit is set multiply */
-// 		if ((b & 1) != 0) {
-// 			if ((err = mp_mul(c, &g, c)) != MP_OKAY) {
-// 				goto LBL_ERR;
-// 			}
-// 		}
-
-// 		/* square */
-// 		if (b > 1) {
-// 			if ((err = mp_sqr(&g, &g)) != MP_OKAY) {
-// 				goto LBL_ERR;
-// 			}
-// 		}
-
-// 		/* shift to next bit */
-// 		b >>= 1;
-// 	}
-
-// LBL_ERR:
-// 	mp_clear(&g);
-// 	return err;
+	destroy(g);
 	return err;
 }
 
+/*
+	Calculate c = a**b.
+*/
+int_pow_int :: proc(dest: ^Int, base, power: int) -> (err: Error) {
+	base_t := &Int{};
+	defer destroy(base_t);
+
+	if err = set(base_t, base); err != .None { return err; }
+
+	return int_pow(dest, base_t, power);
+}
 
-pow :: proc { int_pow, };
+pow :: proc { int_pow, int_pow_int, };
+exp :: pow;
 
 /*
 	Returns the log2 of an `Int`, provided `base` is a power of two.