Browse Source

big: Improved `zero_unused` helper.

Jeroen van Rijn 4 years ago
parent
commit
9890e7cfeb

+ 27 - 100
core/math/big/basic.odin

@@ -155,17 +155,33 @@ int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: E
 	if quotient == nil && remainder == nil { return nil; }
 	if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
 
-	return #force_inline internal_int_divmod(quotient, remainder, numerator, denominator);
+	return #force_inline internal_divmod(quotient, remainder, numerator, denominator);
 }
-divmod :: proc{ int_divmod, };
+
+int_divmod_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
+	if quotient == nil { return 0, .Invalid_Pointer; };
+	if err = clear_if_uninitialized(numerator); err != nil { return 0, err; }
+
+	return #force_inline internal_divmod(quotient, numerator, denominator);
+}
+divmod :: proc{ int_divmod, int_divmod_digit, };
 
 int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) {
 	if quotient == nil { return .Invalid_Pointer; };
 	if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
 
-	return #force_inline internal_int_divmod(quotient, nil, numerator, denominator);
+	return #force_inline internal_divmod(quotient, nil, numerator, denominator);
+}
+
+int_div_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT) -> (err: Error) {
+	if quotient == nil { return .Invalid_Pointer; };
+	if err = clear_if_uninitialized(numerator); err != nil { return err; }
+
+	remainder: DIGIT;
+	remainder, err = #force_inline internal_divmod(quotient, numerator, denominator);
+	return err;
 }
-div :: proc { int_div, };
+div :: proc { int_div, int_div_digit, };
 
 /*
 	remainder = numerator % denominator.
@@ -180,7 +196,7 @@ int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) {
 }
 
 int_mod_digit :: proc(numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
-	return _int_div_digit(nil, numerator, denominator);
+	return #force_inline internal_divmod(nil, numerator, denominator);
 }
 
 mod :: proc { int_mod, int_mod_digit, };
@@ -490,13 +506,8 @@ _int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
 	/*
 		Clear unused digits [that existed in the old copy of dest].
 	*/
-	zero_count := old_used - dest.used;
-	/*
-		Zero remainder.
-	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	zero_unused(dest, old_used);
+
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -848,92 +859,6 @@ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (er
 	return err;
 }
 
-/*
-	Single digit division (based on routine from MPI).
-*/
-_int_div_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
-	q := &Int{};
-	ix: int;
-
-	/*
-		Cannot divide by zero.
-	*/
-	if denominator == 0 {
-		return 0, .Division_by_Zero;
-	}
-
-	/*
-		Quick outs.
-	*/
-	if denominator == 1 || numerator.used == 0 {
-		err = nil;
-		if quotient != nil {
-			err = copy(quotient, numerator);
-		}
-		return 0, err;
-	}
-	/*
-		Power of two?
-	*/
-	if denominator == 2 {
-		if odd, _ := is_odd(numerator); odd {
-			remainder = 1;
-		}
-		if quotient == nil {
-			return remainder, nil;
-		}
-		return remainder, shr(quotient, numerator, 1);
-	}
-
-	if is_power_of_two(int(denominator)) {
-		ix = 1;
-		for ix < _DIGIT_BITS && denominator != (1 << uint(ix)) {
-			ix += 1;
-		}
-		remainder = numerator.digit[0] & ((1 << uint(ix)) - 1);
-		if quotient == nil {
-			return remainder, nil;
-		}
-
-		return remainder, shr(quotient, numerator, int(ix));
-	}
-
-	/*
-		Three?
-	*/
-	if denominator == 3 {
-		return _int_div_3(quotient, numerator);
-	}
-
-	/*
-		No easy answer [c'est la vie].  Just division.
-	*/
-	if err = grow(q, numerator.used); err != nil { return 0, err; }
-
-	q.used = numerator.used;
-	q.sign = numerator.sign;
-
-	w := _WORD(0);
-
-	for ix = numerator.used - 1; ix >= 0; ix -= 1 {
-		t := DIGIT(0);
-		w = (w << _WORD(_DIGIT_BITS) | _WORD(numerator.digit[ix]));
-		if w >= _WORD(denominator) {
-			t = DIGIT(w / _WORD(denominator));
-			w -= _WORD(t) * _WORD(denominator);
-		}
-		q.digit[ix] = t;
-	}
-	remainder = DIGIT(w);
-
-	if quotient != nil {
-		clamp(q);
-		swap(q, quotient);
-	}
-	destroy(q);
-	return remainder, nil;
-}
-
 /*
 	Function computing both GCD and (if target isn't `nil`) also LCM.
 */
@@ -1176,9 +1101,11 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
 	/*
 		Zero digits above the last digit of the modulus.
 	*/
