Browse Source

big: Add tests for `log`.

Jeroen van Rijn 4 years ago
parent
commit
385b9c9922
4 changed files with 123 additions and 79 deletions
  1. 2 2
      core/math/big/example.odin
  2. 2 2
      core/math/big/exp_log.odin
  3. 28 1
      core/math/big/test.odin
  4. 91 74
      core/math/big/test.py

+ 2 - 2
core/math/big/example.odin

@@ -70,10 +70,10 @@ demo :: proc() {
 	// r := &rnd.Rand{};
 	// rnd.init(r, 12345);
 
-	// as := cstring("596360079055148742691396559496540363");
+	// as := cstring("12341234");
 	// bs := cstring("159671292010002348397151706347412301");
 
-	// res := test_div_two(as, bs);
+	// res := test_log(as, 2, 10);
 	// fmt.print(res);
 	// destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	// defer destroy(destination, source, quotient, remainder, numerator, denominator);

+ 2 - 2
core/math/big/exp_log.odin

@@ -14,8 +14,8 @@ int_log :: proc(a: ^Int, base: DIGIT) -> (res: int, err: Error) {
 		return -1, .Invalid_Argument;
 	}
 	if err = clear_if_uninitialized(a); err != .None { return -1, err; }
-	if n, _ := is_neg(a);  n { return -1, .Invalid_Argument; }
-	if z, _ := is_zero(a); z { return -1, .Invalid_Argument; }
+	if n, _ := is_neg(a);  n { return -1, .Math_Domain_Error; }
+	if z, _ := is_zero(a); z { return -1, .Math_Domain_Error; }
 
 	/*
 		Fast path for bases that are a power of two.

+ 28 - 1
core/math/big/test.odin

@@ -95,4 +95,31 @@ PyRes :: struct {
 	r, err = int_itoa_cstring(quotient, i8(radix), context.temp_allocator);
 	if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
 	return PyRes{res = r, err = .None};
-}
+}
+
+
+/*
+	res = log(a, base)
+*/
+@export test_log :: proc "c" (a: cstring, base := DIGIT(2), radix := int(10)) -> (res: PyRes) {
+	context = runtime.default_context();
+	err: Error;
+	l: int;
+
+	aa := &Int{};
+	defer destroy(aa);
+
+	if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":log:atoi(a):", err=err}; }
+	if l, err = log(aa, base);               err != .None { return PyRes{res=":log:log(a, base):", err=err}; }
+
+	zero(aa);
+	aa.digit[0] = DIGIT(l)  & _MASK;
+	aa.digit[1] = DIGIT(l) >> _DIGIT_BITS;
+	aa.used = 2;
+	clamp(aa);
+
+	r: cstring;
+	r, err = int_itoa_cstring(aa, i8(radix), context.temp_allocator);
+	if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
+	return PyRes{res = r, err = .None};
+}

+ 91 - 74
core/math/big/test.py

@@ -6,7 +6,12 @@ import platform
 import time
 
 #
-# Where is the DLL? If missing, build using: `odin build . -build-mode:dll`
+# Fast tests?
+#
+FAST_TESTS = True
+
+#
+# Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
 #
 if platform.system() == "Windows":
 	LIB_PATH = os.getcwd() + os.sep + "big.dll"
@@ -16,7 +21,8 @@ elif platform.system() == "Darwin":
 	LIB_PATH = os.getcwd() + os.sep + "big.dylib"
 else:
 	print("Platform is unsupported.")
-	os.exit(1)
+	exit(1)
+
 #
 # How many iterations of each random test do we want to run?
 #
@@ -27,6 +33,14 @@ BITS_AND_ITERATIONS = [
 	(12_000,     10),
 ]
 
+#
+# Fast tests?
+#
+if FAST_TESTS:
+	for k in range(len(BITS_AND_ITERATIONS)):
+		b, i = BITS_AND_ITERATIONS[k]
+		BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
+
 #
 # Result values will be passed in a struct { res: cstring, err: Error }
 #
