Browse Source

Fix i128 division

Ginger Bill 8 years ago
parent
commit
0fff6a2b74
4 changed files with 130 additions and 118 deletions
  1. 2 2
      core/_soft_numbers.odin
  2. 11 11
      src/check_expr.cpp
  3. 115 102
      src/integer128.cpp
  4. 2 3
      src/main.cpp

+ 2 - 2
core/_soft_numbers.odin

@@ -74,8 +74,8 @@ __i128_quo_mod :: proc(a, b: i128, rem: ^i128) -> (quo: i128) #cc_c #link_name "
 
 
 __u128_quo_mod :: proc(a, b: u128, rem: ^u128) -> (quo: u128) #cc_c #link_name "__udivmodti4" {
-	alo, ahi := u64(a), u64(a>>64);
-	blo, bhi := u64(b), u64(b>>64);
+	alo := u64(a);
+	blo := u64(b);
 	if b == 0 {
 		if rem != nil do rem^ = 0;
 		return u128(alo/blo);

+ 11 - 11
src/check_expr.cpp

@@ -1329,7 +1329,7 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 
 
 }
-void check_union_type(Checker *c, Type *named_type, Type *union_type, AstNode *node) {
+void check_union_type(Checker *c, Type *union_type, AstNode *node) {
 	GB_ASSERT(is_type_union(union_type));
 	ast_node(ut, UnionType, node);
 
@@ -1577,7 +1577,7 @@ void check_enum_type(Checker *c, Type *enum_type, Type *named_type, AstNode *nod
 }
 
 
-void check_bit_field_type(Checker *c, Type *bit_field_type, Type *named_type, AstNode *node) {
+void check_bit_field_type(Checker *c, Type *bit_field_type, AstNode *node) {
 	ast_node(bft, BitFieldType, node);
 	GB_ASSERT(is_type_bit_field(bit_field_type));
 
@@ -3069,7 +3069,7 @@ bool check_type_internal(Checker *c, AstNode *e, Type **type, Type *named_type)
 		*type = make_type_union(c->allocator);
 		set_base_type(named_type, *type);
 		check_open_scope(c, e);
-		check_union_type(c, named_type, *type, e);
+		check_union_type(c, *type, e);
 		check_close_scope(c);
 		(*type)->Union.node = e;
 		return true;
@@ -3089,7 +3089,7 @@ bool check_type_internal(Checker *c, AstNode *e, Type **type, Type *named_type)
 		*type = make_type_bit_field(c->allocator);
 		set_base_type(named_type, *type);
 		check_open_scope(c, e);
-		check_bit_field_type(c, *type, named_type, e);
+		check_bit_field_type(c, *type, e);
 		check_close_scope(c);
 		return true;
 	case_end;
@@ -3150,13 +3150,13 @@ Type *check_type(Checker *c, AstNode *e, Type *named_type) {
 		type = t_invalid;
 	}
 
-	if (type->kind == Type_Named) {
-		if (type->Named.base == nullptr) {
-			gbString name = type_to_string(type);
-			error(e, "Invalid type definition of %s", name);
-			gb_string_free(name);
-			type->Named.base = t_invalid;
-		}
+	if (type->kind == Type_Named &&
+	    type->Named.base == nullptr) {
+		// IMPORTANT TODO(bill): Is this a serious error?!
+		#if 0
+		error(e, "Invalid type definition of `%.*s`", LIT(type->Named.name));
+		#endif
+		type->Named.base = t_invalid;
 	}
 
 	#if 0

+ 115 - 102
src/integer128.cpp

@@ -95,48 +95,48 @@ void i128_divide (i128 num, i128 den, i128 *quo, i128 *rem);
 i128 i128_quo    (i128 a, i128 b);
 i128 i128_mod    (i128 a, i128 b);
 
-bool operator==(u128 a, u128 b) { return u128_eq(a, b); }
-bool operator!=(u128 a, u128 b) { return u128_ne(a, b); }
-bool operator< (u128 a, u128 b) { return u128_lt(a, b); }
-bool operator> (u128 a, u128 b) { return u128_gt(a, b); }
-bool operator<=(u128 a, u128 b) { return u128_le(a, b); }
-bool operator>=(u128 a, u128 b) { return u128_ge(a, b); }
-
-u128 operator+(u128 a, u128 b) { return u128_add(a, b); }
-u128 operator-(u128 a, u128 b) { return u128_sub(a, b); }
-u128 operator*(u128 a, u128 b) { return u128_mul(a, b); }
-u128 operator/(u128 a, u128 b) { return u128_quo(a, b); }
-u128 operator%(u128 a, u128 b) { return u128_mod(a, b); }
-u128 operator&(u128 a, u128 b) { return u128_and(a, b); }
-u128 operator|(u128 a, u128 b) { return u128_or (a, b); }
-u128 operator^(u128 a, u128 b) { return u128_xor(a, b); }
-u128 operator~(u128 a)         { return u128_not(a); }
-u128 operator+(u128 a)         { return a; }
-u128 operator-(u128 a)         { return u128_neg(a); }
-u128 operator<<(u128 a, u32 b) { return u128_shl(a, b); }
-u128 operator>>(u128 a, u32 b) { return u128_shr(a, b); }
-
-
-bool operator==(i128 a, i128 b) { return i128_eq(a, b); }
-bool operator!=(i128 a, i128 b) { return i128_ne(a, b); }
-bool operator< (i128 a, i128 b) { return i128_lt(a, b); }
-bool operator> (i128 a, i128 b) { return i128_gt(a, b); }
-bool operator<=(i128 a, i128 b) { return i128_le(a, b); }
-bool operator>=(i128 a, i128 b) { return i128_ge(a, b); }
-
-i128 operator+(i128 a, i128 b) { return i128_add(a, b); }
-i128 operator-(i128 a, i128 b) { return i128_sub(a, b); }
-i128 operator*(i128 a, i128 b) { return i128_mul(a, b); }
-i128 operator/(i128 a, i128 b) { return i128_quo(a, b); }
-i128 operator%(i128 a, i128 b) { return i128_mod(a, b); }
-i128 operator&(i128 a, i128 b) { return i128_and(a, b); }
-i128 operator|(i128 a, i128 b) { return i128_or (a, b); }
-i128 operator^(i128 a, i128 b) { return i128_xor(a, b); }
-i128 operator~(i128 a)         { return i128_not(a); }
-i128 operator+(i128 a)         { return a; }
-i128 operator-(i128 a)         { return i128_neg(a); }
-i128 operator<<(i128 a, u32 b) { return i128_shl(a, b); }
-i128 operator>>(i128 a, u32 b) { return i128_shr(a, b); }
+bool operator==(u128 const &a, u128 const &b) { return u128_eq(a, b); }
+bool operator!=(u128 const &a, u128 const &b) { return u128_ne(a, b); }
+bool operator< (u128 const &a, u128 const &b) { return u128_lt(a, b); }
+bool operator> (u128 const &a, u128 const &b) { return u128_gt(a, b); }
+bool operator<=(u128 const &a, u128 const &b) { return u128_le(a, b); }
+bool operator>=(u128 const &a, u128 const &b) { return u128_ge(a, b); }
+
+u128 operator+ (u128 const &a, u128 const &b) { return u128_add(a, b); }
+u128 operator- (u128 const &a, u128 const &b) { return u128_sub(a, b); }
+u128 operator* (u128 const &a, u128 const &b) { return u128_mul(a, b); }
+u128 operator/ (u128 const &a, u128 const &b) { return u128_quo(a, b); }
+u128 operator% (u128 const &a, u128 const &b) { return u128_mod(a, b); }
+u128 operator& (u128 const &a, u128 const &b) { return u128_and(a, b); }
+u128 operator| (u128 const &a, u128 const &b) { return u128_or (a, b); }
+u128 operator^ (u128 const &a, u128 const &b) { return u128_xor(a, b); }
+u128 operator~ (u128 const &a)                { return u128_not(a); }
+u128 operator+ (u128 const &a)                { return a; }
+u128 operator- (u128 const &a)                { return u128_neg(a); }
+u128 operator<<(u128 const &a, u32 const &b)  { return u128_shl(a, b); }
+u128 operator>>(u128 const &a, u32 const &b)  { return u128_shr(a, b); }
+
+
+bool operator==(i128 const &a, i128 const &b) { return i128_eq(a, b); }
+bool operator!=(i128 const &a, i128 const &b) { return i128_ne(a, b); }
+bool operator< (i128 const &a, i128 const &b) { return i128_lt(a, b); }
+bool operator> (i128 const &a, i128 const &b) { return i128_gt(a, b); }
+bool operator<=(i128 const &a, i128 const &b) { return i128_le(a, b); }
+bool operator>=(i128 const &a, i128 const &b) { return i128_ge(a, b); }
+
+i128 operator+ (i128 const &a, i128 const &b) { return i128_add(a, b); }
+i128 operator- (i128 const &a, i128 const &b) { return i128_sub(a, b); }
+i128 operator* (i128 const &a, i128 const &b) { return i128_mul(a, b); }
+i128 operator/ (i128 const &a, i128 const &b) { return i128_quo(a, b); }
+i128 operator% (i128 const &a, i128 const &b) { return i128_mod(a, b); }
+i128 operator& (i128 const &a, i128 const &b) { return i128_and(a, b); }
+i128 operator| (i128 const &a, i128 const &b) { return i128_or (a, b); }
+i128 operator^ (i128 const &a, i128 const &b) { return i128_xor(a, b); }
+i128 operator~ (i128 const &a)                { return i128_not(a); }
+i128 operator+ (i128 const &a)                { return a; }
+i128 operator- (i128 const &a)                { return i128_neg(a); }
+i128 operator<<(i128 const &a, u32 b)         { return i128_shl(a, b); }
+i128 operator>>(i128 const &a, u32 b)         { return i128_shr(a, b); }
 
 ////////////////////////////////////////////////////////////////
 
@@ -482,36 +482,37 @@ u128 u128_mul(u128 a, u128 b) {
 	return res;
 }
 
-bool u128_hibit(u128 *d) { return (d->hi & BIT128_U64_HIGHBIT) != 0; }
+bool u128_hibit(u128 const &d) { return (d.hi & BIT128_U64_HIGHBIT) != 0; }
+bool i128_hibit(i128 const &d) { return d.hi < 0; }
 
-void u128_divide(u128 num, u128 den, u128 *quo, u128 *rem) {
-	if (u128_eq(den, U128_ZERO)) {
-		if (quo) *quo = u128_from_u64(num.lo/den.lo);
+void u128_divide(u128 a, u128 b, u128 *quo, u128 *rem) {
+	if (u128_eq(b, U128_ZERO)) {
+		if (quo) *quo = u128_from_u64(a.lo/b.lo);
 		if (rem) *rem = U128_ZERO;
-	} else {
-		u128 n = num;
-		u128 d = den;
-		u128 x = U128_ONE;
-		u128 r = U128_ZERO;
-
-		while (u128_ge(n, d) && !u128_hibit(&d)) {
-			x = u128_shl(x, 1);
-			d = u128_shl(d, 1);
-		}
+		return;
+	}
+	u128 r = a;
+	u128 d = b;
+	u128 x = U128_ONE;
+	u128 q = U128_ZERO;
 
-		while (u128_ne(x, U128_ZERO)) {
-			if (u128_ge(n, d)) {
-				n = u128_sub(n, d);
-				r = u128_or(r, x);
-			}
+	while (u128_ge(r, d) && !u128_hibit(d)) {
+		x = u128_shl(x, 1);
+		d = u128_shl(d, 1);
+	}
 
-			x = u128_shr(x, 1);
-			d = u128_shr(d, 1);
+	while (u128_ne(x, U128_ZERO)) {
+		if (u128_ge(r, d)) {
+			r = u128_sub(r, d);
+			q = u128_or(q, x);
 		}
 
-		if (quo) *quo = r;
-		if (rem) *rem = n;
+		x = u128_shr(x, 1);
+		d = u128_shr(d, 1);
 	}
+
+	if (quo) *quo = q;
+	if (rem) *rem = r;
 }
 
 u128 u128_quo(u128 a, u128 b) {
@@ -668,50 +669,62 @@ i128 i128_mul(i128 a, i128 b) {
 	return res;
 }
 
-void i128_divide(i128 a, i128 b, i128 *quo, i128 *rem) {
-	// TODO(bill): Which one is correct?!
-#if 1
-	i128 s = i128_shr(b, 127);
-	b = i128_sub(i128_xor(b, s), s);
-	s = i128_shr(a, 127);
-	b = i128_sub(i128_xor(a, s), s);
-
-	u128 n, r = {0};
-	u128_divide(*cast(u128 *)&a, *cast(u128 *)&b, &n, &r);
-	i128 ni = *cast(i128 *)&n;
-	i128 ri = *cast(i128 *)&r;
-
-	if (quo) *quo = i128_sub(i128_xor(ni, s), s);
-	if (rem) *rem = i128_sub(i128_xor(ri, s), s);
-#else
-	if (i128_eq(b, I128_ZERO)) {
-		if (quo) *quo = i128_from_u64(a.lo/b.lo);
-		if (rem) *rem = I128_ZERO;
+void i128_divide(i128 a, i128 b, i128 *quo_, i128 *rem_) {
+	// TODO(bill): Optimize this i128 division calculation
+	i128 iquo = {0};
+	i128 irem = {0};
+	if (a.hi == 0 && b.hi == 0) {
+		u64 q = a.lo / b.lo;
+		u64 r = a.lo % b.lo;
+		iquo = i128_from_u64(q);
+		irem = i128_from_u64(r);
+	} else if ((~a.hi) == 0 && (~b.hi) == 0) {
+		i64 x = i128_to_i64(a);
+		i64 y = i128_to_i64(b);
+		i64 q = x / y;
+		i64 r = x % y;
+		iquo = i128_from_i64(q);
+		irem = i128_from_i64(r);
+	} else if (a.hi > 0 || b.hi > 0) {
+		u128 q, r = {0};
+		u128_divide(*cast(u128 *)&a, *cast(u128 *)&b, &q, &r);
+		iquo = *cast(i128 *)&q;
+		irem = *cast(i128 *)&r;
+	} else if (i128_eq(b, I128_ZERO)) {
+		iquo = i128_from_u64(a.lo/b.lo);
 	} else {
-		i128 n = a;
-		i128 d = b;
-		i128 x = I128_ONE;
-		i128 r = I128_ZERO;
-
-		while (i128_ge(n, d) && ((i128_shr(d, 128-1).lo&1) == 0)) {
-			x = i128_shl(x, 1);
-			d = i128_shl(d, 1);
+		i32 rem_sign = 1;
+		i32 quo_sign = 1;
+		if (i128_lt(a, I128_ZERO)) {
+			a = i128_neg(a);
+			rem_sign = -1;
 		}
+		if (i128_lt(b, I128_ZERO)) {
+			b = i128_neg(b);
+			quo_sign = -1;
+		}
+		quo_sign *= rem_sign;
 
-		while (i128_ne(x, I128_ZERO)) {
-			if (i128_ge(n, d)) {
-				n = i128_sub(n, d);
-				r = i128_or(r, x);
-			}
+		iquo = a;
 
-			x = i128_shr(x, 1);
-			d = i128_shr(d, 1);
+		for (isize i = 0; i < 128; i++) {
+			irem = i128_shl(irem, 1);
+			if (i128_lt(iquo, I128_ZERO)) {
+				irem.lo |= 1;
+			}
+			iquo = i128_shl(iquo, 1);
+			if (i128_ge(irem, b)) {
+				irem = i128_sub(irem, b);
+				iquo = i128_add(iquo, I128_ONE);
+			}
 		}
 
-		if (quo) *quo = r;
-		if (rem) *rem = n;
+		if (quo_sign < 0) iquo = i128_neg(iquo);
+		if (rem_sign < 0) irem = i128_neg(irem);
 	}
-#endif
+
+	if (quo_) *quo_ = iquo;
+	if (rem_) *rem_ = irem;
 }
 
 i128 i128_quo(i128 a, i128 b) {

+ 2 - 3
src/main.cpp

@@ -451,7 +451,6 @@ int main(int arg_count, char **arg_ptr) {
 
 	Array<String> args = setup_args(arg_count, arg_ptr);
 
-
 #if 1
 
 	String init_filename = {};
@@ -655,7 +654,7 @@ int main(int arg_count, char **arg_ptr) {
 		}
 
 		exit_code = system_exec_command_line_app("msvc-link", true,
-			"link \"%.*s\".obj -OUT:\"%.*s.%s\" %s "
+			"link \"%.*s.obj\" -OUT:\"%.*s.%s\" %s "
 			"/defaultlib:libcmt "
 			// "/nodefaultlib "
 			"/nologo /incremental:no /opt:ref /subsystem:CONSOLE "
@@ -757,7 +756,7 @@ int main(int arg_count, char **arg_ptr) {
 		#endif
 
 		exit_code = system_exec_command_line_app("ld-link", true,
-			"%s \"%.*s\".o -o \"%.*s%s\" %s "
+			"%s \"%.*s.o\" -o \"%.*s%s\" %s "
 			"-lc -lm "
 			" %.*s "
 			" %s "