test.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. from ctypes import *
  2. from random import *
  3. import math
  4. import os
  5. import platform
  6. import time
  7. from enum import Enum
  8. #
  9. # Normally, we report the number of passes and fails.
  10. # With EXIT_ON_FAIL set, we exit at the first fail.
  11. #
  12. EXIT_ON_FAIL = True
  13. EXIT_ON_FAIL = False
  14. #
  15. # We skip randomized tests altogether if NO_RANDOM_TESTS is set.
  16. #
  17. NO_RANDOM_TESTS = True
  18. NO_RANDOM_TESTS = False
  19. #
  20. # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
  21. # See below.
  22. #
  23. FAST_TESTS = True
  24. #
  25. # For timed tests we budget a second per `n` bits and iterate until we hit that time.
  26. # Otherwise, we specify the number of iterations per bit depth in BITS_AND_ITERATIONS.
  27. #
  28. TIMED_TESTS = False
  29. TIMED_BITS_PER_SECOND = 20_000
  30. #
  31. # How many iterations of each random test do we want to run?
  32. #
  33. BITS_AND_ITERATIONS = [
  34. ( 120, 10_000),
  35. ( 1_200, 1_000),
  36. ( 4_096, 100),
  37. (12_000, 10),
  38. ]
  39. if FAST_TESTS:
  40. for k in range(len(BITS_AND_ITERATIONS)):
  41. b, i = BITS_AND_ITERATIONS[k]
  42. BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
  43. if NO_RANDOM_TESTS:
  44. BITS_AND_ITERATIONS = []
  45. #
  46. # Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
  47. #
  48. if platform.system() == "Windows":
  49. LIB_PATH = os.getcwd() + os.sep + "big.dll"
  50. elif platform.system() == "Linux":
  51. LIB_PATH = os.getcwd() + os.sep + "big.so"
  52. elif platform.system() == "Darwin":
  53. LIB_PATH = os.getcwd() + os.sep + "big.dylib"
  54. else:
  55. print("Platform is unsupported.")
  56. exit(1)
  57. TOTAL_TIME = 0
  58. UNTIL_TIME = 0
  59. UNTIL_ITERS = 0
  60. def we_iterate():
  61. if TIMED_TESTS:
  62. return TOTAL_TIME < UNTIL_TIME
  63. else:
  64. global UNTIL_ITERS
  65. UNTIL_ITERS -= 1
  66. return UNTIL_ITERS != -1
  67. #
  68. # Error enum values
  69. #
  70. class Error(Enum):
  71. Okay = 0
  72. Out_Of_Memory = 1
  73. Invalid_Pointer = 2
  74. Invalid_Argument = 3
  75. Unknown_Error = 4
  76. Max_Iterations_Reached = 5
  77. Buffer_Overflow = 6
  78. Integer_Overflow = 7
  79. Division_by_Zero = 8
  80. Math_Domain_Error = 9
  81. Unimplemented = 127
  82. #
  83. # Set up exported procedures
  84. #
  85. try:
  86. l = cdll.LoadLibrary(LIB_PATH)
  87. except:
  88. print("Couldn't find or load " + LIB_PATH + ".")
  89. exit(1)
  90. def load(export_name, args, res):
  91. export_name.argtypes = args
  92. export_name.restype = res
  93. return export_name
  94. #
  95. # Result values will be passed in a struct { res: cstring, err: Error }
  96. #
  97. class Res(Structure):
  98. _fields_ = [("res", c_char_p), ("err", c_uint64)]
  99. initialize_constants = load(l.test_initialize_constants, [], c_uint64)
  100. initialize_constants()
  101. error_string = load(l.test_error_string, [c_byte], c_char_p)
  102. add = load(l.test_add, [c_char_p, c_char_p], Res)
  103. sub = load(l.test_sub, [c_char_p, c_char_p], Res)
  104. mul = load(l.test_mul, [c_char_p, c_char_p], Res)
  105. div = load(l.test_div, [c_char_p, c_char_p], Res)
  106. # Powers and such
  107. int_log = load(l.test_log, [c_char_p, c_longlong], Res)
  108. int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
  109. int_sqrt = load(l.test_sqrt, [c_char_p], Res)
  110. int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
  111. # Logical operations
  112. int_shl_digit = load(l.test_shl_digit, [c_char_p, c_longlong], Res)
  113. int_shr_digit = load(l.test_shr_digit, [c_char_p, c_longlong], Res)
  114. int_shl = load(l.test_shl, [c_char_p, c_longlong], Res)
  115. int_shr = load(l.test_shr, [c_char_p, c_longlong], Res)
  116. int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
  117. int_factorial = load(l.test_factorial, [c_uint64], Res)
  118. int_gcd = load(l.test_gcd, [c_char_p, c_char_p], Res)
  119. int_lcm = load(l.test_lcm, [c_char_p, c_char_p], Res)
  120. def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = "", radix=16):
  121. passed = True
  122. r = None
  123. err = Error(res.err)
  124. if err != expected_error:
  125. error_loc = res.res.decode('utf-8')
  126. error = "{}: {} in '{}'".format(test_name, err, error_loc)
  127. if len(param):
  128. error += " with params {}".format(param)
  129. print(error, flush=True)
  130. passed = False
  131. elif err == Error.Okay:
  132. r = None
  133. try:
  134. r = res.res.decode('utf-8')
  135. r = int(res.res, radix)
  136. except:
  137. pass
  138. if r != expected_result:
  139. error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  140. if len(param):
  141. error += " with params {}".format(param)
  142. print(error, flush=True)
  143. passed = False
  144. if EXIT_ON_FAIL and not passed: exit(res.err)
  145. return passed
  146. def arg_to_odin(a):
  147. if a >= 0:
  148. s = hex(a)[2:]
  149. else:
  150. s = '-' + hex(a)[3:]
  151. return s.encode('utf-8')
  152. def test_add(a = 0, b = 0, expected_error = Error.Okay):
  153. args = [arg_to_odin(a), arg_to_odin(b)]
  154. res = add(*args)
  155. expected_result = None
  156. if expected_error == Error.Okay:
  157. expected_result = a + b
  158. return test("test_add", res, [a, b], expected_error, expected_result)
  159. def test_sub(a = 0, b = 0, expected_error = Error.Okay):
  160. args = [arg_to_odin(a), arg_to_odin(b)]
  161. res = sub(*args)
  162. expected_result = None
  163. if expected_error == Error.Okay:
  164. expected_result = a - b
  165. return test("test_sub", res, [a, b], expected_error, expected_result)
  166. def test_mul(a = 0, b = 0, expected_error = Error.Okay):
  167. args = [arg_to_odin(a), arg_to_odin(b)]
  168. try:
  169. res = mul(*args)
  170. except OSError as e:
  171. print("{} while trying to multiply {} x {}.".format(e, a, b))
  172. if EXIT_ON_FAIL: exit(3)
  173. return False
  174. expected_result = None
  175. if expected_error == Error.Okay:
  176. expected_result = a * b
  177. return test("test_mul", res, [a, b], expected_error, expected_result)
  178. def test_div(a = 0, b = 0, expected_error = Error.Okay):
  179. args = [arg_to_odin(a), arg_to_odin(b)]
  180. res = div(*args)
  181. expected_result = None
  182. if expected_error == Error.Okay:
  183. #
  184. # We don't round the division results, so if one component is negative, we're off by one.
  185. #
  186. if a < 0 and b > 0:
  187. expected_result = int(-(abs(a) // b))
  188. elif b < 0 and a > 0:
  189. expected_result = int(-(a // abs((b))))
  190. else:
  191. expected_result = a // b if b != 0 else None
  192. return test("test_div", res, [a, b], expected_error, expected_result)
  193. def test_log(a = 0, base = 0, expected_error = Error.Okay):
  194. args = [arg_to_odin(a), base]
  195. res = int_log(*args)
  196. expected_result = None
  197. if expected_error == Error.Okay:
  198. expected_result = int(math.log(a, base))
  199. return test("test_log", res, [a, base], expected_error, expected_result)
  200. def test_pow(base = 0, power = 0, expected_error = Error.Okay):
  201. args = [arg_to_odin(base), power]
  202. res = int_pow(*args)
  203. expected_result = None
  204. if expected_error == Error.Okay:
  205. if power < 0:
  206. expected_result = 0
  207. else:
  208. # NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation.
  209. # Use built-in `pow` or `a**b` instead.
  210. expected_result = pow(base, power)
  211. return test("test_pow", res, [base, power], expected_error, expected_result)
  212. def test_sqrt(number = 0, expected_error = Error.Okay):
  213. args = [arg_to_odin(number)]
  214. try:
  215. res = int_sqrt(*args)
  216. except OSError as e:
  217. print("{} while trying to sqrt {}.".format(e, number))
  218. if EXIT_ON_FAIL: exit(3)
  219. return False
  220. expected_result = None
  221. if expected_error == Error.Okay:
  222. if number < 0:
  223. expected_result = 0
  224. else:
  225. expected_result = int(math.isqrt(number))
  226. return test("test_sqrt", res, [number], expected_error, expected_result)
  227. def root_n(number, root):
  228. u, s = number, number + 1
  229. while u < s:
  230. s = u
  231. t = (root-1) * s + number // pow(s, root - 1)
  232. u = t // root
  233. return s
  234. def test_root_n(number = 0, root = 0, expected_error = Error.Okay):
  235. args = [arg_to_odin(number), root]
  236. res = int_root_n(*args)
  237. expected_result = None
  238. if expected_error == Error.Okay:
  239. if number < 0:
  240. expected_result = 0
  241. else:
  242. expected_result = root_n(number, root)
  243. return test("test_root_n", res, [number, root], expected_error, expected_result)
  244. def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
  245. args = [arg_to_odin(a), digits]
  246. res = int_shl_digit(*args)
  247. expected_result = None
  248. if expected_error == Error.Okay:
  249. expected_result = a << (digits * 60)
  250. return test("test_shl_digit", res, [a, digits], expected_error, expected_result)
  251. def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay):
  252. args = [arg_to_odin(a), digits]
  253. res = int_shr_digit(*args)
  254. expected_result = None
  255. if expected_error == Error.Okay:
  256. if a < 0:
  257. # Don't pass negative numbers. We have a shr_signed.
  258. return False
  259. else:
  260. expected_result = a >> (digits * 60)
  261. return test("test_shr_digit", res, [a, digits], expected_error, expected_result)
  262. def test_shl(a = 0, bits = 0, expected_error = Error.Okay):
  263. args = [arg_to_odin(a), bits]
  264. res = int_shl(*args)
  265. expected_result = None
  266. if expected_error == Error.Okay:
  267. expected_result = a << bits
  268. return test("test_shl", res, [a, bits], expected_error, expected_result)
  269. def test_shr(a = 0, bits = 0, expected_error = Error.Okay):
  270. args = [arg_to_odin(a), bits]
  271. res = int_shr(*args)
  272. expected_result = None
  273. if expected_error == Error.Okay:
  274. if a < 0:
  275. # Don't pass negative numbers. We have a shr_signed.
  276. return False
  277. else:
  278. expected_result = a >> bits
  279. return test("test_shr", res, [a, bits], expected_error, expected_result)
  280. def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
  281. args = [arg_to_odin(a), bits]
  282. res = int_shr_signed(*args)
  283. expected_result = None
  284. if expected_error == Error.Okay:
  285. expected_result = a >> bits
  286. return test("test_shr_signed", res, [a, bits], expected_error, expected_result)
  287. def test_factorial(n = 0, expected_error = Error.Okay):
  288. args = [n]
  289. res = int_factorial(*args)
  290. expected_result = None
  291. if expected_error == Error.Okay:
  292. expected_result = math.factorial(n)
  293. return test("test_factorial", res, [n], expected_error, expected_result)
  294. def test_gcd(a = 0, b = 0, expected_error = Error.Okay):
  295. args = [arg_to_odin(a), arg_to_odin(b)]
  296. res = int_gcd(*args)
  297. expected_result = None
  298. if expected_error == Error.Okay:
  299. expected_result = math.gcd(a, b)
  300. return test("test_gcd", res, [a, b], expected_error, expected_result)
  301. def test_lcm(a = 0, b = 0, expected_error = Error.Okay):
  302. args = [arg_to_odin(a), arg_to_odin(b)]
  303. res = int_lcm(*args)
  304. expected_result = None
  305. if expected_error == Error.Okay:
  306. expected_result = math.lcm(a, b)
  307. return test("test_lcm", res, [a, b], expected_error, expected_result)
  308. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  309. #
  310. # The last two arguments in tests are the expected error and expected result.
  311. #
  312. # The expected error defaults to None.
  313. # By default the Odin implementation will be tested against the Python one.
  314. # You can override that by supplying an expected result as the last argument instead.
  315. TESTS = {
  316. test_add: [
  317. [ 1234, 5432],
  318. ],
  319. test_sub: [
  320. [ 1234, 5432],
  321. ],
  322. test_mul: [
  323. [ 1234, 5432],
  324. [ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7],
  325. ],
  326. test_div: [
  327. [ 54321, 12345],
  328. [ 55431, 0, Error.Division_by_Zero],
  329. [ 12980742146337069150589594264770969721, 4611686018427387904 ],
  330. [ 831956404029821402159719858789932422, 243087903122332132 ],
  331. ],
  332. test_log: [
  333. [ 3192, 1, Error.Invalid_Argument],
  334. [ -1234, 2, Error.Math_Domain_Error],
  335. [ 0, 2, Error.Math_Domain_Error],
  336. [ 1024, 2],
  337. ],
  338. test_pow: [
  339. [ 0, -1, Error.Math_Domain_Error ], # Math
  340. [ 0, 0 ], # 1
  341. [ 0, 2 ], # 0
  342. [ 42, -1,], # 0
  343. [ 42, 1 ], # 1
  344. [ 42, 0 ], # 42
  345. [ 42, 2 ], # 42*42
  346. ],
  347. test_sqrt: [
  348. [ -1, Error.Invalid_Argument, ],
  349. [ 42, Error.Okay, ],
  350. [ 12345678901234567890, Error.Okay, ],
  351. [ 1298074214633706907132624082305024, Error.Okay, ],
  352. [ 686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214, Error.Okay, ],
  353. ],
  354. test_root_n: [
  355. [ 1298074214633706907132624082305024, 2, Error.Okay, ],
  356. ],
  357. test_shl_digit: [
  358. [ 3192, 1 ],
  359. [ 1298074214633706907132624082305024, 2 ],
  360. [ 1024, 3 ],
  361. ],
  362. test_shr_digit: [
  363. [ 3680125442705055547392, 1 ],
  364. [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
  365. [ 219504133884436710204395031992179571, 2 ],
  366. ],
  367. test_shl: [
  368. [ 3192, 1 ],
  369. [ 1298074214633706907132624082305024, 2 ],
  370. [ 1024, 3 ],
  371. ],
  372. test_shr: [
  373. [ 3680125442705055547392, 1 ],
  374. [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
  375. [ 219504133884436710204395031992179571, 2 ],
  376. ],
  377. test_shr_signed: [
  378. [ -611105530635358368578155082258244262, 12 ],
  379. [ -149195686190273039203651143129455, 12 ],
  380. [ 611105530635358368578155082258244262, 12 ],
  381. [ 149195686190273039203651143129455, 12 ],
  382. ],
  383. test_factorial: [
  384. [ 6_000 ], # Regular factorial, see cutoff in common.odin.
  385. [ 12_345 ], # Binary split factorial
  386. ],
  387. test_gcd: [
  388. [ 23, 25, ],
  389. [ 125, 25, ],
  390. [ 125, 0, ],
  391. [ 0, 0, ],
  392. [ 0, 125,],
  393. ],
  394. test_lcm: [
  395. [ 23, 25,],
  396. [ 125, 25, ],
  397. [ 125, 0, ],
  398. [ 0, 0, ],
  399. [ 0, 125,],
  400. ],
  401. }
  402. if not FAST_TESTS:
  403. TESTS[test_factorial].append(
  404. # This one on its own takes around 800ms, so we exclude it for FAST_TESTS
  405. [ 100_000 ],
  406. )
  407. total_passes = 0
  408. total_failures = 0
  409. #
  410. # test_shr_signed also tests shr, so we're not going to test shr randomly.
  411. #
  412. RANDOM_TESTS = [
  413. test_add, test_sub, test_mul, test_div,
  414. test_log, test_pow, test_sqrt, test_root_n,
  415. test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
  416. test_gcd, test_lcm,
  417. ]
  418. SKIP_LARGE = [
  419. test_pow, test_root_n, # test_gcd,
  420. ]
  421. SKIP_LARGEST = []
  422. # Untimed warmup.
  423. for test_proc in TESTS:
  424. for t in TESTS[test_proc]:
  425. res = test_proc(*t)
  426. if __name__ == '__main__':
  427. print("---- math/big tests ----")
  428. print()
  429. for test_proc in TESTS:
  430. count_pass = 0
  431. count_fail = 0
  432. TIMINGS = {}
  433. for t in TESTS[test_proc]:
  434. start = time.perf_counter()
  435. res = test_proc(*t)
  436. diff = time.perf_counter() - start
  437. TOTAL_TIME += diff
  438. if test_proc not in TIMINGS:
  439. TIMINGS[test_proc] = diff
  440. else:
  441. TIMINGS[test_proc] += diff
  442. if res:
  443. count_pass += 1
  444. total_passes += 1
  445. else:
  446. count_fail += 1
  447. total_failures += 1
  448. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  449. for BITS, ITERATIONS in BITS_AND_ITERATIONS:
  450. print()
  451. print("---- math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
  452. print()
  453. for test_proc in RANDOM_TESTS:
  454. if BITS > 1_200 and test_proc in SKIP_LARGE: continue
  455. if BITS > 4_096 and test_proc in SKIP_LARGEST: continue
  456. count_pass = 0
  457. count_fail = 0
  458. TIMINGS = {}
  459. UNTIL_ITERS = ITERATIONS
  460. if test_proc == test_root_n and BITS == 1_200:
  461. UNTIL_ITERS /= 10
  462. UNTIL_TIME = TOTAL_TIME + BITS / TIMED_BITS_PER_SECOND
  463. # We run each test for a second per 20k bits
  464. while we_iterate():
  465. a = randint(-(1 << BITS), 1 << BITS)
  466. b = randint(-(1 << BITS), 1 << BITS)
  467. if test_proc == test_div:
  468. # We've already tested division by zero above.
  469. bits = int(BITS * 0.6)
  470. b = randint(-(1 << bits), 1 << bits)
  471. if b == 0:
  472. b == 42
  473. elif test_proc == test_log:
  474. # We've already tested log's domain errors.
  475. a = randint(1, 1 << BITS)
  476. b = randint(2, 1 << 60)
  477. elif test_proc == test_pow:
  478. b = randint(1, 10)
  479. elif test_proc == test_sqrt:
  480. a = randint(1, 1 << BITS)
  481. b = Error.Okay
  482. elif test_proc == test_root_n:
  483. a = randint(1, 1 << BITS)
  484. b = randint(1, 10);
  485. elif test_proc == test_shl_digit:
  486. b = randint(0, 10);
  487. elif test_proc == test_shr_digit:
  488. a = abs(a)
  489. b = randint(0, 10);
  490. elif test_proc == test_shl:
  491. b = randint(0, min(BITS, 120));
  492. elif test_proc == test_shr_signed:
  493. b = randint(0, min(BITS, 120));
  494. else:
  495. b = randint(0, 1 << BITS)
  496. res = None
  497. start = time.perf_counter()
  498. res = test_proc(a, b)
  499. diff = time.perf_counter() - start
  500. TOTAL_TIME += diff
  501. if test_proc not in TIMINGS:
  502. TIMINGS[test_proc] = diff
  503. else:
  504. TIMINGS[test_proc] += diff
  505. if res:
  506. count_pass += 1; total_passes += 1
  507. else:
  508. count_fail += 1; total_failures += 1
  509. print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  510. print()
  511. print("---- THE END ----")
  512. print()
  513. print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
  514. if total_failures:
  515. exit(1)