@@ -58,76 +72,53 @@ except:
 	print("Couldn't find or load " + LIB_PATH + ".")
 	exit(1)
 
+def load(export_name, args, res):
+	export_name.argtypes = args
+	export_name.restype  = res
+	return export_name
+
+error_string = load(l.test_error_string, [c_byte], c_char_p)
+
 #
 # res = a + b, err
 #
-try:
-	l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong]
-	l.test_add_two.restype  = Res
-except:
-	print("Couldn't find exported function 'test_add_two'")
-	exit(2)
-
-add_two = l.test_add_two
+add_two = load(l.test_add_two, [c_char_p, c_char_p, c_longlong], Res)
 
 #
 # res = a - b, err
 #
-try:
-	l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong]
-	l.test_sub_two.restype  = Res
-except:
-	print("Couldn't find exported function 'test_sub_two'")
-	exit(2)
-
-sub_two = l.test_sub_two
+sub_two = load(l.test_sub_two, [c_char_p, c_char_p, c_longlong], Res)
 
 #
 # res = a * b, err
 #
-try:
-	l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong]
-	l.test_mul_two.restype  = Res
-except:
-	print("Couldn't find exported function 'test_add_two'")
-	exit(2)
-
-mul_two = l.test_mul_two
+mul_two = load(l.test_mul_two, [c_char_p, c_char_p, c_longlong], Res)
 
 #
 # res = a / b, err
 #
-try:
-	l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong]
-	l.test_div_two.restype  = Res
-except:
-	print("Couldn't find exported function 'test_div_two'")
-	exit(2)
-
-div_two = l.test_div_two
+div_two = load(l.test_div_two, [c_char_p, c_char_p, c_longlong], Res)
 
 
+#
+# res = log(a, base)
+#
+int_log = load(l.test_log, [c_char_p, c_longlong, c_longlong], Res)
 
-try:
-	l.test_error_string.argtypes = [c_byte]
-	l.test_error_string.restype  = c_char_p
-except:
-	print("Couldn't find exported function 'test_error_string'")
-	exit(2)
 
 def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_result = ""):
 	passed = True
 	r = None
 
 	if res.err != expected_error:
-		error_type = l.test_error_string(res.err).decode('utf-8')
+		error_type = error_string(res.err).decode('utf-8')
 		error_loc  = res.res.decode('utf-8')
 
-		error_string = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
+		error = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
 		if len(param):
-			error_string += " with params {}".format(param)
+			error += " with params {}".format(param)
 
-		print(error_string, flush=True)
+		print(error, flush=True)
 		passed = False
 	elif res.err == E_None:
 		try:
@@ -137,50 +128,43 @@ def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_re
 
 		r = eval(res.res)
 		if r != expected_result:
-			error_string = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
+			error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
 			if len(param):
-				error_string += " with params {}".format(param)
+				error += " with params {}".format(param)
 
-			print(error_string, flush=True)
+			print(error, flush=True)
 			passed = False
 
 	return passed
 
+
 def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
-	sa   = str(a)
-	sb   = str(b)
-	sa_c = sa.encode('utf-8')
-	sb_c = sb.encode('utf-8')
+	args = [str(a), str(b), radix]
+	sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
 	res  = add_two(sa_c, sb_c, radix)
 	if expected_result == None:
 		expected_result = a + b
-	return test("test_add_two", res, [sa, sb, radix], expected_error, expected_result)
+	return test("test_add_two", res, args, expected_error, expected_result)
 
 def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
-	sa   = str(a)
-	sb   = str(b)
-	sa_c = sa.encode('utf-8')
-	sb_c = sb.encode('utf-8')
+	sa,     sb = str(a), str(b)
+	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	res  = sub_two(sa_c, sb_c, radix)
 	if expected_result == None:
 		expected_result = a - b
 	return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
 
 def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
