瀏覽代碼

big: Finish `log`, fix `sqr`.

Jeroen van Rijn 4 年之前
父節點
當前提交
31c94bd7f8
共有 4 個文件被更改,包括 113 次插入33 次删除
  1. 3 4
      core/math/big/basic.odin
  2. 2 8
      core/math/big/example.odin
  3. 14 1
      core/math/big/helpers.odin
  4. 94 20
      core/math/big/log.odin

+ 3 - 4
core/math/big/basic.odin

@@ -898,7 +898,7 @@ _int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
 	/*
 		Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
 	*/
-	if err = grow(t, min((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
+	if err = grow(t, max((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
 	t.used = (2 * pa) + 1;
 
 	for ix = 0; ix < pa; ix += 1 {
@@ -906,13 +906,12 @@ _int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
 		/*
 			First calculate the digit at 2*ix; calculate double precision result.
 		*/
-		r := _WORD(t.digit[ix+ix]) + _WORD(src.digit[ix] * src.digit[ix]);
+		r := _WORD(t.digit[ix+ix]) + (_WORD(src.digit[ix]) * _WORD(src.digit[ix]));
 
 		/*
 			Store lower part in result.
 		*/
 		t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
-
 		/*
 			Get the carry.
 		*/
@@ -924,7 +923,7 @@ _int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
 			*/
 			r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
 
-			/* Now calculate the double precision result. Nte we use
+			/* Now calculate the double precision result. Nóte we use
 			 * addition instead of *2 since it's easier to optimize
 			 */
 			r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);

+ 2 - 8
core/math/big/example.odin

@@ -57,14 +57,8 @@ demo :: proc() {
 	a, b, c := &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c);
 
-	for base in -3..=3 {
-		for power in -3..=3 {
-			err = pow(a, base, power);
-			fmt.printf("err: %v | pow(%v, %v) = ", err, base, power); print("", a, 10);
-		}
-	}
-
-
+	err = set(a, 5125);
+	print("a", a);
 }
 
 main :: proc() {

+ 14 - 1
core/math/big/helpers.odin

@@ -75,7 +75,7 @@ int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error
 	/*
 		Copy everything over and zero high digits.
 	*/
-	for v, i in src.digit[:src.used+1] {
+	for v, i in src.digit[:src.used] {
 		dest.digit[i] = v;
 	}
 	dest.used = src.used;
@@ -533,6 +533,19 @@ clear_if_uninitialized :: proc(dest: ^Int, minimize := false) -> (err: Error) {
 	return .None;
 }
 
+/*
+	Allocates several `Int`s at once.
+*/
+_int_init_multi :: proc(integers: ..^Int) -> (err: Error) {
+	integers := integers;
+	for a in &integers {
+		if err = clear(a); err != .None { return err; }
+	}
+	return .None;
+}
+
+_init_multi :: proc { _int_init_multi, };
+
 _copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) {
 	digits := digits;
 	if err = clear_if_uninitialized(src);  err != .None { return err; }

+ 94 - 20
core/math/big/log.odin

@@ -13,36 +13,22 @@ int_log :: proc(a: ^Int, base: DIGIT) -> (res: int, err: Error) {
 	if base < 2 || DIGIT(base) > _DIGIT_MAX {
 		return -1, .Invalid_Argument;
 	}
-
-	if err = clear_if_uninitialized(a); err != .None {
-		return -1, err;
-	}
-	if n, _ := is_neg(a); n {
-		return -1, .Invalid_Argument;
-	}
-	if z, _ := is_zero(a); z {
-		return -1, .Invalid_Argument;
-	}
+	if err = clear_if_uninitialized(a); err != .None { return -1, err; }
+	if n, _ := is_neg(a);  n { return -1, .Invalid_Argument; }
+	if z, _ := is_zero(a); z { return -1, .Invalid_Argument; }
 
 	/*
 		Fast path for bases that are a power of two.
 	*/
-	if is_power_of_two(int(base)) {
-		return _log_power_of_two(a, base);
-	}
+	if is_power_of_two(int(base)) { return _log_power_of_two(a, base); }
 
 	/*
 		Fast path for `Int`s that fit within a single `DIGIT`.
 	*/
-	if a.used == 1 {
-		return log(a.digit[0], DIGIT(base));
-	}
+	if a.used == 1 { return log(a.digit[0], DIGIT(base)); }
 
-    // if (MP_HAS(S_MP_LOG)) {
-    //    return s_mp_log(a, (mp_digit)base, c);
-    // }
+	return _int_log(a, base);
 
-	return -1, .Unimplemented;
 }
 
 log :: proc { int_log, int_log_digit, };
@@ -222,4 +208,92 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
    	} else {
    		return low, .None;
    	}
+}
+
+
+/*
+	Internal implementation of log.	
+*/
+_int_log :: proc(a: ^Int, base: DIGIT) -> (res: int, err: Error) {
+	bracket_low, bracket_high, bracket_mid, t, bi_base := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
+
+	cnt := 0;
+
+	ic, _ := cmp(a, base);
+	if ic == -1 || ic == 0 {
+		return 1 if ic == 0 else 0, .None;
+	}
+
+	if err = set(bi_base, base);          err != .None { return -1, err; }
+	if err = _init_multi(bracket_mid, t); err != .None { return -1, err; }
+	if err = one(bracket_low);            err != .None { return -1, err; }
+	if err = set(bracket_high, base);     err != .None { return -1, err; }
+
+	low  := 0; high := 1;
+
+	/*
+		A kind of Giant-step/baby-step algorithm.
+		Idea shamelessly stolen from https://programmingpraxis.com/2010/05/07/integer-logarithms/2/
+		The effect is asymptotic, hence needs benchmarks to test if the Giant-step should be skipped
+		for small n.
+	*/
+
+	for {
+		/*
+			Iterate until `a` is bracketed between low + high.
+		*/
+		if bc, _ := cmp(bracket_high, a); bc != -1 {
+			break;
+		}
+
+	 	low = high;
+	 	if err = copy(bracket_low, bracket_high); err != .None {
+			destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return -1, err;
+	 	}
+	 	high <<= 1;
+	 	if err = sqr(bracket_high, bracket_high); err != .None {
+			destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return -1, err;
+	 	}
+
+	 	cnt += 1;
+	 	if cnt == 7 {
+		 	destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return -2, .Max_Iterations_Reached;
+	 	}
+	}
+
+	for (high - low) > 1 {
+		mid := (high + low) >> 1;
+
+		if err = pow(t, bi_base, mid - low); err != .None {
+			destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return -1, err;
+		}
+
+		if err = mul(bracket_mid, bracket_low, t); err != .None {
+			destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return -1, err;
+		}
+		mc, _ := cmp(a, bracket_mid);
+		if mc == -1 {
+			high = mid;
+			swap(bracket_mid, bracket_high);
+		}
+		if mc == 1 {
+			low = mid;
+			swap(bracket_mid, bracket_low);
+		}
+		if mc == 0 {
+			destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+			return mid, .None;
+		}
+	}
+
+	fc, _ := cmp(bracket_high, a);
+	res = high if fc == 0 else low;
+
+	destroy(bracket_low, bracket_high, bracket_mid, t, bi_base);
+	return;
 }