Browse Source

big: Add `sqrt`.

Jeroen van Rijn 4 years ago
parent
commit
2aae1016ab
4 changed files with 77 additions and 37 deletions
  1. 45 0
      core/math/big/basic.odin
  2. 13 36
      core/math/big/example.odin
  3. 18 0
      core/math/big/logical.odin
  4. 1 1
      core/math/big/radix.odin

+ 45 - 0
core/math/big/basic.odin

@@ -13,6 +13,7 @@ package big
 
 
 import "core:mem"
 import "core:mem"
 import "core:intrinsics"
 import "core:intrinsics"
+
 /*
 /*
 	===========================
 	===========================
 		User-level routines    
 		User-level routines    
@@ -748,6 +749,50 @@ int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) {
 }
 }
 sqrmod :: proc { int_sqrmod, };
 sqrmod :: proc { int_sqrmod, };
 
 
+/*
+	This function is less generic than `nth_root`, simpler and faster.
+*/
+int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
+	if err = clear_if_uninitialized(dest);			err != .None { return err; }
+	if err = clear_if_uninitialized(src);			err != .None { return err; }
+
+	/*						Must be positive. 					*/
+	if src.sign == .Negative						{ return .Invalid_Argument; }
+
+	/*			Easy out. If src is zero, so is dest.			*/
+	if z, _ := is_zero(src); 						z { return zero(dest); }
+
+	/*						Set up temporaries.					*/
+	t1, t2 := &Int{}, &Int{};
+	defer destroy(t1, t2);
+
+	if err = copy(t1, src);							err != .None { return err; }
+	if err = zero(t2);								err != .None { return err; }
+
+	/*	First approximation. Not very bad for large arguments.	*/
+	if err = shr_digit(t1, t1.used / 2);			err != .None { return err; }
+	/*							t1 > 0 							*/
+	if err = div(t2, src, t1);						err != .None { return err; }
+	if err = add(t1, t1, t2);						err != .None { return err; }
+	if err = shr(t1, t1, 1);						err != .None { return err; }
+
+	/*					And now t1 > sqrt(arg).					*/
+	for {
+		if err = div(t2, src, t1);						err != .None { return err; }
+		if err = add(t1, t1, t2);						err != .None { return err; }
+		if err = shr(t1, t1, 1);						err != .None { return err; }
+		/* t1 >= sqrt(arg) >= t2 at this point */
+
+		cm, _ := cmp_mag(t1, t2);
+		if cm != 1 { break; }
+	}
+
+	swap(dest, t1);
+	return err;
+}
+
+sqrt :: proc { int_sqrt, };
+
 /*
 /*
 	==========================
 	==========================
 		Low-level routines    
 		Low-level routines    

+ 13 - 36
core/math/big/example.odin

@@ -51,15 +51,7 @@ print :: proc(name: string, a: ^Int, base := i8(10)) {
 	
 	
 }
 }
 
 
-num_threads :: 16;
-global_traces_indexes := [num_threads]u16{};
-@thread_local local_traces_index : ^u16;
-
-init_thread_tracing :: proc(thread_id: u8) {
-    
-    fmt.printf("%p\n", &global_traces_indexes[thread_id]);
-    fmt.printf("%p\n", local_traces_index);
-}
+@thread_local string_buffer: [1024]u8;
 
 
 demo :: proc() {
 demo :: proc() {
 	err: Error;
 	err: Error;
@@ -67,34 +59,19 @@ demo :: proc() {
 	destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(destination, source, quotient, remainder, numerator, denominator);
 	defer destroy(destination, source, quotient, remainder, numerator, denominator);
 
 
-	err = set (numerator,   2);
-	err = set (denominator, 1);
-	err = set (quotient,    u128(1 << 120));
-	err = zero(remainder);
-	err = pow(remainder, numerator, 120);
-	if err != .None {
-		fmt.printf("Error: %v\n", err);
-	} else {
-		print("numerator  ", numerator,   10);
-		print("denominator", denominator, 10);
-		print("quotient   ", quotient,    10);
-		print("remainder  ", remainder,   10);
-	}
-	if c, _ := cmp(quotient, remainder); c == 0 {
-		fmt.println("c == r");
-	} else {
-		fmt.println("c != r");
-	}
+	// string_buffer := make([]u8, 1024);
+	// defer delete(string_buffer);
 
 
-	foozle := "-1329227995784915872903807060280344576";
-	err = atoi(destination, foozle, 10);
-	if err != .None {
-		fmt.printf("Error %v while parsing `%v`", err, foozle);
-	} else {
-		print("destination", destination);
-		err = add(remainder, remainder, destination);
-		print("remainder + destination", remainder);
-	}
+	err = set (numerator,   1024);
+	err = int_sqrt(destination, numerator);
+	fmt.printf("int_sqrt returned: %v\n", err);
+
+	print("destination", destination);
+	// print("source     ", source);
+	// print("quotient   ", quotient);
+	// print("remainder  ", remainder);
+	print("numerator  ", numerator);
+	// print("denominator", denominator);
 }
 }
 
 
 main :: proc() {
 main :: proc() {

+ 18 - 0
core/math/big/logical.odin

@@ -348,6 +348,24 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
 }
 }
 shr_digit :: proc { int_shr_digit, };
 shr_digit :: proc { int_shr_digit, };
 
 
+/*
+	Shift right by a certain bit count with sign extension.
+*/
+int_shr_signed :: proc(dest, src: ^Int, bits: int) -> (err: Error) {
+	if err = clear_if_uninitialized(src);	err != .None { return err; }
+	if err = clear_if_uninitialized(dest);	err != .None { return err; }
+
+	if src.sign == .Zero_or_Positive {
+		return shr(dest, src, bits);
+	}
+	if err = add(dest, src, DIGIT(1));		err != .None { return err; }
+
+	if err = shr(dest, dest, bits);			err != .None { return err; }
+	return sub(dest, src, DIGIT(1));
+}
+
+shr_signed :: proc { int_shr_signed, };
+
 /*
 /*
 	Shift left by a certain bit count.
 	Shift left by a certain bit count.
 */
 */

+ 1 - 1
core/math/big/radix.odin

@@ -46,7 +46,7 @@ int_itoa_string :: proc(a: ^Int, radix := i8(-1), zero_terminate := false, alloc
 	/*
 	/*
 		Allocate the buffer we need.
 		Allocate the buffer we need.
 	*/
 	*/
-	buffer := make([]u8, size);
+	buffer := make([]u8, size, allocator);
 
 
 	/*
 	/*
 		Write the digits out into the buffer.
 		Write the digits out into the buffer.