Browse Source

big: Add `choose`.

Jeroen van Rijn 4 years ago
parent
commit
cd0ce7b76e
2 changed files with 43 additions and 26 deletions
  1. 35 0
      core/math/big/basic.odin
  2. 8 26
      core/math/big/example.odin

+ 35 - 0
core/math/big/basic.odin

@@ -773,8 +773,43 @@ int_factorial :: proc(res: ^Int, n: DIGIT) -> (err: Error) {
 }
 }
 factorial :: proc { int_factorial, };
 factorial :: proc { int_factorial, };
 
 
+/*
+	Number of ways to choose `k` items from `n` items.
+	Also known as the binomial coefficient.
+
+	TODO: Speed up.
+
+	Could be done faster by reusing code from factorial and reusing the common "prefix" results for n!, k! and n-k!
+	We know that n >= k, otherwise we early out with res = 0.
+
+	So:
+		n-k, keep result
+		n, start from previous result
+		k, start from previous result
+
+*/
+int_choose_digit :: proc(res: ^Int, n, k: DIGIT) -> (err: Error) {
+	if res == nil  { return .Invalid_Pointer; }
+	if err = clear_if_uninitialized(res); err != .None { return err; }
 
 
+	if k > n { return zero(res); }
 
 
+	/*
+		res = n! / (k! * (n - k)!)
+	*/
+	n_fac, k_fac, n_minus_k_fac := &Int{}, &Int{}, &Int{};
+	defer destroy(n_fac, k_fac, n_minus_k_fac);
+
+	if err = factorial(n_minus_k_fac, n - k);  err != .None { return err; }
+	if err = factorial(k_fac, k);              err != .None { return err; }
+	if err = mul(k_fac, k_fac, n_minus_k_fac); err != .None { return err; }
+
+	if err = factorial(n_fac, n);              err != .None { return err; }
+	if err = div(res, n_fac, k_fac);           err != .None { return err; }
+
+	return err;	
+}
+choose :: proc { int_choose_digit, };
 
 
 /*
 /*
 	==========================
 	==========================

+ 8 - 26
core/math/big/example.odin

@@ -54,7 +54,7 @@ print_timings :: proc() {
 			case avg < time.Millisecond:
 			case avg < time.Millisecond:
 				avg_s = fmt.tprintf("%v µs", time.duration_microseconds(avg));
 				avg_s = fmt.tprintf("%v µs", time.duration_microseconds(avg));
 			case:
 			case:
-				avg_s = fmt.tprintf("%v", time.duration_milliseconds(avg));
+				avg_s = fmt.tprintf("%v ms", time.duration_milliseconds(avg));
 			}
 			}
 
 
 			total_s: string;
 			total_s: string;
@@ -64,7 +64,7 @@ print_timings :: proc() {
 			case v.t < time.Millisecond:
 			case v.t < time.Millisecond:
 				total_s = fmt.tprintf("%v µs", time.duration_microseconds(v.t));
 				total_s = fmt.tprintf("%v µs", time.duration_microseconds(v.t));
 			case:
 			case:
-				total_s = fmt.tprintf("%v", time.duration_milliseconds(v.t));
+				total_s = fmt.tprintf("%v ms", time.duration_milliseconds(v.t));
 			}
 			}
 
 
 			fmt.printf("\t%v: %s (avg), %s (total, %v calls)\n", i, avg_s, total_s, v.c);
 			fmt.printf("\t%v: %s (avg), %s (total, %v calls)\n", i, avg_s, total_s, v.c);
@@ -76,6 +76,7 @@ Category :: enum {
 	itoa,
 	itoa,
 	atoi,
 	atoi,
 	factorial,
 	factorial,
+	choose,
 	lsb,
 	lsb,
 	ctz,
 	ctz,
 };
 };
@@ -114,30 +115,11 @@ demo :: proc() {
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer destroy(a, b, c, d, e, f);
 	defer destroy(a, b, c, d, e, f);
 
 
-	set(a, 125);
-	set(b, 75);
-
-	err = gcd_lcm(c, d, a, b);
-	fmt.printf("gcd_lcm(");
-	print("a =",   a, 10, false, true, false);
-	print(", b =", b, 10, false, true, false);
-	print("), gcd =",   c, 10, false, true, false);
-	print(", lcm =",   d, 10, false, true, false);
-	fmt.printf(" (err = %v)\n", err);
-
-	err = gcd(c, a, b);
-	fmt.printf("gcd(");
-	print("a =",   a, 10, false, true, false);
-	print(", b =", b, 10, false, true, false);
-	print(") =",   c, 10, false, true, false);
-	fmt.printf(" (err = %v)\n", err);
-
-	err = lcm(c, a, b);
-	fmt.printf("lcm(");
-	print("a =",   a, 10, false, true, false);
-	print(", b =", b, 10, false, true, false);
-	print(") =",   c, 10, false, true, false);
-	fmt.printf(" (err = %v)\n", err);
+	s := time.tick_now();
+	err = choose(a, 65535, 255);
+	Timings[.choose].t += time.tick_since(s); Timings[.choose].c += 1;
+	print("choose", a);
+	fmt.println(err);
 }
 }
 
 
 main :: proc() {
 main :: proc() {