-	sa   = str(a)
-	sb   = str(b)
-	sa_c = sa.encode('utf-8')
-	sb_c = sb.encode('utf-8')
+	sa,     sb = str(a), str(b)
+	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	res  = mul_two(sa_c, sb_c, radix)
 	if expected_result == None:
 		expected_result = a * b
 	return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
 
 def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
-	sa   = str(a)
-	sb   = str(b)
-	sa_c = sa.encode('utf-8')
-	sb_c = sb.encode('utf-8')
+	sa,     sb = str(a), str(b)
+	sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
 	try:
 		res  = div_two(sa_c, sb_c, radix)
 	except:
@@ -190,6 +174,17 @@ def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_res
 		expected_result = a // b if b != 0 else None
 	return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
 
+
+def test_log(a = 0, base = 0, radix = 10, expected_error = E_None, expected_result = None):
+	args  = [str(a), base, radix]
+	sa_c  = args[0].encode('utf-8')
+	res   = int_log(sa_c, base, radix)
+
+	if expected_result == None:
+		expected_result = int(log(a, base))
+	return test("test_log", res, args, expected_error, expected_result)
+
+
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 #
 # The last two arguments in tests are the expected error and expected result.
@@ -211,15 +206,26 @@ TESTS = {
 		[ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
 	],
 	test_div_two: [
-		[ 54321, 12345,    10, ],
-		[ 55431,     0,    10, E_Division_by_Zero, ],
+		[ 54321,	12345,		10, ],
+		[ 55431,		0,		10,		E_Division_by_Zero, ],
+	],
+	test_log: [
+		[ 3192,			1,		10,		E_Invalid_Argument,		":log:log(a, base):"],
+		[ -1234,		2,		10,		E_Math_Domain_Error,	":log:log(a, base):"],
+		[ 0,			2,		10,		E_Math_Domain_Error, 	":log:log(a, base):"],
+		[ 1024,			2,		10, ],
 	],
 }
 
 TOTAL_TIME     = 0
+total_passes   = 0
 total_failures = 0
 
+
 if __name__ == '__main__':
+
+	test_log(1234, 2, 10)
+
 	print("---- core:math/big tests ----")
 	print()
 
@@ -240,6 +246,7 @@ if __name__ == '__main__':
 
 			if res:
 				count_pass     += 1
+				total_passes   += 1
 			else:
 				count_fail     += 1
 				total_failures += 1
@@ -251,19 +258,25 @@ if __name__ == '__main__':
 		print("---- core:math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
 		print()
 
-		for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]:
+		for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two, test_log]:
 			count_pass = 0
 			count_fail = 0
 			TIMINGS    = {}
 
 			for i in range(ITERATIONS):
 				a = randint(0, 1 << BITS)
-				b = randint(0, 1 << BITS)
-				res = None
 
-				# We've already tested division by zero above.
-				if b == 0 and test_proc == test_div_two:
-					b = b + 1
+				if test_proc == test_div_two:
+					# We've already tested division by zero above.
+					b = randint(1, 1 << BITS)
+				elif test_proc == test_log:
+					# We've already tested log's domain errors.
+					a = randint(1, 1 << BITS)
+					b = randint(2, 1 << 60)
+				else:
+					b = randint(0, 1 << BITS)					
+
+				res = None
 
 				start = time.perf_counter()
 				res   = test_proc(a, b)
@@ -277,13 +290,17 @@ if __name__ == '__main__':
 
 				if res:
 					count_pass     += 1
+					total_passes   += 1
 				else:
 					count_fail     += 1
 					total_failures += 1
 
 			print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
 
-	print("\ntotal: {0:.3f} ms".format(TOTAL_TIME * 1_000))
+	print()		
+	print("---- THE END ----")
+	print()
+	print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
 
 	if total_failures:
-		os.exit(1)
+		exit(1)