Browse Source

big: Add `root_n`.

Jeroen van Rijn 4 years ago
parent
commit
531c4936dd
3 changed files with 195 additions and 51 deletions
  1. 0 44
      core/math/big/basic.odin
  2. 11 7
      core/math/big/example.odin
  3. 184 0
      core/math/big/exp_log.odin

+ 0 - 44
core/math/big/basic.odin

@@ -749,50 +749,6 @@ int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) {
 }
 sqrmod :: proc { int_sqrmod, };
 
-/*
-	This function is less generic than `nth_root`, simpler and faster.
-*/
-int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
-	if err = clear_if_uninitialized(dest);			err != .None { return err; }
-	if err = clear_if_uninitialized(src);			err != .None { return err; }
-
-	/*						Must be positive. 					*/
-	if src.sign == .Negative						{ return .Invalid_Argument; }
-
-	/*			Easy out. If src is zero, so is dest.			*/
-	if z, _ := is_zero(src); 						z { return zero(dest); }
-
-	/*						Set up temporaries.					*/
-	t1, t2 := &Int{}, &Int{};
-	defer destroy(t1, t2);
-
-	if err = copy(t1, src);							err != .None { return err; }
-	if err = zero(t2);								err != .None { return err; }
-
-	/*	First approximation. Not very bad for large arguments.	*/
-	if err = shr_digit(t1, t1.used / 2);			err != .None { return err; }
-	/*							t1 > 0 							*/
-	if err = div(t2, src, t1);						err != .None { return err; }
-	if err = add(t1, t1, t2);						err != .None { return err; }
-	if err = shr(t1, t1, 1);						err != .None { return err; }
-
-	/*					And now t1 > sqrt(arg).					*/
-	for {
-		if err = div(t2, src, t1);						err != .None { return err; }
-		if err = add(t1, t1, t2);						err != .None { return err; }
-		if err = shr(t1, t1, 1);						err != .None { return err; }
-		/* t1 >= sqrt(arg) >= t2 at this point */
-
-		cm, _ := cmp_mag(t1, t2);
-		if cm != 1 { break; }
-	}
-
-	swap(dest, t1);
-	return err;
-}
-
-sqrt :: proc { int_sqrt, };
-
 /*
 	==========================
 		Low-level routines    

+ 11 - 7
core/math/big/example.odin

@@ -63,15 +63,19 @@ demo :: proc() {
 	// defer delete(string_buffer);
 
 	err = set (numerator,   1024);
-	err = int_sqrt(destination, numerator);
+	err = sqrt(destination, numerator);
 	fmt.printf("int_sqrt returned: %v\n", err);
 
-	print("destination", destination);
-	// print("source     ", source);
-	// print("quotient   ", quotient);
-	// print("remainder  ", remainder);
-	print("numerator  ", numerator);
-	// print("denominator", denominator);
+	print("num      ", numerator);
+	print("sqrt(num)", destination);
+
+	fmt.println("\n\n");
+
+	err = root_n(destination, numerator, 2);
+	fmt.printf("root_n(2) returned: %v\n", err);
+
+	print("num        ", numerator);
+	print("root_n(num)", destination);
 }
 
 main :: proc() {

+ 184 - 0
core/math/big/log.odin → core/math/big/exp_log.odin

@@ -210,6 +210,190 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
    	}
 }
 
+/*
+	This function is less generic than `root_n`, simpler and faster.
+*/
+int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
+	if err = clear_if_uninitialized(dest);			err != .None { return err; }
+	if err = clear_if_uninitialized(src);			err != .None { return err; }
+
+	/*						Must be positive. 					*/
+	if src.sign == .Negative						{ return .Invalid_Argument; }
+
+	/*			Easy out. If src is zero, so is dest.			*/
+	if z, _ := is_zero(src); 						z { return zero(dest); }
+
+	/*						Set up temporaries.					*/
+	t1, t2 := &Int{}, &Int{};
+	defer destroy(t1, t2);
+
+	if err = copy(t1, src);							err != .None { return err; }
+	if err = zero(t2);								err != .None { return err; }
+
+	/*	First approximation. Not very bad for large arguments.	*/
+	if err = shr_digit(t1, t1.used / 2);			err != .None { return err; }
+	/*							t1 > 0 							*/
+	if err = div(t2, src, t1);						err != .None { return err; }
+	if err = add(t1, t1, t2);						err != .None { return err; }
+	if err = shr(t1, t1, 1);						err != .None { return err; }
+
+	/*					And now t1 > sqrt(arg).					*/
+	for {
+		if err = div(t2, src, t1);						err != .None { return err; }
+		if err = add(t1, t1, t2);						err != .None { return err; }
+		if err = shr(t1, t1, 1);						err != .None { return err; }
+		/* t1 >= sqrt(arg) >= t2 at this point */
+
+		cm, _ := cmp_mag(t1, t2);
+		if cm != 1 { break; }
+	}
+
+	swap(dest, t1);
+	return err;
+}
+sqrt :: proc { int_sqrt, };
+
+
+/*
+	Find the nth root of an Integer.
+ 	Result found such that `(dest)**n <= src` and `(dest+1)**n > src`
+
+	This algorithm uses Newton's approximation `x[i+1] = x[i] - f(x[i])/f'(x[i])`,
+  	which will find the root in `log(n)` time where each step involves a fair bit.
+*/
+int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
+	/*						Fast path for n == 2 						*/
+	if n == 2 { return sqrt(dest, src); }
+
+	/*					Initialize dest + src if needed. 				*/
+	if err = clear_if_uninitialized(dest);			err != .None { return err; }
+	if err = clear_if_uninitialized(src);			err != .None { return err; }
+
+	if n < 0 || n > int(_DIGIT_MAX) {
+		return .Invalid_Argument;
+	}
+
+	neg: bool;
+	if n & 1 == 0 {
+		if neg, err = is_neg(src); neg || err != .None { return .Invalid_Argument; }
+	}
+
+	/*							Set up temporaries.						*/
+	t1, t2, t3, a := &Int{}, &Int{}, &Int{}, &Int{};
+	defer destroy(t1, t2, t3);
+
+	/*			If a is negative fudge the sign but keep track.			*/
+	a.sign  = .Zero_or_Positive;
+	a.used  = src.used;
+	a.digit = src.digit;
+
+	/*
+	  If "n" is larger than INT_MAX it is also larger than
+	  log_2(src) because the bit-length of the "src" is measured
+	  with an int and hence the root is always < 2 (two).
+	*/
+	if n > max(int) / 2 {
+		err = set(dest, 1);
+		dest.sign = a.sign;
+		return err;
+	}
+
+	/*					Compute seed: 2^(log_2(src)/n + 2)				*/
+	ilog2: int;
+	ilog2, err = count_bits(src);
+
+	/*			"src" is smaller than max(int), we can cast safely.		*/
+	if ilog2 < n {
+		err = set(dest, 1);
+		dest.sign = a.sign;
+		return err;
+	}
+
+	ilog2 /= n;
+	if ilog2 == 0 {
+		err = set(dest, 1);
+		dest.sign = a.sign;
+		return err;
+	}
+
+	/*					Start value must be larger than root.			*/
+	ilog2 += 2;
+	if err = power_of_two(t2, ilog2); err != .None { return err; }
+
+	c: int;
+	for {
+		/* t1 = t2 */
+		if err = copy(t1, t2); err != .None { return err; }
+
+		/* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */
+
+		/* t3 = t1**(b-1) */
+		if err = pow(t3, t1, n-1); err != .None { return err; }
+
+		/* numerator */
+		/* t2 = t1**b */
+		if err = mul(t2, t1, t3); err != .None { return err; }
+
+		/* t2 = t1**b - a */
+		if err = sub(t2, t2, a); err != .None { return err; }
+
+		/* denominator */
+		/* t3 = t1**(b-1) * b  */
+		if err = mul(t3, t3, DIGIT(n)); err != .None { return err; }
+
+		/* t3 = (t1**b - a)/(b * t1**(b-1)) */
+		if err = div(t3, t2, t3); err != .None { return err; }
+		if err = sub(t2, t1, t3); err != .None { return err; }
+
+		/*
+			 Number of rounds is at most log_2(root). If it is more it
+			 got stuck, so break out of the loop and do the rest manually.
+		*/
+		if ilog2 -= 1; ilog2 == 0 {
+			break;
+		}
+		if c, err = cmp(t1, t2); c == 0 { break; }
+	}
+
+	/*						Result can be off by a few so check.					*/
+	/* Loop beneath can overshoot by one if found root is smaller than actual root. */
+
+	for {
+		if err = pow(t2, t1, n); err != .None { return err; }
+
+		c, err = cmp(t2, a);
+		if c == 0 {
+			swap(dest, t1);
+			return .None;
+		} else if c == -1 {
+			if err = add(t1, t1, DIGIT(1)); err != .None { return err; }
+		} else {
+			break;
+		}
+	}
+
+	/*					Correct overshoot from above or from recurrence.			*/
+	for {
+		if err = pow(t2, t1, n); err != .None { return err; }
+
+		c, err = cmp(t2, a);
+		if c == 1 {
+			if err = sub(t1, t1, DIGIT(1)); err != .None { return err; }
+		} else {
+			break;
+		}
+	}
+
+	/*								Set the result.									*/
+	swap(dest, t1);
+
+	/* set the sign of the result */
+	dest.sign = src.sign;
+
+	return err;
+}
+root_n :: proc { int_root_n, };
+
 /*
 	Internal implementation of log.	
 */