-	zero_count := (bits / _DIGIT_BITS) + 0 if (bits % _DIGIT_BITS == 0) else 1;
+	zero_count := (bits / _DIGIT_BITS);
+	zero_count += 0 if (bits % _DIGIT_BITS == 0) else 1;
+
 	/*
-		Zero remainder.
+		Zero remainder. Special case, can't use `zero_unused`.
 	*/
 	if zero_count > 0 {
 		mem.zero_slice(remainder.digit[zero_count:]);

+ 1 - 1
core/math/big/build.bat

@@ -4,7 +4,7 @@
 :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:size -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:size
-odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check
+:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check
 :odin build . -build-mode:shared -show-timings -o:speed
 
 python test.py

+ 13 - 7
core/math/big/example.odin

@@ -65,13 +65,19 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 
-	foo := "686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214";
-	err := atoi(a, foo, 10);
-
-	if err != nil do fmt.printf("atoi returned %v\n", err);
-
-	print("foo: ", a);
-
+	N :: 12345;
+	D :: 4;
+
+	set(a, N);
+	print("a: ", a);
+	div(b, a, D);
+	rem, _ := mod(a, D);
+	print("b: ", b);
+	fmt.printf("rem: %v\n", rem);
+
+	mul(b, b, D);
+	add(b, b, rem);
+	print("b: ", b);
 }
 
 main :: proc() {

+ 6 - 11
core/math/big/helpers.odin

@@ -56,7 +56,7 @@ int_set_from_integer :: proc(dest: ^Int, src: $T, minimize := false, allocator :
 		dest.used += 1;
 		src >>= _DIGIT_BITS;
 	}
-	_zero_unused(dest);
+	zero_unused(dest);
 	return nil;
 }
 
@@ -94,7 +94,7 @@ int_copy :: proc(dest, src: ^Int, minimize := false, allocator := context.alloca
 	dest.sign  = src.sign;
 	dest.flags = src.flags &~ {.Immutable};
 
-	_zero_unused(dest);
+	zero_unused(dest);
 	return nil;
 }
 copy :: proc { int_copy, };
@@ -583,16 +583,11 @@ assert_initialized :: proc(a: ^Int, loc := #caller_location) {
 	assert(is_initialized(a), "`Int` was not properly initialized.", loc);
 }
 
