Bladeren bron

big: Squashed shl1 bug when a larger dest was reused for a smaller result.

Jeroen van Rijn 4 jaren geleden
bovenliggende
commit
4be48973ad

+ 10 - 39
core/math/big/basic.odin

@@ -152,45 +152,18 @@ int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: E
 	/*
 		Early out if neither of the results is wanted.
 	*/
-	if quotient == nil && remainder == nil 		        { return nil; }
+	if quotient == nil && remainder == nil { return nil; }
+	if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
 
-	if err = clear_if_uninitialized(numerator);			err != nil { return err; }
-	if err = clear_if_uninitialized(denominator);		err != nil { return err; }
-
-	z: bool;
-	if z, err = is_zero(denominator);                   z { return .Division_by_Zero; }
-
-	/*
-		If numerator < denominator then quotient = 0, remainder = numerator.
-	*/
-	c: int;
-	if c, err = cmp_mag(numerator, denominator); c == -1 {
-		if remainder != nil {
-			if err = copy(remainder, numerator); 		err != nil { return err; }
-		}
-		if quotient != nil {
-			zero(quotient);
-		}
-		return nil;
-	}
-
-	if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) {
-		// err = _int_div_recursive(quotient, remainder, numerator, denominator);
-	} else {
-		err = _int_div_school(quotient, remainder, numerator, denominator);
-		/*
-			NOTE(Jeroen): We no longer need or use `_int_div_small`.
-			We'll keep it around for a bit.
-			err = _int_div_small(quotient, remainder, numerator, denominator);
-		*/
-	}
-
-	return err;
+	return #force_inline internal_int_divmod(quotient, remainder, numerator, denominator);
 }
 divmod :: proc{ int_divmod, };
 
 int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) {
-	return int_divmod(quotient, nil, numerator, denominator);
+	if quotient == nil { return .Invalid_Pointer; };
+	if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
+
+	return #force_inline internal_int_divmod(quotient, nil, numerator, denominator);
 }
 div :: proc { int_div, };
 
@@ -200,11 +173,10 @@ div :: proc { int_div, };
 	denominator < remainder <= 0 if denominator < 0
 */
 int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) {
-	if err = divmod(nil, remainder, numerator, denominator); err != nil { return err; }
+	if remainder == nil { return .Invalid_Pointer; };
+	if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
 
-	z: bool;
-	if z, err = is_zero(remainder); z || denominator.sign == remainder.sign { return nil; }
-	return add(remainder, remainder, numerator);
+	return #force_inline internal_int_mod(remainder, numerator, denominator);
 }
 
 int_mod_digit :: proc(numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
@@ -776,7 +748,6 @@ _int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (e
 			t2.used = 3;
 
 			if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 {
-
 				break;
 			}
 			iter += 1; if iter > 100 { return .Max_Iterations_Reached; }

+ 1 - 1
core/math/big/build.bat

@@ -1,5 +1,5 @@
 @echo off
-:odin run . -vet-more
+:odin run . -vet
 : -o:size -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:size -no-bounds-check

+ 5 - 19
core/math/big/example.odin

@@ -62,30 +62,16 @@ print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline
 }
 
 demo :: proc() {
-	err: Error;
-	as: string;
-	defer delete(as);
-
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 
-	err = factorial(a, 1224);
-	count, _ := count_bits(a);
+	foo := "686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214";
+	err := atoi(a, foo, 10);
 
-	bits :=  51;
-	be1: _WORD;
+	if err != nil do fmt.printf("atoi returned %v\n", err);
+
+	print("foo: ", a);
 
-	/*
-		Timing loop
-	*/
-	{
-		SCOPED_TIMING(.bitfield_extract);
-		for o := 0; o < count - bits; o += 1 {
-			be1, _ = int_bitfield_extract(a, o, bits);
-		}
-	}
-	SCOPED_COUNT_ADD(.bitfield_extract, count - bits - 1);
-	fmt.printf("be1: %v\n", be1);
 }
 
 main :: proc() {

+ 1 - 1
core/math/big/exp_log.odin

@@ -253,7 +253,7 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
 		swap(dest, x);
 		return err;
 	} else {
-		// return root_n(dest, src, 2);
+		return root_n(dest, src, 2);
 	}
 }
 sqrt :: proc { int_sqrt, };

