test.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from math import *
  2. from ctypes import *
  3. from random import *
  4. import os
  5. import time
  6. #
  7. # Where is the DLL? If missing, build using: `odin build . -build-mode:dll`
  8. #
  9. LIB_PATH = os.getcwd() + os.sep + "big.dll"
  10. #
  11. # How many iterations of each random test do we want to run?
  12. #
  13. RANDOM_ITERATIONS = 10_000
  14. #
  15. # Result values will be passed in a struct { res: cstring, err: Error }
  16. #
  17. class Res(Structure):
  18. _fields_ = [("res", c_char_p), ("err", c_byte)]
  19. #
  20. # Error enum values
  21. #
  22. E_None = 0
  23. E_Out_Of_Memory = 1
  24. E_Invalid_Pointer = 2
  25. E_Invalid_Argument = 3
  26. E_Unknown_Error = 4
  27. E_Max_Iterations_Reached = 5
  28. E_Buffer_Overflow = 6
  29. E_Integer_Overflow = 7
  30. E_Division_by_Zero = 8
  31. E_Math_Domain_Error = 9
  32. E_Unimplemented = 127
  33. #
  34. # Set up exported procedures
  35. #
  36. try:
  37. l = cdll.LoadLibrary(LIB_PATH)
  38. except:
  39. print("Couldn't find or load " + LIB_PATH + ".")
  40. exit(1)
  41. #
  42. # res = a + b, err
  43. #
  44. try:
  45. l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong]
  46. l.test_add_two.restype = Res
  47. except:
  48. print("Couldn't find exported function 'test_add_two'")
  49. exit(2)
  50. add_two = l.test_add_two
  51. #
  52. # res = a - b, err
  53. #
  54. try:
  55. l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong]
  56. l.test_sub_two.restype = Res
  57. except:
  58. print("Couldn't find exported function 'test_sub_two'")
  59. exit(2)
  60. sub_two = l.test_sub_two
  61. #
  62. # res = a * b, err
  63. #
  64. try:
  65. l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong]
  66. l.test_mul_two.restype = Res
  67. except:
  68. print("Couldn't find exported function 'test_add_two'")
  69. exit(2)
  70. mul_two = l.test_mul_two
  71. #
  72. # res = a / b, err
  73. #
  74. try:
  75. l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong]
  76. l.test_div_two.restype = Res
  77. except:
  78. print("Couldn't find exported function 'test_div_two'")
  79. exit(2)
  80. div_two = l.test_div_two
  81. try:
  82. l.test_error_string.argtypes = [c_byte]
  83. l.test_error_string.restype = c_char_p
  84. except:
  85. print("Couldn't find exported function 'test_error_string'")
  86. exit(2)
  87. def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_result = ""):
  88. passed = True
  89. r = None
  90. if res.err != expected_error:
  91. error_type = l.test_error_string(res.err).decode('utf-8')
  92. error_loc = res.res.decode('utf-8')
  93. error_string = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
  94. if len(param):
  95. error_string += " with params {}".format(param)
  96. print(error_string, flush=True)
  97. passed = False
  98. elif res.err == E_None:
  99. try:
  100. r = res.res.decode('utf-8')
  101. except:
  102. pass
  103. r = eval(res.res)
  104. if r != expected_result:
  105. error_string = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  106. if len(param):
  107. error_string += " with params {}".format(param)
  108. print(error_string, flush=True)
  109. passed = False
  110. return passed
  111. def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  112. sa = str(a)
  113. sb = str(b)
  114. sa_c = sa.encode('utf-8')
  115. sb_c = sb.encode('utf-8')
  116. res = add_two(sa_c, sb_c, radix)
  117. if expected_result == None:
  118. expected_result = a + b
  119. return test("test_add_two", res, [sa, sb, radix], expected_error, expected_result)
  120. def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  121. sa = str(a)
  122. sb = str(b)
  123. sa_c = sa.encode('utf-8')
  124. sb_c = sb.encode('utf-8')
  125. res = sub_two(sa_c, sb_c, radix)
  126. if expected_result == None:
  127. expected_result = a - b
  128. return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  129. def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  130. sa = str(a)
  131. sb = str(b)
  132. sa_c = sa.encode('utf-8')
  133. sb_c = sb.encode('utf-8')
  134. res = mul_two(sa_c, sb_c, radix)
  135. if expected_result == None:
  136. expected_result = a * b
  137. return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  138. def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
  139. sa = str(a)
  140. sb = str(b)
  141. sa_c = sa.encode('utf-8')
  142. sb_c = sb.encode('utf-8')
  143. try:
  144. res = div_two(sa_c, sb_c, radix)
  145. except:
  146. print("Exception with arguments:", a, b, radix)
  147. return False
  148. if expected_result == None:
  149. expected_result = a // b if b != 0 else None
  150. return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
  151. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  152. #
  153. # The last two arguments in tests are the expected error and expected result.
  154. #
  155. # The expected error defaults to None.
  156. # By default the Odin implementation will be tested against the Python one.
  157. # You can override that by supplying an expected result as the last argument instead.
  158. TESTS = {
  159. test_add_two: [
  160. [ 1234, 5432, 10, ],
  161. [ 1234, 5432, 110, E_Invalid_Argument, ],
  162. ],
  163. test_sub_two: [
  164. [ 1234, 5432, 10, ],
  165. ],
  166. test_mul_two: [
  167. [ 1234, 5432, 10, ],
  168. [ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
  169. ],
  170. test_div_two: [
  171. [ 54321, 12345, 10, ],
  172. [ 55431, 0, 10, E_Division_by_Zero, ],
  173. ],
  174. }
  175. TIMINGS = {}
  176. if __name__ == '__main__':
  177. print()
  178. print("---- core:math/big tests ----")
  179. print()
  180. for test_proc in TESTS:
  181. count_pass = 0
  182. count_fail = 0
  183. for t in TESTS[test_proc]:
  184. start = time.perf_counter()
  185. res = test_proc(*t)
  186. diff = time.perf_counter() - start
  187. if test_proc not in TIMINGS:
  188. TIMINGS[test_proc] = diff
  189. else:
  190. TIMINGS[test_proc] += diff
  191. if res:
  192. count_pass += 1
  193. else:
  194. count_fail += 1
  195. print("{name}: {count_pass:,} passes, {count_fail:,} failures.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail))
  196. print()
  197. print("---- core:math/big random tests ----")
  198. print()
  199. for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]:
  200. count_pass = 0
  201. count_fail = 0
  202. for i in range(RANDOM_ITERATIONS):
  203. a = randint(0, 1 << 120)
  204. b = randint(0, 1 << 120)
  205. res = None
  206. # We've already tested division by zero above.
  207. if b == 0 and test_proc == test_div_two:
  208. b = b + 1
  209. start = time.perf_counter()
  210. res = test_proc(a, b)
  211. diff = time.perf_counter() - start
  212. if test_proc not in TIMINGS:
  213. TIMINGS[test_proc] = diff
  214. else:
  215. TIMINGS[test_proc] += diff
  216. if res:
  217. count_pass += 1
  218. else:
  219. count_fail += 1
  220. print("{name}: {count_pass:,} passes, {count_fail:,} failures.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail))
  221. print()
  222. total = 0
  223. for k in TIMINGS:
  224. print("{name}: {total:.3f} ms in {calls:,} calls".format(name=k.__name__, total=TIMINGS[k] * 1_000, calls=RANDOM_ITERATIONS + len(TESTS[k])))
  225. total += TIMINGS[k]
  226. print("\ntotal: {0:.3f} ms".format(total * 1_000))