test.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. from math import *
  2. from ctypes import *
  3. from random import *
  4. import os
  5. import platform
  6. import time
  7. from enum import Enum
  8. #
  9. # Normally, we report the number of passes and fails.
  10. # With EXIT_ON_FAIL set, we exit at the first fail.
  11. #
  12. EXIT_ON_FAIL = False
  13. #
  14. # We skip randomized tests altogether if NO_RANDOM_TESTS is set.
  15. #
  16. NO_RANDOM_TESTS = False #True
  17. #
  18. # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
  19. # See below.
  20. #
  21. FAST_TESTS = True
  22. #
  23. # For timed tests we budget a second per `n` bits and iterate until we hit that time.
  24. # Otherwise, we specify the number of iterations per bit depth in BITS_AND_ITERATIONS.
  25. #
  26. TIMED_TESTS = False
  27. TIMED_BITS_PER_SECOND = 20_000
  28. #
  29. # How many iterations of each random test do we want to run?
  30. #
  31. BITS_AND_ITERATIONS = [
  32. ( 120, 10_000),
  33. ( 1_200, 1_000),
  34. ( 4_096, 100),
  35. (12_000, 10),
  36. ]
  37. if FAST_TESTS:
  38. for k in range(len(BITS_AND_ITERATIONS)):
  39. b, i = BITS_AND_ITERATIONS[k]
  40. BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
  41. if NO_RANDOM_TESTS:
  42. BITS_AND_ITERATIONS = []
  43. #
  44. # Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
  45. #
  46. if platform.system() == "Windows":
  47. LIB_PATH = os.getcwd() + os.sep + "big.dll"
  48. elif platform.system() == "Linux":
  49. LIB_PATH = os.getcwd() + os.sep + "big.so"
  50. elif platform.system() == "Darwin":
  51. LIB_PATH = os.getcwd() + os.sep + "big.dylib"
  52. else:
  53. print("Platform is unsupported.")
  54. exit(1)
  55. TOTAL_TIME = 0
  56. UNTIL_TIME = 0
  57. UNTIL_ITERS = 0
  58. def we_iterate():
  59. if TIMED_TESTS:
  60. return TOTAL_TIME < UNTIL_TIME
  61. else:
  62. global UNTIL_ITERS
  63. UNTIL_ITERS -= 1
  64. return UNTIL_ITERS != -1
  65. #
  66. # Error enum values
  67. #
  68. class Error(Enum):
  69. Okay = 0
  70. Out_Of_Memory = 1
  71. Invalid_Pointer = 2
  72. Invalid_Argument = 3
  73. Unknown_Error = 4
  74. Max_Iterations_Reached = 5
  75. Buffer_Overflow = 6
  76. Integer_Overflow = 7
  77. Division_by_Zero = 8
  78. Math_Domain_Error = 9
  79. Unimplemented = 127
  80. #
  81. # Set up exported procedures
  82. #
  83. try:
  84. l = cdll.LoadLibrary(LIB_PATH)
  85. except:
  86. print("Couldn't find or load " + LIB_PATH + ".")
  87. exit(1)
  88. def load(export_name, args, res):
  89. export_name.argtypes = args
  90. export_name.restype = res
  91. return export_name
  92. #
  93. # Result values will be passed in a struct { res: cstring, err: Error }
  94. #
  95. class Res(Structure):
  96. _fields_ = [("res", c_char_p), ("err", c_uint64)]
  97. error_string = load(l.test_error_string, [c_byte], c_char_p)
  98. add_two = load(l.test_add_two, [c_char_p, c_char_p], Res)
  99. sub_two = load(l.test_sub_two, [c_char_p, c_char_p], Res)
  100. mul_two = load(l.test_mul_two, [c_char_p, c_char_p], Res)
  101. div_two = load(l.test_div_two, [c_char_p, c_char_p], Res)
  102. int_log = load(l.test_log, [c_char_p, c_longlong], Res)
  103. int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
  104. def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
  105. passed = True
  106. r = None
  107. err = Error(res.err)
  108. if err != expected_error:
  109. error_loc = res.res.decode('utf-8')
  110. error = "{}: {} in '{}'".format(test_name, err, error_loc)
  111. if len(param):
  112. error += " with params {}".format(param)
  113. print(error, flush=True)
  114. passed = False
  115. elif err == Error.Okay:
  116. r = None
  117. try:
  118. r = res.res.decode('utf-8')
  119. r = int(res.res, 10)
  120. except:
  121. pass
  122. if r != expected_result:
  123. error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  124. if len(param):
  125. error += " with params {}".format(param)
  126. print(error, flush=True)
  127. passed = False
  128. if EXIT_ON_FAIL and not passed: exit(res.err)
  129. return passed
  130. def test_add_two(a = 0, b = 0, expected_error = Error.Okay):
  131. args = [str(a), str(b)]
  132. sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
  133. res = add_two(sa_c, sb_c)
  134. expected_result = None
  135. if expected_error == Error.Okay:
  136. expected_result = a + b
  137. return test("test_add_two", res, args, expected_error, expected_result)
  138. def test_sub_two(a = 0, b = 0, expected_error = Error.Okay):
  139. sa, sb = str(a), str(b)
  140. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  141. res = sub_two(sa_c, sb_c)
  142. expected_result = None
  143. if expected_error == Error.Okay:
  144. expected_result = a - b
  145. return test("test_sub_two", res, [sa_c, sb_c], expected_error, expected_result)
  146. def test_mul_two(a = 0, b = 0, expected_error = Error.Okay):
  147. sa, sb = str(a), str(b)
  148. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  149. res = mul_two(sa_c, sb_c)
  150. expected_result = None
  151. if expected_error == Error.Okay:
  152. expected_result = a * b
  153. return test("test_mul_two", res, [sa_c, sb_c], expected_error, expected_result)
  154. def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
  155. sa, sb = str(a), str(b)
  156. sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
  157. try:
  158. res = div_two(sa_c, sb_c)
  159. except:
  160. print("Exception with arguments:", a, b)
  161. return False
  162. expected_result = None
  163. if expected_error == Error.Okay:
  164. #
  165. # We don't round the division results, so if one component is negative, we're off by one.
  166. #
  167. if a < 0 and b > 0:
  168. expected_result = int(-(abs(a) / b))
  169. elif b < 0 and a > 0:
  170. expected_result = int(-(a / abs((b))))
  171. else:
  172. expected_result = a // b if b != 0 else None
  173. return test("test_div_two", res, [sa_c, sb_c], expected_error, expected_result)
  174. def test_log(a = 0, base = 0, expected_error = Error.Okay):
  175. args = [str(a), base]
  176. sa_c = args[0].encode('utf-8')
  177. res = int_log(sa_c, base)
  178. expected_result = None
  179. if expected_error == Error.Okay:
  180. expected_result = int(log(a, base))
  181. return test("test_log", res, args, expected_error, expected_result)
  182. def test_pow(base = 0, power = 0, expected_error = Error.Okay):
  183. args = [str(base), power]
  184. sa_c = args[0].encode('utf-8')
  185. res = int_pow(sa_c, power)
  186. expected_result = None
  187. if expected_error == Error.Okay:
  188. if power < 0:
  189. expected_result = 0
  190. else:
  191. expected_result = int(base**power)
  192. return test("test_pow", res, args, expected_error, expected_result)
  193. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  194. #
  195. # The last two arguments in tests are the expected error and expected result.
  196. #
  197. # The expected error defaults to None.
  198. # By default the Odin implementation will be tested against the Python one.
  199. # You can override that by supplying an expected result as the last argument instead.
  200. TESTS = {
  201. test_add_two: [
  202. [ 1234, 5432],
  203. ],
  204. test_sub_two: [
  205. [ 1234, 5432],
  206. ],
  207. test_mul_two: [
  208. [ 1234, 5432],
  209. [ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7]
  210. ],
  211. test_div_two: [
  212. [ 54321, 12345],
  213. [ 55431, 0, Error.Division_by_Zero],
  214. ],
  215. test_log: [
  216. [ 3192, 1, Error.Invalid_Argument],
  217. [ -1234, 2, Error.Math_Domain_Error],
  218. [ 0, 2, Error.Math_Domain_Error],
  219. [ 1024, 2],
  220. ],
  221. test_pow: [
  222. [ 0, -1, Error.Math_Domain_Error ], # Math
  223. [ 0, 0 ], # 1
  224. [ 0, 2 ], # 0
  225. [ 42, -1,], # 0
  226. [ 42, 1 ], # 1
  227. [ 42, 0 ], # 42
  228. [ 42, 2 ], # 42*42
  229. ],
  230. }
  231. total_passes = 0
  232. total_failures = 0
  233. RANDOM_TESTS = [
  234. test_add_two, test_sub_two, test_mul_two, test_div_two,
  235. test_log, test_pow,
  236. ]
  237. # Untimed warmup.
  238. for test_proc in TESTS:
  239. for t in TESTS[test_proc]:
  240. res = test_proc(*t)
  241. if __name__ == '__main__':
  242. print("---- math/big tests ----")
  243. print()
  244. for test_proc in TESTS:
  245. count_pass = 0
  246. count_fail = 0
  247. TIMINGS = {}
  248. for t in TESTS[test_proc]:
  249. start = time.perf_counter()
  250. res = test_proc(*t)
  251. diff = time.perf_counter() - start
  252. TOTAL_TIME += diff
  253. if test_proc not in TIMINGS:
  254. TIMINGS[test_proc] = diff
  255. else:
  256. TIMINGS[test_proc] += diff
  257. if res:
  258. count_pass += 1
  259. total_passes += 1
  260. else:
  261. count_fail += 1
  262. total_failures += 1
  263. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  264. for BITS, ITERATIONS in BITS_AND_ITERATIONS:
  265. print()
  266. print("---- math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
  267. print()
  268. for test_proc in RANDOM_TESTS:
  269. if test_proc == test_pow and BITS > 1_200: continue
  270. count_pass = 0
  271. count_fail = 0
  272. TIMINGS = {}
  273. UNTIL_ITERS = ITERATIONS
  274. UNTIL_TIME = TOTAL_TIME + BITS / TIMED_BITS_PER_SECOND
  275. # We run each test for a second per 20k bits
  276. while we_iterate():
  277. a = randint(-(1 << BITS), 1 << BITS)
  278. b = randint(-(1 << BITS), 1 << BITS)
  279. if test_proc == test_div_two:
  280. # We've already tested division by zero above.
  281. if b == 0:
  282. b == 42
  283. elif test_proc == test_log:
  284. # We've already tested log's domain errors.
  285. a = randint(1, 1 << BITS)
  286. b = randint(2, 1 << 60)
  287. elif test_proc == test_pow:
  288. b = randint(1, 10)
  289. else:
  290. b = randint(0, 1 << BITS)
  291. res = None
  292. start = time.perf_counter()
  293. res = test_proc(a, b)
  294. diff = time.perf_counter() - start
  295. TOTAL_TIME += diff
  296. if test_proc not in TIMINGS:
  297. TIMINGS[test_proc] = diff
  298. else:
  299. TIMINGS[test_proc] += diff
  300. if res:
  301. count_pass += 1; total_passes += 1
  302. else:
  303. count_fail += 1; total_failures += 1
  304. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  305. print()
  306. print("---- THE END ----")
  307. print()
  308. print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
  309. if total_failures:
  310. exit(1)