+ 85 - 39
core/math/big/internal.odin

@@ -491,49 +491,30 @@ internal_int_shr1 :: proc(dest, src: ^Int) -> (err: Error) {
 	dest = src << 1
 */
 internal_int_shl1 :: proc(dest, src: ^Int) -> (err: Error) {
-	old_used  := dest.used; dest.used  = src.used + 1;
-
+	if err = copy(dest, src); err != nil { return err; }
 	/*
-		Forward carry
+		Grow `dest` to accommodate the additional bits.
 	*/
-	carry := DIGIT(0);
-	#no_bounds_check for x := 0; x < src.used; x += 1 {
-		/*
-			Get what will be the *next* carry bit from the MSB of the current digit.
-		*/
-		src_digit := src.digit[x];
-		fwd_carry := src_digit >> (_DIGIT_BITS - 1);
+	digits_needed := dest.used + 1;
+	if err = grow(dest, digits_needed); err != nil { return err; }
+	dest.used = digits_needed;
 
-		/*
-			Now shift up this digit, add in the carry [from the previous]
-		*/
-		dest.digit[x] = (src_digit << 1 | carry) & _MASK;
+	mask  := (DIGIT(1) << uint(1)) - DIGIT(1);
+	shift := DIGIT(_DIGIT_BITS - 1);
+	carry := DIGIT(0);
 
-		/*
-			Update carry
-		*/
+	#no_bounds_check for x:= 0; x < dest.used; x+= 1 {		
+		fwd_carry := (dest.digit[x] >> shift) & mask;
+		dest.digit[x] = (dest.digit[x] << uint(1) | carry) & _MASK;
 		carry = fwd_carry;
 	}
 	/*
-		New leading digit?
+		Use final carry.
 	*/
 	if carry != 0 {
-		/*
-			Add a MSB which is always 1 at this point.
-		*/
-		dest.digit[dest.used] = 1;
+		dest.digit[dest.used] = carry;
+		dest.used += 1;
 	}
-	zero_count := old_used - dest.used;
-	/*
-		Zero remainder.
-	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
-	/*
-		Adjust dest.used based on leading zeroes.
-	*/
-	dest.sign = src.sign;
 	return clamp(dest);
 }
 
@@ -552,7 +533,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
 		Power of two?
 	*/
 	if multiplier == 2 {
-		return #force_inline shl1(dest, src);
+		return #force_inline internal_int_shl1(dest, src);
 	}
 	if is_power_of_two(int(multiplier)) {
 		ix: int;
@@ -581,7 +562,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
 		Compute columns.
 	*/
 	ix := 0;
-	#no_bounds_check for ; ix < src.used; ix += 1 {
+	for ; ix < src.used; ix += 1 {
 		/*
 			Compute product and carry sum for this term
 		*/
@@ -600,13 +581,15 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
 		Store final carry [if any] and increment used.
 	*/
 	dest.digit[ix] = DIGIT(carry);
+
 	dest.used = src.used + 1;
 	/*
 		Zero unused digits.
 	*/
+	//_zero_unused(dest);
 	zero_count := old_used - dest.used;
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[zero_count:]);
+	if zero_count > 0  {
+	 	mem.zero_slice(dest.digit[dest.used:][:zero_count]);
 	}
 	return clamp(dest);
 }
@@ -675,9 +658,72 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			err = _int_mul(dest, src, multiplier, digits);
 		}
 	}
-	neg      := src.sign != multiplier.sign;
+	neg := src.sign != multiplier.sign;
 	dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
 	return err;
 }
 
