Browse Source

big: Add `_private_int_mul_balance`.

Jeroen van Rijn 4 years ago
parent
commit
737b4fde1c

+ 2 - 2
core/math/big/build.bat

@@ -1,9 +1,9 @@
 @echo off
-odin run . -vet
+:odin run . -vet
 
 set TEST_ARGS=-fast-tests
 :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
-:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
 :odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
 :odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
 :odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests %TEST_ARGS%

+ 0 - 9
core/math/big/example.odin

@@ -205,15 +205,6 @@ int_to_byte_little :: proc(v: ^Int) {
 demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
-
-	foo := "92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201";
-
-	set(a, foo);
-
-	print("a: ", a);
-
-	is_sqr, _ := internal_int_is_square(a);
-	fmt.printf("is_square: %v\n", is_sqr);
 }
 
 main :: proc() {

+ 4 - 7
core/math/big/internal.odin

@@ -659,8 +659,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			Can we use the balance method? Check sizes.
 			* The smaller one needs to be larger than the Karatsuba cut-off.
 			* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
-			* to make some sense, but it depends on architecture, OS, position of the
-			* stars... so YMMV.
+			* to make some sense, but it depends on architecture, OS, position of the stars... so YMMV.
 			* Using it to cut the input into slices small enough for _mul_comba
 			* was actually slower on the author's machine, but YMMV.
 		*/
@@ -669,13 +668,11 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 		max_used := max(src.used, multiplier.used);
 		digits   := src.used + multiplier.used + 1;
 
-		if        false &&  min_used     >= MUL_KARATSUBA_CUTOFF &&
-							max_used / 2 >= MUL_KARATSUBA_CUTOFF &&
+		if min_used >= MUL_KARATSUBA_CUTOFF && (max_used / 2) >= MUL_KARATSUBA_CUTOFF && max_used >= (2 * min_used) {
 			/*
 				Not much effect was observed below a ratio of 1:2, but again: YMMV.
 			*/
-							max_used     >= 2 * min_used {
-			// err = s_mp_mul_balance(a,b,c);
+			err = _private_int_mul_balance(dest, src, multiplier);
 		} else if min_used >= MUL_TOOM_CUTOFF {
 			/*
 				Toom path commented out until it no longer fails Factorial 10k or 100k,
@@ -914,7 +911,7 @@ internal_int_factorial :: proc(res: ^Int, n: int, allocator := context.allocator
 	context.allocator = allocator;
 
 	if n >= FACTORIAL_BINARY_SPLIT_CUTOFF {
-		return #force_inline _private_int_factorial_binary_split(res, n);
+		return _private_int_factorial_binary_split(res, n);
 	}
 
 	i := len(_factorial_table);

+ 88 - 19
core/math/big/private.odin

@@ -113,7 +113,7 @@ _private_int_mul_toom :: proc(dest, a, b: ^Int, allocator := context.allocator)
 	context.allocator = allocator;
 
 	S1, S2, T1, a0, a1, a2, b0, b1, b2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
-	defer destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
+	defer internal_destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
 
 	/*
 		Init temps.
@@ -258,7 +258,7 @@ _private_int_mul_karatsuba :: proc(dest, a, b: ^Int, allocator := context.alloca
 	context.allocator = allocator;
 
 	x0, x1, y0, y1, t1, x0y0, x1y1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
-	defer destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
+	defer internal_destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
 
 	/*
 		min # of digits, divided by two.
@@ -546,8 +546,74 @@ _private_int_mul_high_comba :: proc(dest, a, b: ^Int, digits: int, allocator :=
 	return internal_clamp(dest);
 }
 
+/*
+	Single-digit multiplication with the smaller number as the single-digit.
+*/
+_private_int_mul_balance :: proc(dest, a, b: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+	a, b := a, b;
 
+	a0, tmp, r := &Int{}, &Int{}, &Int{};
+	defer internal_destroy(a0, tmp, r);
 
+	b_size   := min(a.used, b.used);
+	n_blocks := max(a.used, b.used) / b_size;
+
+	internal_grow(a0, b_size + 2) or_return;
+	internal_init_multi(tmp, r)   or_return;
+
+	/*
+		Make sure that `a` is the larger one.
+	*/
+	if a.used < b.used {
+		a, b = b, a;
+	}
+	assert(a.used >= b.used);
+
+	i, j := 0, 0;
+	for ; i < n_blocks; i += 1 {
+		/*
+			Cut a slice off of `a`.
+		*/
+
+		a0.used = b_size;
+		internal_copy_digits(a0, a, a0.used, j);
+		j += a0.used;
+		internal_clamp(a0);
+
+		/*
+			Multiply with `b`.
+		*/
+		internal_mul(tmp, a0, b)                                     or_return;
+
+		/*
+			Shift `tmp` to the correct position.
+		*/
+		internal_shl_digit(tmp, b_size * i)                          or_return;
+
+		/*
+			Add to output. No carry needed.
+		*/
+		internal_add(r, r, tmp)                                      or_return;
+	}
+
+	/*
+		The left-overs; there are always left-overs.
+	*/
+	if j < a.used {
+		a0.used = a.used - j;
+		internal_copy_digits(a0, a, a0.used, j);
+		j += a0.used;
+		internal_clamp(a0);
+
+		internal_mul(tmp, a0, b)                                     or_return;
+		internal_shl_digit(tmp, b_size * i)                          or_return;
+		internal_add(r, r, tmp)                                      or_return;
+	}
+
+	internal_swap(dest, r);
+	return;
+}
 
 /*
 	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
@@ -1311,7 +1377,7 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
 
 	ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
 	c: int;
-	defer destroy(ta, tb, tq, q);
+	defer internal_destroy(ta, tb, tq, q);
 
 	for {
 		internal_one(tq) or_return;
@@ -1364,31 +1430,34 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
 	Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
 */
 _private_int_factorial_binary_split :: proc(res: ^Int, n: int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
 
 	inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer internal_destroy(inner, outer, start, stop, temp);
 
-	internal_one(inner, false, allocator) or_return;
-	internal_one(outer, false, allocator) or_return;
+	internal_one(inner, false)                                       or_return;
+	internal_one(outer, false)                                       or_return;
 
 	bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
 
 	for i := bits_used; i >= 0; i -= 1 {
 		start := (n >> (uint(i) + 1)) + 1 | 1;
 		stop  := (n >> uint(i)) + 1 | 1;
-		_private_int_recursive_product(temp, start, stop, 0, allocator) or_return;
-		internal_mul(inner, inner, temp, allocator) or_return;
-		internal_mul(outer, outer, inner, allocator) or_return;
+		_private_int_recursive_product(temp, start, stop, 0)         or_return;
+		internal_mul(inner, inner, temp)                             or_return;
+		internal_mul(outer, outer, inner)                            or_return;
 	}
 	shift := n - intrinsics.count_ones(n);
 
-	return internal_shl(res, outer, int(shift), allocator);
+	return internal_shl(res, outer, int(shift));
 }
 
 /*
 	Recursive product used by binary split factorial algorithm.
 */
 _private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0), allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
 	t1, t2 := &Int{}, &Int{};
 	defer internal_destroy(t1, t2);
 
@@ -1398,28 +1467,28 @@ _private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int
 
 	num_factors := (stop - start) >> 1;
 	if num_factors == 2 {
-		internal_set(t1, start, false, allocator) or_return;
+		internal_set(t1, start, false)                               or_return;
 		when true {
-			internal_grow(t2, t1.used + 1, false, allocator) or_return;
-			internal_add(t2, t1, 2, allocator) or_return;
+			internal_grow(t2, t1.used + 1, false)                    or_return;
+			internal_add(t2, t1, 2)                                  or_return;
 		} else {
-			add(t2, t1, 2) or_return;
+			internal_add(t2, t1, 2)                                  or_return;
 		}
-		return internal_mul(res, t1, t2, allocator);
+		return internal_mul(res, t1, t2);
 	}
 
 	if num_factors > 1 {
 		mid := (start + num_factors) | 1;
-		_private_int_recursive_product(t1, start,  mid, level + 1, allocator) or_return;
-		_private_int_recursive_product(t2,   mid, stop, level + 1, allocator) or_return;
-		return internal_mul(res, t1, t2, allocator);
+		_private_int_recursive_product(t1, start,  mid, level + 1)   or_return;
+		_private_int_recursive_product(t2,   mid, stop, level + 1)   or_return;
+		return internal_mul(res, t1, t2);
 	}
 
 	if num_factors == 1 {
-		return #force_inline internal_set(res, start, true, allocator);
+		return #force_inline internal_set(res, start, true);
 	}
 
-	return #force_inline internal_one(res, true, allocator);
+	return #force_inline internal_one(res, true);
 }
 
 /*

+ 12 - 5
core/math/big/test.py

@@ -403,14 +403,21 @@ def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
 		
 	return test("test_shr_signed", res, [a, bits], expected_error, expected_result)
 
-def test_factorial(n = 0, expected_error = Error.Okay):
-	args  = [n]
-	res   = int_factorial(*args)
+def test_factorial(number = 0, expected_error = Error.Okay):
+	print("Factorial:", number)
+	args  = [number]
+	try:
+		res = int_factorial(*args)
+	except OSError as e:
+		print("{} while trying to factorial {}.".format(e, number))
+		if EXIT_ON_FAIL: exit(3)
+		return False
+
 	expected_result = None
 	if expected_error == Error.Okay:
-		expected_result = math.factorial(n)
+		expected_result = math.factorial(number)
 		
-	return test("test_factorial", res, [n], expected_error, expected_result)
+	return test("test_factorial", res, [number], expected_error, expected_result)
 
 def test_gcd(a = 0, b = 0, expected_error = Error.Okay):
 	args  = [arg_to_odin(a), arg_to_odin(b)]