Explorar el Código

big: Add `div`.

Jeroen van Rijn hace 4 años
padre
commit
c2255c6c19
Se han modificado 2 ficheros con 79 adiciones y 43 borrados
  1. 63 43
      core/math/big/basic.odin
  2. 16 0
      core/math/big/example.odin

+ 63 - 43
core/math/big/basic.odin

@@ -652,6 +652,24 @@ sqr :: proc(dest, src: ^Int) -> (err: Error) {
 	return mul(dest, src, src);
 }
 
+/*
+	divmod.
+	Both the quotient and remainder are optional and may be passed a nil.
+*/
+int_div :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
+	/*
+		Early out if neither of the results is wanted.
+	*/
+	if quotient == nil && remainder == nil 		        { return .None; }
+
+
+	if err = clear_if_uninitialized(numerator);			err != .None { return err; }
+	if err = clear_if_uninitialized(denominator);		err != .None { return err; }
+
+	return _int_div(quotient, remainder, numerator, denominator);
+}
+div :: proc{ int_div, };
+
 
 /*
 	==========================
@@ -1014,47 +1032,49 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) {
 */
 _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
 
-// 	mp_int ta, tb, tq, q;
-// 	int n;
-// 	bool neg;
-// 	mp_err err;
-
-// 	/* init our temps */
-// 	if ((err = mp_init_multi(&ta, &tb, &tq, &q, NULL)) != MP_OKAY) {
-// 		return err;
-// 	}
-
-// 	mp_set(&tq, 1uL);
-// 	n = mp_count_bits(a) - mp_count_bits(b);
-// 	if ((err = mp_abs(a, &ta)) != MP_OKAY)                         goto LBL_ERR;
-// 	if ((err = mp_abs(b, &tb)) != MP_OKAY)                         goto LBL_ERR;
-// 	if ((err = mp_mul_2d(&tb, n, &tb)) != MP_OKAY)                 goto LBL_ERR;
-// 	if ((err = mp_mul_2d(&tq, n, &tq)) != MP_OKAY)                 goto LBL_ERR;
-
-// 	while (n-- >= 0) {
-// 		if (mp_cmp(&tb, &ta) != MP_GT) {
-// 			if ((err = mp_sub(&ta, &tb, &ta)) != MP_OKAY)            goto LBL_ERR;
-// 			if ((err = mp_add(&q, &tq, &q)) != MP_OKAY)              goto LBL_ERR;
-// 		}
-// 		if ((err = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY)        goto LBL_ERR;
-// 		if ((err = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY)        goto LBL_ERR;
-// 	}
-
-// 	/* now q == quotient and ta == remainder */
-
-// 	neg = (a->sign != b->sign);
-// 	if (c != NULL) {
-// 		mp_exch(c, &q);
-// 		c->sign = ((neg && !mp_iszero(c)) ? MP_NEG : MP_ZPOS);
-// 	}
-// 	if (d != NULL) {
-// 		mp_exch(d, &ta);
-// 		d->sign = (mp_iszero(d) ? MP_ZPOS : a->sign);
-// 	}
-// LBL_ERR:
-// 	mp_clear_multi(&ta, &tb, &tq, &q, NULL);
-// 	return err;
-
-
-	return .None;
+	ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
+
+	goto_end: for {
+		if err = one(tq);									err != .None { break goto_end; }
+
+		num_bits, _ := count_bits(numerator);
+		den_bits, _ := count_bits(denominator);
+		n := num_bits - den_bits;
+
+		if err = abs(ta, numerator);						err != .None { break goto_end; }
+		if err = abs(tb, denominator);						err != .None { break goto_end; }
+
+		if err = shl(tb, tb, n);							err != .None { break goto_end; }
+		if err = shl(tq, tq, n);							err != .None { break goto_end; }
+
+		for ; n >= 0; n -= 1 {
+			c: int;
+			if c, err = cmp(tb, ta);						err != .None { break goto_end; }
+			if c != 1 {
+				if err = sub(ta, ta, tb);					err != .None { break goto_end; }
+				if err = add( q, tq,  q);					err != .None { break goto_end; }
+			}
+			if err = shr1(tb, tb);							err != .None { break goto_end; }
+			if err = shr1(tq, tq);							err != .None { break goto_end; }
+		}		
+
+		/*
+			Now q == quotient and ta == remainder.
+		*/
+		neg := numerator.sign != denominator.sign;
+		if quotient != nil {
+			swap(quotient, q);
+			z, _ := is_zero(quotient);
+			quotient.sign = .Negative if neg && !z else .Zero_or_Positive;
+		}
+		if remainder != nil {
+			swap(remainder, ta);
+			z, _ := is_zero(numerator);
+			remainder.sign = .Zero_or_Positive if z else numerator.sign;
+		}
+
+		break goto_end;
+	}
+	destroy(ta, tb, tq, q);
+	return err;
 }

+ 16 - 0
core/math/big/example.odin

@@ -68,6 +68,22 @@ demo :: proc() {
 	print("quotient   ", quotient,    10);
 	fmt.println("remainder  ", i);
 	fmt.println("error", err);
+
+	fmt.println(); fmt.println();
+
+	err = set (numerator,   15625);
+	err = set (denominator,     3);
+	err = zero(quotient);
+
+	print("numerator  ", numerator,   10);
+	print("denominator", denominator, 10);
+
+	err = _int_div_small(quotient, remainder, numerator, denominator);
+
+	print("quotient   ", quotient,    10);
+	print("remainder  ", remainder,   10);
+
+
 }
 
 main :: proc() {