-_zero_unused :: proc(a: ^Int) {
-	if a == nil {
-		return;
-	} else if !is_initialized(a) {
-		return;
-	}
+zero_unused :: proc(dest: ^Int, old_used := -1) {
+	if dest == nil { return; }
+	if ! #force_inline is_initialized(dest) { return; }
 
-	if a.used < len(a.digit) {
-		mem.zero_slice(a.digit[a.used:]);
-	}
+	internal_zero_unused(dest, old_used);
 }
 
 clear_if_uninitialized_single :: proc(arg: ^Int) -> (err: Error) {

+ 127 - 29
core/math/big/internal.odin

@@ -96,13 +96,11 @@ internal_int_add_unsigned :: proc(dest, a, b: ^Int, allocator := context.allocat
 		Add remaining carry.
 	*/
 	dest.digit[i] = carry;
-	zero_count := old_used - dest.used;
+
 	/*
 		Zero remainder.
 	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -237,13 +235,11 @@ internal_int_add_digit :: proc(dest, a: ^Int, digit: DIGIT) -> (err: Error) {
 	*/
 	dest.sign = .Zero_or_Positive;
 
-	zero_count := old_used - dest.used;
 	/*
 		Zero remainder.
 	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
+
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -307,13 +303,11 @@ internal_int_sub_unsigned :: proc(dest, number, decrease: ^Int, allocator := con
 		dest.digit[i] &= _MASK;
 	}
 
-	zero_count := old_used - dest.used;
 	/*
 		Zero remainder.
 	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
+
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -430,13 +424,11 @@ internal_int_sub_digit :: proc(dest, number: ^Int, digit: DIGIT) -> (err: Error)
 		}
 	}
 
-	zero_count := old_used - dest.used;
 	/*
 		Zero remainder.
 	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
+
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -472,13 +464,11 @@ internal_int_shr1 :: proc(dest, src: ^Int) -> (err: Error) {
 		fwd_carry = carry;
 	}
 
-	zero_count := old_used - dest.used;
 	/*
 		Zero remainder.
 	*/
-	if zero_count > 0 {
-		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
+
 	/*
 		Adjust dest.used based on leading zeroes.
 	*/
@@ -522,6 +512,8 @@ internal_int_shl1 :: proc(dest, src: ^Int) -> (err: Error) {
 	Multiply by a DIGIT.
 */
 internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator := context.allocator) -> (err: Error) {
+	assert(dest != nil && src != nil);
+
 	if multiplier == 0 {
 		return zero(dest);
 	}
@@ -581,16 +573,13 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
 		Store final carry [if any] and increment used.
 	*/
 	dest.digit[ix] = DIGIT(carry);
-
 	dest.used = src.used + 1;
+
 	/*
-		Zero unused digits.
+		Zero remainder.
 	*/
-	//_zero_unused(dest);
-	zero_count := old_used - dest.used;
-	if zero_count > 0  {
-	 	mem.zero_slice(dest.digit[dest.used:][:zero_count]);
-	}
+	internal_zero_unused(dest, old_used);
+
 	return clamp(dest);
 }
 
@@ -702,7 +691,93 @@ internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, a
 	}
 	return;
 }
-internal_divmod :: proc { internal_int_divmod, };
+
+/*
+	Single digit division (based on routine from MPI).
+	The quotient is optional and may be passed a nil.
+*/
+internal_int_divmod_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
+	/*
+		Cannot divide by zero.
+	*/
+	if denominator == 0 { return 0, .Division_by_Zero; }
+
+	/*
+		Quick outs.
+	*/
+	if denominator == 1 || numerator.used == 0 {
+		if quotient != nil {
+			return 0, copy(quotient, numerator);
+		}
+		return 0, err;
+	}
+	/*
+		Power of two?
+	*/
+	if denominator == 2 {
+		if numerator.used > 0 && numerator.digit[0] & 1 != 0 {
+			// Remainder is 1 if numerator is odd.
+			remainder = 1;
+		}
+		if quotient == nil {
+			return remainder, nil;
+		}
+		return remainder, shr(quotient, numerator, 1);
+	}
+
+	ix: int;
+	if is_power_of_two(int(denominator)) {
+		ix = 1;
+		for ix < _DIGIT_BITS && denominator != (1 << uint(ix)) {
+			ix += 1;
+		}
+		remainder = numerator.digit[0] & ((1 << uint(ix)) - 1);
+		if quotient == nil {
+			return remainder, nil;
+		}
+
+		return remainder, shr(quotient, numerator, int(ix));
+	}
+
+	/*
+		Three?
+	*/
+	if denominator == 3 {
+		return _int_div_3(quotient, numerator);
+	}
+
+	/*
+		No easy answer [c'est la vie].  Just division.
+	*/
+	q := &Int{};
+
+	if err = grow(q, numerator.used); err != nil { return 0, err; }
+
+	q.used = numerator.used;
+	q.sign = numerator.sign;
+
+	w := _WORD(0);
+
+	for ix = numerator.used - 1; ix >= 0; ix -= 1 {
+		t := DIGIT(0);
+		w = (w << _WORD(_DIGIT_BITS) | _WORD(numerator.digit[ix]));
+		if w >= _WORD(denominator) {
+			t = DIGIT(w / _WORD(denominator));
+			w -= _WORD(t) * _WORD(denominator);
+		}
+		q.digit[ix] = t;
+	}
+	remainder = DIGIT(w);
+
+	if quotient != nil {
+		clamp(q);
+		swap(q, quotient);
+	}
+	destroy(q);
+	return remainder, nil;
+}
+
+internal_divmod :: proc { internal_int_divmod, internal_int_divmod_digit, };
 
 /*
 	Asssumes quotient, numerator and denominator to have been initialized and not to be nil.
@@ -726,4 +801,27 @@ internal_int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error
 
 	return #force_inline internal_add(remainder, remainder, numerator);
 }
-internal_mod :: proc{ internal_int_mod, };
+internal_mod :: proc{ internal_int_mod, };
+
+
+
+internal_int_zero_unused :: #force_inline proc(dest: ^Int, old_used := -1) {
+	/*
+		If we don't pass the number of previously used DIGITs, we zero all remaining ones.
+	*/
+	zero_count: int;
+	if old_used == -1 {
+		zero_count = len(dest.digit) - dest.used;
+	} else {
+		zero_count = old_used - dest.used;
+	}
+
+	/*
+		Zero remainder.
+	*/
+	if zero_count > 0 && dest.used < len(dest.digit) {
+		mem.zero_slice(dest.digit[dest.used:][:zero_count]);
+	}
+}
+
+internal_zero_unused :: proc { internal_int_zero_unused, };

+ 1 - 1
core/math/big/logical.odin

@@ -323,7 +323,7 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
     	quotient.digit[x] = quotient.digit[x + digits];
 	}
 	quotient.used -= digits;
-	_zero_unused(quotient);
+	zero_unused(quotient);
 	return clamp(quotient);
 }
 shr_digit :: proc { int_shr_digit, };

+ 1 - 1
core/math/big/radix.odin

@@ -452,7 +452,7 @@ _itoa_raw_full :: proc(a: ^Int, radix: i8, buffer: []u8, zero_terminate := false
 
 	remainder: DIGIT;
 	for {
-		if remainder, err = _int_div_digit(temp, temp, DIGIT(radix)); err != nil {
+		if remainder, err = #force_inline internal_divmod(temp, temp, DIGIT(radix)); err != nil {
 			destroy(temp, denominator);
 			return len(buffer) - available, err;
 		}

+ 7 - 1
core/math/big/tune.odin

@@ -43,7 +43,13 @@ print_timings :: proc() {
 		}
 	}
 
-	fmt.println("\nTimings:");
+	for v in Timings {
+		if v.count > 0 {
+			fmt.println("Timings:");
+			break;
+		}
+	}
+
 	for v, i in Timings {
 		if v.count > 0 {
 			avg_ticks  := time.Duration(f64(v.ticks) / f64(v.count));