test.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from math import *
  2. from ctypes import *
  3. from random import *
  4. import os
  5. import platform
  6. import time
  7. from enum import Enum
  8. #
  9. # How many iterations of each random test do we want to run?
  10. #
  11. BITS_AND_ITERATIONS = [
  12. ( 120, 10_000),
  13. ( 1_200, 1_000),
  14. ( 4_096, 100),
  15. (12_000, 10),
  16. ]
  17. #
  18. # For timed tests we budget a second per `n` bits and iterate until we hit that time.
  19. # Otherwise, we specify the number of iterations per bit depth in BITS_AND_ITERATIONS.
  20. #
  21. TIMED_TESTS = False
  22. TIMED_BITS_PER_SECOND = 20_000
  23. #
  24. # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
  25. # See below.
  26. #
  27. FAST_TESTS = True
  28. if FAST_TESTS:
  29. for k in range(len(BITS_AND_ITERATIONS)):
  30. b, i = BITS_AND_ITERATIONS[k]
  31. BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
  32. #
  33. # Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
  34. #
  35. if platform.system() == "Windows":
  36. LIB_PATH = os.getcwd() + os.sep + "big.dll"
  37. elif platform.system() == "Linux":
  38. LIB_PATH = os.getcwd() + os.sep + "big.so"
  39. elif platform.system() == "Darwin":
  40. LIB_PATH = os.getcwd() + os.sep + "big.dylib"
  41. else:
  42. print("Platform is unsupported.")
  43. exit(1)
  44. TOTAL_TIME = 0
  45. UNTIL_TIME = 0
  46. UNTIL_ITERS = 0
  47. def we_iterate():
  48. if TIMED_TESTS:
  49. return TOTAL_TIME < UNTIL_TIME
  50. else:
  51. global UNTIL_ITERS
  52. UNTIL_ITERS -= 1
  53. return UNTIL_ITERS != -1
  54. #
  55. # Error enum values
  56. #
  57. class Error(Enum):
  58. Okay = 0
  59. Out_Of_Memory = 1
  60. Invalid_Pointer = 2
  61. Invalid_Argument = 3
  62. Unknown_Error = 4
  63. Max_Iterations_Reached = 5
  64. Buffer_Overflow = 6
  65. Integer_Overflow = 7
  66. Division_by_Zero = 8
  67. Math_Domain_Error = 9
  68. Unimplemented = 127
  69. #
  70. # Set up exported procedures
  71. #
  72. try:
  73. l = cdll.LoadLibrary(LIB_PATH)
  74. except:
  75. print("Couldn't find or load " + LIB_PATH + ".")
  76. exit(1)
  77. def load(export_name, args, res):
  78. export_name.argtypes = args
  79. export_name.restype = res
  80. return export_name
  81. #
  82. # Result values will be passed in a struct { res: cstring, err: Error }
  83. #
  84. class Res(Structure):
  85. _fields_ = [("res", c_char_p), ("err", c_uint64)]
  86. error_string = load(l.test_error_string, [c_byte], c_char_p)
  87. add_two = load(l.test_add_two, [c_char_p, c_char_p, c_longlong], Res)
  88. sub_two = load(l.test_sub_two, [c_char_p, c_char_p, c_longlong], Res)
  89. mul_two = load(l.test_mul_two, [c_char_p, c_char_p, c_longlong], Res)
  90. div_two = load(l.test_div_two, [c_char_p, c_char_p, c_longlong], Res)
  91. int_log = load(l.test_log, [c_char_p, c_longlong, c_longlong], Res)
  92. def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
  93. passed = True
  94. r = None
  95. err = Error(res.err)
  96. if err != expected_error:
  97. error_loc = res.res.decode('utf-8')
  98. error = "{}: {} in '{}'".format(test_name, err, error_loc)
  99. if len(param):
  100. error += " with params {}".format(param)
  101. print(error, flush=True)
  102. passed = False
  103. elif err == Error.Okay:
  104. r = None
  105. try:
  106. r = res.res.decode('utf-8')
  107. r = int(res.res, 10)
  108. except:
  109. pass
  110. if r != expected_result:
  111. error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  112. if len(param):
  113. error += " with params {}".format(param)
  114. print(error, flush=True)
  115. passed = False
  116. if not passed:
  117. exit()
  118. return passed
  119. def test_add_two(a = 0, b = 0, radix = 10, expected_error = Error.Okay):
  120. args = [str(a), str(b), radix]
  121. sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
  122. res = add_two(sa_c, sb_c, radix)
  123. expected_result = None
  124. if expected_error == Error.Okay:
  125. expected_result = a + b
  126. return test("test_add_two", res, args, expected_error, expected_result)
  127. def test_sub_two(a = 0, b = 0, radix = 10, expected_error = Error.Okay):
  128. sa, sb = str(a), str(b)
  129. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  130. res = sub_two(sa_c, sb_c, radix)
  131. expected_result = None
  132. if expected_error == Error.Okay:
  133. expected_result = a - b
  134. return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  135. def test_mul_two(a = 0, b = 0, radix = 10, expected_error = Error.Okay):
  136. sa, sb = str(a), str(b)
  137. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  138. res = mul_two(sa_c, sb_c, radix)
  139. expected_result = None
  140. if expected_error == Error.Okay:
  141. expected_result = a * b
  142. return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  143. def test_div_two(a = 0, b = 0, radix = 10, expected_error = Error.Okay):
  144. sa, sb = str(a), str(b)
  145. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  146. try:
  147. res = div_two(sa_c, sb_c, radix)
  148. except:
  149. print("Exception with arguments:", a, b, radix)
  150. return False
  151. expected_result = None
  152. if expected_error == Error.Okay:
  153. #
  154. # We don't round the division results, so if one component is negative, we're off by one.
  155. #
  156. if a < 0 and b > 0:
  157. expected_result = int(-(abs(a) / b))
  158. elif b < 0 and a > 0:
  159. expected_result = int(-(a / abs((b))))
  160. else:
  161. expected_result = a // b if b != 0 else None
  162. return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  163. def test_log(a = 0, base = 0, radix = 10, expected_error = Error.Okay):
  164. args = [str(a), base, radix]
  165. sa_c = args[0].encode('utf-8')
  166. res = int_log(sa_c, base, radix)
  167. expected_result = None
  168. if expected_error == Error.Okay:
  169. expected_result = int(log(a, base))
  170. return test("test_log", res, args, expected_error, expected_result)
  171. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  172. #
  173. # The last two arguments in tests are the expected error and expected result.
  174. #
  175. # The expected error defaults to None.
  176. # By default the Odin implementation will be tested against the Python one.
  177. # You can override that by supplying an expected result as the last argument instead.
  178. TESTS = {
  179. test_add_two: [
  180. [ 1234, 5432, 10, ],
  181. [ 1234, 5432, 110, Error.Invalid_Argument],
  182. ],
  183. test_sub_two: [
  184. [ 1234, 5432, 10, ],
  185. ],
  186. test_mul_two: [
  187. [ 1234, 5432, 10, ],
  188. [ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7, 10, ]
  189. ],
  190. test_div_two: [
  191. [ 54321, 12345, 10, ],
  192. [ 55431, 0, 10, Error.Division_by_Zero],
  193. ],
  194. test_log: [
  195. [ 3192, 1, 10, Error.Invalid_Argument],
  196. [ -1234, 2, 10, Error.Math_Domain_Error],
  197. [ 0, 2, 10, Error.Math_Domain_Error],
  198. [ 1024, 2, 10, ],
  199. ],
  200. }
  201. RANDOM_TESTS = [test_add_two, test_sub_two, test_mul_two, test_div_two, test_log]
  202. total_passes = 0
  203. total_failures = 0
  204. if __name__ == '__main__':
  205. test_log(1234, 2, 10)
  206. print("---- core:math/big tests ----")
  207. print()
  208. for test_proc in TESTS:
  209. count_pass = 0
  210. count_fail = 0
  211. TIMINGS = {}
  212. for t in TESTS[test_proc]:
  213. start = time.perf_counter()
  214. res = test_proc(*t)
  215. diff = time.perf_counter() - start
  216. TOTAL_TIME += diff
  217. if test_proc not in TIMINGS:
  218. TIMINGS[test_proc] = diff
  219. else:
  220. TIMINGS[test_proc] += diff
  221. if res:
  222. count_pass += 1
  223. total_passes += 1
  224. else:
  225. count_fail += 1
  226. total_failures += 1
  227. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  228. for BITS, ITERATIONS in BITS_AND_ITERATIONS:
  229. print()
  230. print("---- core:math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
  231. print()
  232. for test_proc in RANDOM_TESTS:
  233. count_pass = 0
  234. count_fail = 0
  235. TIMINGS = {}
  236. UNTIL_ITERS = ITERATIONS
  237. UNTIL_TIME = TOTAL_TIME + BITS / TIMED_BITS_PER_SECOND
  238. # We run each test for a second per 20k bits
  239. while we_iterate():
  240. a = randint(-(1 << BITS), 1 << BITS)
  241. b = randint(-(1 << BITS), 1 << BITS)
  242. if test_proc == test_div_two:
  243. # We've already tested division by zero above.
  244. if b == 0:
  245. b == 42
  246. elif test_proc == test_log:
  247. # We've already tested log's domain errors.
  248. a = randint(1, 1 << BITS)
  249. b = randint(2, 1 << 60)
  250. else:
  251. b = randint(0, 1 << BITS)
  252. res = None
  253. start = time.perf_counter()
  254. res = test_proc(a, b)
  255. diff = time.perf_counter() - start
  256. TOTAL_TIME += diff
  257. if test_proc not in TIMINGS:
  258. TIMINGS[test_proc] = diff
  259. else:
  260. TIMINGS[test_proc] += diff
  261. if res:
  262. count_pass += 1; total_passes += 1
  263. else:
  264. count_fail += 1; total_failures += 1
  265. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  266. print()
  267. print("---- THE END ----")
  268. print()
  269. print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
  270. if total_failures:
  271. exit(1)