test.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from math import *
  2. from ctypes import *
  3. from random import *
  4. import os
  5. import platform
  6. import time
  7. #
  8. # Fast tests?
  9. #
  10. FAST_TESTS = True
  11. #
  12. # Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
  13. #
  14. if platform.system() == "Windows":
  15. LIB_PATH = os.getcwd() + os.sep + "big.dll"
  16. elif platform.system() == "Linux":
  17. LIB_PATH = os.getcwd() + os.sep + "big.so"
  18. elif platform.system() == "Darwin":
  19. LIB_PATH = os.getcwd() + os.sep + "big.dylib"
  20. else:
  21. print("Platform is unsupported.")
  22. exit(1)
  23. #
  24. # How many iterations of each random test do we want to run?
  25. #
  26. BITS_AND_ITERATIONS = [
  27. ( 120, 10_000),
  28. ( 1_200, 1_000),
  29. ( 4_096, 100),
  30. (12_000, 10),
  31. ]
  32. #
  33. # Fast tests?
  34. #
  35. if FAST_TESTS:
  36. for k in range(len(BITS_AND_ITERATIONS)):
  37. b, i = BITS_AND_ITERATIONS[k]
  38. BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
  39. #
  40. # Result values will be passed in a struct { res: cstring, err: Error }
  41. #
  42. class Res(Structure):
  43. _fields_ = [("res", c_char_p), ("err", c_byte)]
  44. #
  45. # Error enum values
  46. #
  47. E_None = 0
  48. E_Out_Of_Memory = 1
  49. E_Invalid_Pointer = 2
  50. E_Invalid_Argument = 3
  51. E_Unknown_Error = 4
  52. E_Max_Iterations_Reached = 5
  53. E_Buffer_Overflow = 6
  54. E_Integer_Overflow = 7
  55. E_Division_by_Zero = 8
  56. E_Math_Domain_Error = 9
  57. E_Unimplemented = 127
  58. #
  59. # Set up exported procedures
  60. #
  61. try:
  62. l = cdll.LoadLibrary(LIB_PATH)
  63. except:
  64. print("Couldn't find or load " + LIB_PATH + ".")
  65. exit(1)
  66. def load(export_name, args, res):
  67. export_name.argtypes = args
  68. export_name.restype = res
  69. return export_name
  70. error_string = load(l.test_error_string, [c_byte], c_char_p)
  71. #
  72. # res = a + b, err
  73. #
  74. add_two = load(l.test_add_two, [c_char_p, c_char_p, c_longlong], Res)
  75. #
  76. # res = a - b, err
  77. #
  78. sub_two = load(l.test_sub_two, [c_char_p, c_char_p, c_longlong], Res)
  79. #
  80. # res = a * b, err
  81. #
  82. mul_two = load(l.test_mul_two, [c_char_p, c_char_p, c_longlong], Res)
  83. #
  84. # res = a / b, err
  85. #
  86. div_two = load(l.test_div_two, [c_char_p, c_char_p, c_longlong], Res)
  87. #
  88. # res = log(a, base)
  89. #
  90. int_log = load(l.test_log, [c_char_p, c_longlong, c_longlong], Res)
  91. def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_result = ""):
  92. passed = True
  93. r = None
  94. if res.err != expected_error:
  95. error_type = error_string(res.err).decode('utf-8')
  96. error_loc = res.res.decode('utf-8')
  97. error = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
  98. if len(param):
  99. error += " with params {}".format(param)
  100. print(error, flush=True)
  101. passed = False
  102. elif res.err == E_None:
  103. try:
  104. r = res.res.decode('utf-8')
  105. except:
  106. pass
  107. r = eval(res.res)
  108. if r != expected_result:
  109. error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  110. if len(param):
  111. error += " with params {}".format(param)
  112. print(error, flush=True)
  113. passed = False
  114. return passed
  115. def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  116. args = [str(a), str(b), radix]
  117. sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
  118. res = add_two(sa_c, sb_c, radix)
  119. if expected_result == None:
  120. expected_result = a + b
  121. return test("test_add_two", res, args, expected_error, expected_result)
  122. def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  123. sa, sb = str(a), str(b)
  124. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  125. res = sub_two(sa_c, sb_c, radix)
  126. if expected_result == None:
  127. expected_result = a - b
  128. return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  129. def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  130. sa, sb = str(a), str(b)
  131. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  132. res = mul_two(sa_c, sb_c, radix)
  133. if expected_result == None:
  134. expected_result = a * b
  135. return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  136. def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  137. sa, sb = str(a), str(b)
  138. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  139. try:
  140. res = div_two(sa_c, sb_c, radix)
  141. except:
  142. print("Exception with arguments:", a, b, radix)
  143. return False
  144. if expected_result == None:
  145. expected_result = a // b if b != 0 else None
  146. return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  147. def test_log(a = 0, base = 0, radix = 10, expected_error = E_None, expected_result = None):
  148. args = [str(a), base, radix]
  149. sa_c = args[0].encode('utf-8')
  150. res = int_log(sa_c, base, radix)
  151. if expected_result == None:
  152. expected_result = int(log(a, base))
  153. return test("test_log", res, args, expected_error, expected_result)
  154. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  155. #
  156. # The last two arguments in tests are the expected error and expected result.
  157. #
  158. # The expected error defaults to None.
  159. # By default the Odin implementation will be tested against the Python one.
  160. # You can override that by supplying an expected result as the last argument instead.
  161. TESTS = {
  162. test_add_two: [
  163. [ 1234, 5432, 10, ],
  164. [ 1234, 5432, 110, E_Invalid_Argument, ],
  165. ],
  166. test_sub_two: [
  167. [ 1234, 5432, 10, ],
  168. ],
  169. test_mul_two: [
  170. [ 1234, 5432, 10, ],
  171. [ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
  172. ],
  173. test_div_two: [
  174. [ 54321, 12345, 10, ],
  175. [ 55431, 0, 10, E_Division_by_Zero, ],
  176. ],
  177. test_log: [
  178. [ 3192, 1, 10, E_Invalid_Argument, ":log:log(a, base):"],
  179. [ -1234, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"],
  180. [ 0, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"],
  181. [ 1024, 2, 10, ],
  182. ],
  183. }
  184. TOTAL_TIME = 0
  185. total_passes = 0
  186. total_failures = 0
  187. if __name__ == '__main__':
  188. test_log(1234, 2, 10)
  189. print("---- core:math/big tests ----")
  190. print()
  191. for test_proc in TESTS:
  192. count_pass = 0
  193. count_fail = 0
  194. TIMINGS = {}
  195. for t in TESTS[test_proc]:
  196. start = time.perf_counter()
  197. res = test_proc(*t)
  198. diff = time.perf_counter() - start
  199. TOTAL_TIME += diff
  200. if test_proc not in TIMINGS:
  201. TIMINGS[test_proc] = diff
  202. else:
  203. TIMINGS[test_proc] += diff
  204. if res:
  205. count_pass += 1
  206. total_passes += 1
  207. else:
  208. count_fail += 1
  209. total_failures += 1
  210. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  211. for BITS, ITERATIONS in BITS_AND_ITERATIONS:
  212. print()
  213. print("---- core:math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
  214. print()
  215. for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two, test_log]:
  216. count_pass = 0
  217. count_fail = 0
  218. TIMINGS = {}
  219. for i in range(ITERATIONS):
  220. a = randint(0, 1 << BITS)
  221. if test_proc == test_div_two:
  222. # We've already tested division by zero above.
  223. b = randint(1, 1 << BITS)
  224. elif test_proc == test_log:
  225. # We've already tested log's domain errors.
  226. a = randint(1, 1 << BITS)
  227. b = randint(2, 1 << 60)
  228. else:
  229. b = randint(0, 1 << BITS)
  230. res = None
  231. start = time.perf_counter()
  232. res = test_proc(a, b)
  233. diff = time.perf_counter() - start
  234. TOTAL_TIME += diff
  235. if test_proc not in TIMINGS:
  236. TIMINGS[test_proc] = diff
  237. else:
  238. TIMINGS[test_proc] += diff
  239. if res:
  240. count_pass += 1
  241. total_passes += 1
  242. else:
  243. count_fail += 1
  244. total_failures += 1
  245. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  246. print()
  247. print("---- THE END ----")
  248. print()
  249. print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
  250. if total_failures:
  251. exit(1)