-internal_mul :: proc { internal_int_mul, internal_int_mul_digit, };
+internal_mul :: proc { internal_int_mul, internal_int_mul_digit, };
+
+/*
+	divmod.
+	Both the quotient and remainder are optional and may be passed a nil.
+*/
+internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, allocator := context.allocator) -> (err: Error) {
+
+	if denominator.used == 0 { return .Division_by_Zero; }
+	/*
+		If numerator < denominator then quotient = 0, remainder = numerator.
+	*/
+	c: int;
+	if c, err = #force_inline cmp_mag(numerator, denominator); c == -1 {
+		if remainder != nil {
+			if err = copy(remainder, numerator, false, allocator); err != nil { return err; }
+		}
+		if quotient != nil {
+			zero(quotient);
+		}
+		return nil;
+	}
+
+	if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) {
+		// err = _int_div_recursive(quotient, remainder, numerator, denominator);
+	} else {
+		when true {
+			err = _int_div_school(quotient, remainder, numerator, denominator);
+		} else {
+			/*
+				NOTE(Jeroen): We no longer need or use `_int_div_small`.
+				We'll keep it around for a bit until we're reasonably certain div_school is bug free.
+				err = _int_div_small(quotient, remainder, numerator, denominator);
+			*/
+			err = _int_div_small(quotient, remainder, numerator, denominator);
+		}
+	}
+	return;
+}
+internal_divmod :: proc { internal_int_divmod, };
+
+/*
+	Asssumes quotient, numerator and denominator to have been initialized and not to be nil.
+*/
+internal_int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) {
+	return #force_inline internal_int_divmod(quotient, nil, numerator, denominator);
+}
+internal_div :: proc { internal_int_div, };
+
+/*
+	remainder = numerator % denominator.
+	0 <= remainder < denominator if denominator > 0
+	denominator < remainder <= 0 if denominator < 0
+
+	Asssumes quotient, numerator and denominator to have been initialized and not to be nil.
+*/
+internal_int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) {
+	if err = #force_inline internal_int_divmod(nil, remainder, numerator, denominator); err != nil { return err; }
+
+	if remainder.used == 0 || denominator.sign == remainder.sign { return nil; }
+
+	return #force_inline internal_add(remainder, remainder, numerator);
+}
+internal_mod :: proc{ internal_int_mod, };

+ 2 - 2
core/math/big/test.odin

@@ -24,9 +24,9 @@ PyRes :: struct {
 	err: Error,
 }
 
-@export test_initialize_constants :: proc "c" () -> (res: int) {
+@export test_initialize_constants :: proc "c" () -> (res: u64) {
 	context = runtime.default_context();
-	return initialize_constants();
+	return u64(initialize_constants());
 }
 
 @export test_error_string :: proc "c" (err: Error) -> (res: cstring) {

+ 9 - 1
core/math/big/test.py

@@ -254,7 +254,13 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
 
 def test_sqrt(number = 0, expected_error = Error.Okay):
 	args  = [arg_to_odin(number)]
-	res   = int_sqrt(*args)
+	try:
+		res = int_sqrt(*args)
+	except OSError as e:
+		print("{} while trying to sqrt {}.".format(e, number))
+		if EXIT_ON_FAIL: exit(3)
+		return False
+
 	expected_result = None
 	if expected_error == Error.Okay:
 		if number < 0:
@@ -384,6 +390,7 @@ TESTS = {
 		[ 54321,	12345],
 		[ 55431,		0, Error.Division_by_Zero],
 		[ 12980742146337069150589594264770969721, 4611686018427387904 ],
+		[   831956404029821402159719858789932422, 243087903122332132 ],
 	],
 	test_log: [
 		[ 3192,			1, Error.Invalid_Argument],
@@ -405,6 +412,7 @@ TESTS = {
 		[  42, Error.Okay, ],
 		[  12345678901234567890, Error.Okay, ],
 		[  1298074214633706907132624082305024, Error.Okay, ],
+		[  686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214, Error.Okay, ],
 	],
 	test_root_n: [
 		[  1298074214633706907132624082305024, 2, Error.Okay, ],