Ver código fonte

bigint: Add substractin with immediate.

Jeroen van Rijn 4 anos atrás
pai
commit
daceaa65f5

+ 90 - 9
core/math/bigint/basic.odin

@@ -89,20 +89,20 @@ add_digit :: proc(dest, a: ^Int, digit: DIGIT) -> (err: Error) {
 		If `a` is negative and `|a|` >= `digit`, call `dest = |a| - digit`
 	*/
 	if is_neg(a) && (a.used > 1 || a.digit[0] >= digit) {
+		fmt.print("a = neg, %v\n", print_int(a));
 		/*
 			Temporarily fix `a`'s sign.
 		*/
-		a.sign = .Zero_or_Positive;
+		t := a;
+		t.sign = .Zero_or_Positive;
 		/*
 			dest = |a| - digit
 		*/
-		err = sub(dest, a, digit);
+		err = sub(dest, t, digit);
 		/*
 			Restore sign and set `dest` sign.
 		*/
 		dest.sign = .Negative;
-		a.sign    = .Negative;
-
 		clamp(dest);
 
 		return err;
@@ -208,11 +208,92 @@ sub_two_ints :: proc(dest, number, decrease: ^Int) -> (err: Error) {
 	Adds the unsigned `DIGIT` immediate to an `Int`,
 	such that the `DIGIT` doesn't have to be turned into an `Int` first.
 
-	dest = number - decrease;
+	dest = a - digit;
 */
-sub_digit :: proc(dest, number: ^Int, decrease: DIGIT) -> (err: Error) {
+sub_digit :: proc(dest, a: ^Int, digit: DIGIT) -> (err: Error) {
+	dest := dest; x := a; digit := digit;
+	assert_initialized(dest); assert_initialized(a);
+
+	/*
+		Fast paths for destination and input Int being the same.
+	*/
+	if dest == a {
+		/*
+			Fast path for `dest` is negative and unsigned addition doesn't overflow the lowest digit.
+		*/
+		if is_neg(dest) && (dest.digit[0] + digit < _DIGIT_MAX) {
+			dest.digit[0] += digit;
+			return .OK;
+		}
+		/*
+			Can be subtracted from dest.digit[0] without underflow.
+		*/
+		if is_pos(a) && (dest.digit[0] > digit) {
+			dest.digit[0] -= digit;
+			return .OK;
+		}
+	}
+
+	/*
+		Grow destination as required.
+	*/
+	err = grow(dest, a.used + 1);
+	if err != .OK {
+		return err;
+	}
+
+	/*
+		If `a` is negative, just do an unsigned addition (with fudged signs).
+	*/
+	if is_neg(a) {
+		t := a;
+		t.sign = .Zero_or_Positive;
+
+		err = add(dest, t, digit);
+		dest.sign = .Negative;
 
-	return .Unimplemented;
+		clamp(dest);
+		return err;
+	}
+
+	old_used := dest.used;
+
+	/*
+		if `a`<= digit, simply fix the single digit.
+	*/
+	if a.used == 1 && (a.digit[0] <= digit || is_zero(a)) {
+		dest.digit[0] = digit - a.digit[0] if a.used == 1 else digit;
+		dest.sign = .Negative;
+		dest.used = 1;
+	} else {
+		dest.sign = .Zero_or_Positive;
+		dest.used = a.used;
+
+		/*
+			Subtract with carry.
+		*/
+		carry := digit;
+
+		for i := 0; i < a.used; i += 1 {
+			dest.digit[i] = a.digit[i] - carry;
+			carry := dest.digit[i] >> ((size_of(DIGIT) * 8) - 1);
+			dest.digit[i] &= _MASK;
+		}
+   	}
+
+	zero_count := old_used - dest.used;
+	/*
+		Zero remainder.
+	*/
+	if zero_count > 0 {
+		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
+	}
+	/*
+		Adjust dest.used based on leading zeroes.
+	*/
+	clamp(dest);
+
+	return .OK;
 }
 
 sub :: proc{sub_two_ints, sub_digit};
@@ -333,7 +414,7 @@ _sub :: proc(dest, number, decrease: ^Int) -> (err: Error) {
 			it will propagate all the way to the MSB.
 			As a result a single shift is enough to get the carry.
 		*/
-		borrow = dest.digit[i] >> (_DIGIT_BITS - 1);
+		borrow = dest.digit[i] >> ((size_of(DIGIT) * 8) - 1);
 		/*
 			Clear borrow from dest[i].
 		*/
@@ -351,7 +432,7 @@ _sub :: proc(dest, number, decrease: ^Int) -> (err: Error) {
 			it will propagate all the way to the MSB.
 			As a result a single shift is enough to get the carry.
 		*/
-		borrow = dest.digit[i] >> (_DIGIT_BITS - 1);
+		borrow = dest.digit[i] >> ((size_of(DIGIT) * 8) - 1);
 		/*
 			Clear borrow from dest[i].
 		*/

+ 2 - 1
core/math/bigint/build.bat

@@ -1,2 +1,3 @@
 @echo off
-odin run . -vet
+odin run .
+rem -vet

+ 3 - 3
core/math/bigint/example.odin

@@ -30,11 +30,11 @@ demo :: proc() {
 	a, b, c: ^Int;
 	err:  Error;
 
-	a, err = init(21);
+	a, err = init(512);
 	defer destroy(a);
 	fmt.printf("a: %v, err: %v\n\n", print_int(a), err);
 
-	b, err = init(21);
+	b, err = init(42);
 	defer destroy(b);
 
 	fmt.printf("b: %v, err: %v\n\n", print_int(b), err);
@@ -44,7 +44,7 @@ demo :: proc() {
 	fmt.printf("c: %v\n", print_int(c, true));
 
 	fmt.println("=== Add ===");
-	err = add(c, a, DIGIT(42));
+	err = sub(a, a, DIGIT(42));
 	// err = add(c, a, b);
 	fmt.printf("Error: %v\n", err);
 	fmt.printf("a: %v\n", print_int(a));