test.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. #
  2. # Copyright 2021 Jeroen van Rijn <[email protected]>.
  3. # Made available under Odin's BSD-3 license.
  4. #
  5. # A BigInt implementation in Odin.
  6. # For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3.
  7. # The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks.
  8. #
  9. from ctypes import *
  10. from random import *
  11. import math
  12. import os
  13. import platform
  14. import time
  15. import gc
  16. from enum import Enum
  17. import argparse
  18. parser = argparse.ArgumentParser(
  19. description = "Odin core:math/big test suite",
  20. epilog = "By default we run regression and random tests with preset parameters.",
  21. formatter_class = argparse.ArgumentDefaultsHelpFormatter,
  22. )
  23. #
  24. # Normally, we report the number of passes and fails. With this option set, we exit at first fail.
  25. #
  26. parser.add_argument(
  27. "-exit-on-fail",
  28. help = "Exit when a test fails",
  29. action = "store_true",
  30. )
  31. #
  32. # We skip randomized tests altogether if this is set.
  33. #
  34. no_random = parser.add_mutually_exclusive_group()
  35. no_random.add_argument(
  36. "-no-random",
  37. help = "No random tests",
  38. action = "store_true",
  39. )
  40. #
  41. # Normally we run a given number of cycles on each test.
  42. # Timed tests budget 1 second per 20_000 bits instead.
  43. #
  44. # For timed tests we budget a second per `n` bits and iterate until we hit that time.
  45. #
  46. timed_or_fast = no_random.add_mutually_exclusive_group()
  47. timed_or_fast.add_argument(
  48. "-timed",
  49. type = bool,
  50. default = False,
  51. help = "Timed tests instead of a preset number of iterations.",
  52. )
  53. parser.add_argument(
  54. "-timed-bits",
  55. type = int,
  56. metavar = "BITS",
  57. default = 20_000,
  58. help = "Timed tests. Every `BITS` worth of input is given a second of running time.",
  59. )
  60. #
  61. # For normal tests (non-timed), `-fast-tests` cuts down on the number of iterations.
  62. #
  63. timed_or_fast.add_argument(
  64. "-fast-tests",
  65. help = "Cut down on the number of iterations of each test",
  66. action = "store_true",
  67. )
  68. args = parser.parse_args()
  69. EXIT_ON_FAIL = args.exit_on_fail
  70. #
  71. # How many iterations of each random test do we want to run?
  72. #
  73. BITS_AND_ITERATIONS = [
  74. ( 120, 10_000),
  75. ( 1_200, 1_000),
  76. ( 4_096, 100),
  77. (12_000, 10),
  78. ]
  79. if args.fast_tests:
  80. for k in range(len(BITS_AND_ITERATIONS)):
  81. b, i = BITS_AND_ITERATIONS[k]
  82. BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
  83. if args.no_random:
  84. BITS_AND_ITERATIONS = []
  85. #
  86. # Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
  87. #
  88. if platform.system() == "Windows":
  89. LIB_PATH = os.getcwd() + os.sep + "test_library.dll"
  90. elif platform.system() == "Linux":
  91. LIB_PATH = os.getcwd() + os.sep + "test_library.so"
  92. elif platform.system() == "Darwin":
  93. LIB_PATH = os.getcwd() + os.sep + "test_library.dylib"
  94. else:
  95. print("Platform is unsupported.")
  96. exit(1)
  97. TOTAL_TIME = 0
  98. UNTIL_TIME = 0
  99. UNTIL_ITERS = 0
  100. def we_iterate():
  101. if args.timed:
  102. return TOTAL_TIME < UNTIL_TIME
  103. else:
  104. global UNTIL_ITERS
  105. UNTIL_ITERS -= 1
  106. return UNTIL_ITERS != -1
  107. #
  108. # Error enum values
  109. #
  110. class Error(Enum):
  111. Okay = 0
  112. Out_Of_Memory = 1
  113. Invalid_Pointer = 2
  114. Invalid_Argument = 3
  115. Unknown_Error = 4
  116. Max_Iterations_Reached = 5
  117. Buffer_Overflow = 6
  118. Integer_Overflow = 7
  119. Division_by_Zero = 8
  120. Math_Domain_Error = 9
  121. Unimplemented = 127
  122. #
  123. # Disable garbage collection
  124. #
  125. gc.disable()
  126. #
  127. # Set up exported procedures
  128. #
  129. try:
  130. l = cdll.LoadLibrary(LIB_PATH)
  131. except:
  132. print("Couldn't find or load " + LIB_PATH + ".")
  133. exit(1)
  134. def load(export_name, args, res):
  135. export_name.argtypes = args
  136. export_name.restype = res
  137. return export_name
  138. #
  139. # Result values will be passed in a struct { res: cstring, err: Error }
  140. #
  141. class Res(Structure):
  142. _fields_ = [("res", c_char_p), ("err", c_uint64)]
  143. initialize_constants = load(l.test_initialize_constants, [], c_uint64)
  144. print("initialize_constants: ", initialize_constants())
  145. error_string = load(l.test_error_string, [c_byte], c_char_p)
  146. add = load(l.test_add, [c_char_p, c_char_p ], Res)
  147. sub = load(l.test_sub, [c_char_p, c_char_p ], Res)
  148. mul = load(l.test_mul, [c_char_p, c_char_p ], Res)
  149. sqr = load(l.test_sqr, [c_char_p ], Res)
  150. div = load(l.test_div, [c_char_p, c_char_p ], Res)
  151. # Powers and such
  152. int_log = load(l.test_log, [c_char_p, c_longlong], Res)
  153. int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
  154. int_sqrt = load(l.test_sqrt, [c_char_p ], Res)
  155. int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
  156. # Logical operations
  157. int_shl_digit = load(l.test_shl_digit, [c_char_p, c_longlong], Res)
  158. int_shr_digit = load(l.test_shr_digit, [c_char_p, c_longlong], Res)
  159. int_shl = load(l.test_shl, [c_char_p, c_longlong], Res)
  160. int_shr = load(l.test_shr, [c_char_p, c_longlong], Res)
  161. int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
  162. int_factorial = load(l.test_factorial, [c_uint64 ], Res)
  163. int_gcd = load(l.test_gcd, [c_char_p, c_char_p ], Res)
  164. int_lcm = load(l.test_lcm, [c_char_p, c_char_p ], Res)
  165. is_square = load(l.test_is_square, [c_char_p ], Res)
  166. def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = "", radix=16):
  167. passed = True
  168. r = None
  169. err = Error(res.err)
  170. if err != expected_error:
  171. error_loc = res.res.decode('utf-8')
  172. error = "{}: {} in '{}'".format(test_name, err, error_loc)
  173. if len(param):
  174. error += " with params {}".format(param)
  175. print(error, flush=True)
  176. passed = False
  177. elif err == Error.Okay:
  178. r = None
  179. try:
  180. r = res.res.decode('utf-8')
  181. r = int(res.res, radix)
  182. except:
  183. pass
  184. if r != expected_result:
  185. error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
  186. if len(param):
  187. error += " with params {}".format(param)
  188. print(error, flush=True)
  189. passed = False
  190. if EXIT_ON_FAIL and not passed: exit(res.err)
  191. return passed
  192. def arg_to_odin(a):
  193. if a >= 0:
  194. s = hex(a)[2:]
  195. else:
  196. s = '-' + hex(a)[3:]
  197. return s.encode('utf-8')
  198. def big_integer_sqrt(src):
  199. # The Python version on Github's CI doesn't offer math.isqrt.
  200. # We implement our own
  201. count = src.bit_length()
  202. a, b = count >> 1, count & 1
  203. x = 1 << (a + b)
  204. while True:
  205. # y = (x + n // x) // 2
  206. t1 = src // x
  207. t2 = t1 + x
  208. y = t2 >> 1
  209. if y >= x:
  210. return x
  211. x, y = y, x
  212. def big_integer_lcm(a, b):
  213. # Computes least common multiple as `|a*b|/gcd(a,b)`
  214. # Divide the smallest by the GCD.
  215. if a == 0 or b == 0:
  216. return 0
  217. if abs(a) < abs(b):
  218. # Store quotient in `t2` such that `t2 * b` is the LCM.
  219. lcm = a // math.gcd(a, b)
  220. return abs(b * lcm)
  221. else:
  222. # Store quotient in `t2` such that `t2 * a` is the LCM.
  223. lcm = b // math.gcd(a, b)
  224. return abs(a * lcm)
  225. def test_add(a = 0, b = 0, expected_error = Error.Okay):
  226. args = [arg_to_odin(a), arg_to_odin(b)]
  227. res = add(*args)
  228. expected_result = None
  229. if expected_error == Error.Okay:
  230. expected_result = a + b
  231. return test("test_add", res, [a, b], expected_error, expected_result)
  232. def test_sub(a = 0, b = 0, expected_error = Error.Okay):
  233. args = [arg_to_odin(a), arg_to_odin(b)]
  234. res = sub(*args)
  235. expected_result = None
  236. if expected_error == Error.Okay:
  237. expected_result = a - b
  238. return test("test_sub", res, [a, b], expected_error, expected_result)
  239. def test_mul(a = 0, b = 0, expected_error = Error.Okay):
  240. args = [arg_to_odin(a), arg_to_odin(b)]
  241. try:
  242. res = mul(*args)
  243. except OSError as e:
  244. print("{} while trying to multiply {} x {}.".format(e, a, b))
  245. if EXIT_ON_FAIL: exit(3)
  246. return False
  247. expected_result = None
  248. if expected_error == Error.Okay:
  249. expected_result = a * b
  250. return test("test_mul", res, [a, b], expected_error, expected_result)
  251. def test_sqr(a = 0, b = 0, expected_error = Error.Okay):
  252. args = [arg_to_odin(a)]
  253. try:
  254. res = sqr(*args)
  255. except OSError as e:
  256. print("{} while trying to square {}.".format(e, a))
  257. if EXIT_ON_FAIL: exit(3)
  258. return False
  259. expected_result = None
  260. if expected_error == Error.Okay:
  261. expected_result = a * a
  262. return test("test_sqr", res, [a], expected_error, expected_result)
  263. def test_div(a = 0, b = 0, expected_error = Error.Okay):
  264. args = [arg_to_odin(a), arg_to_odin(b)]
  265. try:
  266. res = div(*args)
  267. except OSError as e:
  268. print("{} while trying divide to {} / {}.".format(e, a, b))
  269. if EXIT_ON_FAIL: exit(3)
  270. return False
  271. expected_result = None
  272. if expected_error == Error.Okay:
  273. #
  274. # We don't round the division results, so if one component is negative, we're off by one.
  275. #
  276. if a < 0 and b > 0:
  277. expected_result = int(-(abs(a) // b))
  278. elif b < 0 and a > 0:
  279. expected_result = int(-(a // abs((b))))
  280. else:
  281. expected_result = a // b if b != 0 else None
  282. return test("test_div", res, [a, b], expected_error, expected_result)
  283. def test_log(a = 0, base = 0, expected_error = Error.Okay):
  284. args = [arg_to_odin(a), base]
  285. res = int_log(*args)
  286. expected_result = None
  287. if expected_error == Error.Okay:
  288. expected_result = int(math.log(a, base))
  289. return test("test_log", res, [a, base], expected_error, expected_result)
  290. def test_pow(base = 0, power = 0, expected_error = Error.Okay):
  291. args = [arg_to_odin(base), power]
  292. res = int_pow(*args)
  293. expected_result = None
  294. if expected_error == Error.Okay:
  295. if power < 0:
  296. expected_result = 0
  297. else:
  298. # NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation.
  299. # Use built-in `pow` or `a**b` instead.
  300. expected_result = pow(base, power)
  301. return test("test_pow", res, [base, power], expected_error, expected_result)
  302. def test_sqrt(number = 0, expected_error = Error.Okay):
  303. args = [arg_to_odin(number)]
  304. try:
  305. res = int_sqrt(*args)
  306. except OSError as e:
  307. print("{} while trying to sqrt {}.".format(e, number))
  308. if EXIT_ON_FAIL: exit(3)
  309. return False
  310. expected_result = None
  311. if expected_error == Error.Okay:
  312. if number < 0:
  313. expected_result = 0
  314. else:
  315. expected_result = big_integer_sqrt(number)
  316. return test("test_sqrt", res, [number], expected_error, expected_result)
  317. def root_n(number, root):
  318. u, s = number, number + 1
  319. while u < s:
  320. s = u
  321. t = (root-1) * s + number // pow(s, root - 1)
  322. u = t // root
  323. return s
  324. def test_root_n(number = 0, root = 0, expected_error = Error.Okay):
  325. args = [arg_to_odin(number), root]
  326. res = int_root_n(*args)
  327. expected_result = None
  328. if expected_error == Error.Okay:
  329. if number < 0:
  330. expected_result = 0
  331. else:
  332. expected_result = root_n(number, root)
  333. return test("test_root_n", res, [number, root], expected_error, expected_result)
  334. def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
  335. args = [arg_to_odin(a), digits]
  336. res = int_shl_digit(*args)
  337. expected_result = None
  338. if expected_error == Error.Okay:
  339. expected_result = a << (digits * 60)
  340. return test("test_shl_digit", res, [a, digits], expected_error, expected_result)
  341. def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay):
  342. args = [arg_to_odin(a), digits]
  343. res = int_shr_digit(*args)
  344. expected_result = None
  345. if expected_error == Error.Okay:
  346. if a < 0:
  347. # Don't pass negative numbers. We have a shr_signed.
  348. return False
  349. else:
  350. expected_result = a >> (digits * 60)
  351. return test("test_shr_digit", res, [a, digits], expected_error, expected_result)
  352. def test_shl(a = 0, bits = 0, expected_error = Error.Okay):
  353. args = [arg_to_odin(a), bits]
  354. res = int_shl(*args)
  355. expected_result = None
  356. if expected_error == Error.Okay:
  357. expected_result = a << bits
  358. return test("test_shl", res, [a, bits], expected_error, expected_result)
  359. def test_shr(a = 0, bits = 0, expected_error = Error.Okay):
  360. args = [arg_to_odin(a), bits]
  361. res = int_shr(*args)
  362. expected_result = None
  363. if expected_error == Error.Okay:
  364. if a < 0:
  365. # Don't pass negative numbers. We have a shr_signed.
  366. return False
  367. else:
  368. expected_result = a >> bits
  369. return test("test_shr", res, [a, bits], expected_error, expected_result)
  370. def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
  371. args = [arg_to_odin(a), bits]
  372. res = int_shr_signed(*args)
  373. expected_result = None
  374. if expected_error == Error.Okay:
  375. expected_result = a >> bits
  376. return test("test_shr_signed", res, [a, bits], expected_error, expected_result)
  377. def test_factorial(number = 0, expected_error = Error.Okay):
  378. args = [number]
  379. try:
  380. res = int_factorial(*args)
  381. except OSError as e:
  382. print("{} while trying to factorial {}.".format(e, number))
  383. if EXIT_ON_FAIL: exit(3)
  384. return False
  385. expected_result = None
  386. if expected_error == Error.Okay:
  387. expected_result = math.factorial(number)
  388. return test("test_factorial", res, [number], expected_error, expected_result)
  389. def test_gcd(a = 0, b = 0, expected_error = Error.Okay):
  390. args = [arg_to_odin(a), arg_to_odin(b)]
  391. res = int_gcd(*args)
  392. expected_result = None
  393. if expected_error == Error.Okay:
  394. expected_result = math.gcd(a, b)
  395. return test("test_gcd", res, [a, b], expected_error, expected_result)
  396. def test_lcm(a = 0, b = 0, expected_error = Error.Okay):
  397. args = [arg_to_odin(a), arg_to_odin(b)]
  398. res = int_lcm(*args)
  399. expected_result = None
  400. if expected_error == Error.Okay:
  401. expected_result = big_integer_lcm(a, b)
  402. return test("test_lcm", res, [a, b], expected_error, expected_result)
  403. def test_is_square(a = 0, b = 0, expected_error = Error.Okay):
  404. args = [arg_to_odin(a)]
  405. res = is_square(*args)
  406. expected_result = None
  407. if expected_error == Error.Okay:
  408. expected_result = str(big_integer_sqrt(a) ** 2 == a) if a > 0 else "False"
  409. return test("test_is_square", res, [a], expected_error, expected_result)
  410. # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
  411. #
  412. # The last two arguments in tests are the expected error and expected result.
  413. #
  414. # The expected error defaults to None.
  415. # By default the Odin implementation will be tested against the Python one.
  416. # You can override that by supplying an expected result as the last argument instead.
  417. TESTS = {
  418. test_add: [
  419. [ 1234, 5432],
  420. ],
  421. test_sub: [
  422. [ 1234, 5432],
  423. ],
  424. test_mul: [
  425. [ 1234, 5432],
  426. [ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7 ],
  427. [ 1 << 21_105, 1 << 21_501 ],
  428. ],
  429. test_sqr: [
  430. [ 5432],
  431. [ 0xd3b4e926aaba3040e1c12b5ea553b5 ],
  432. ],
  433. test_div: [
  434. [ 54321, 12345],
  435. [ 55431, 0, Error.Division_by_Zero],
  436. [ 12980742146337069150589594264770969721, 4611686018427387904 ],
  437. [ 831956404029821402159719858789932422, 243087903122332132 ],
  438. ],
  439. test_log: [
  440. [ 3192, 1, Error.Invalid_Argument],
  441. [ -1234, 2, Error.Math_Domain_Error],
  442. [ 0, 2, Error.Math_Domain_Error],
  443. [ 1024, 2],
  444. ],
  445. test_pow: [
  446. [ 0, -1, Error.Math_Domain_Error ], # Math
  447. [ 0, 0 ], # 1
  448. [ 0, 2 ], # 0
  449. [ 42, -1,], # 0
  450. [ 42, 1 ], # 1
  451. [ 42, 0 ], # 42
  452. [ 42, 2 ], # 42*42
  453. ],
  454. test_sqrt: [
  455. [ -1, Error.Invalid_Argument, ],
  456. [ 42, Error.Okay, ],
  457. [ 12345678901234567890, Error.Okay, ],
  458. [ 1298074214633706907132624082305024, Error.Okay, ],
  459. [ 686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214, Error.Okay, ],
  460. ],
  461. test_root_n: [
  462. [ 1298074214633706907132624082305024, 2, Error.Okay, ],
  463. ],
  464. test_shl_digit: [
  465. [ 3192, 1 ],
  466. [ 1298074214633706907132624082305024, 2 ],
  467. [ 1024, 3 ],
  468. ],
  469. test_shr_digit: [
  470. [ 3680125442705055547392, 1 ],
  471. [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
  472. [ 219504133884436710204395031992179571, 2 ],
  473. ],
  474. test_shl: [
  475. [ 3192, 1 ],
  476. [ 1298074214633706907132624082305024, 2 ],
  477. [ 1024, 3 ],
  478. ],
  479. test_shr: [
  480. [ 3680125442705055547392, 1 ],
  481. [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
  482. [ 219504133884436710204395031992179571, 2 ],
  483. ],
  484. test_shr_signed: [
  485. [ -611105530635358368578155082258244262, 12 ],
  486. [ -149195686190273039203651143129455, 12 ],
  487. [ 611105530635358368578155082258244262, 12 ],
  488. [ 149195686190273039203651143129455, 12 ],
  489. ],
  490. test_factorial: [
  491. [ 6_000 ], # Regular factorial, see cutoff in common.odin.
  492. [ 12_345 ], # Binary split factorial
  493. ],
  494. test_gcd: [
  495. [ 23, 25, ],
  496. [ 125, 25, ],
  497. [ 125, 0, ],
  498. [ 0, 0, ],
  499. [ 0, 125,],
  500. ],
  501. test_lcm: [
  502. [ 23, 25,],
  503. [ 125, 25, ],
  504. [ 125, 0, ],
  505. [ 0, 0, ],
  506. [ 0, 125,],
  507. ],
  508. test_is_square: [
  509. [ 12, ],
  510. [ 92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201, ]
  511. ],
  512. }
  513. if not args.fast_tests:
  514. TESTS[test_factorial].append(
  515. # This one on its own takes around 800ms, so we exclude it for FAST_TESTS
  516. [ 10_000 ],
  517. )
  518. total_passes = 0
  519. total_failures = 0
  520. #
  521. # test_shr_signed also tests shr, so we're not going to test shr randomly.
  522. #
  523. RANDOM_TESTS = [
  524. test_add, test_sub, test_mul, test_sqr, test_div,
  525. test_log, test_pow, test_sqrt, test_root_n,
  526. test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
  527. test_gcd, test_lcm, test_is_square,
  528. ]
  529. SKIP_LARGE = [
  530. test_pow, test_root_n, # test_gcd,
  531. ]
  532. SKIP_LARGEST = []
  533. # Untimed warmup.
  534. for test_proc in TESTS:
  535. for t in TESTS[test_proc]:
  536. res = test_proc(*t)
  537. if __name__ == '__main__':
  538. print("\n---- math/big tests ----")
  539. print()
  540. max_name = 0
  541. for test_proc in TESTS:
  542. max_name = max(max_name, len(test_proc.__name__))
  543. fmt_string = "{name:>{max_name}}: {count_pass:7,} passes and {count_fail:7,} failures in {timing:9.3f} ms."
  544. fmt_string = fmt_string.replace("{max_name}", str(max_name))
  545. for test_proc in TESTS:
  546. count_pass = 0
  547. count_fail = 0
  548. TIMINGS = {}
  549. for t in TESTS[test_proc]:
  550. start = time.perf_counter()
  551. res = test_proc(*t)
  552. diff = time.perf_counter() - start
  553. TOTAL_TIME += diff
  554. if test_proc not in TIMINGS:
  555. TIMINGS[test_proc] = diff
  556. else:
  557. TIMINGS[test_proc] += diff
  558. if res:
  559. count_pass += 1
  560. total_passes += 1
  561. else:
  562. count_fail += 1
  563. total_failures += 1
  564. print(fmt_string.format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  565. for BITS, ITERATIONS in BITS_AND_ITERATIONS:
  566. print()
  567. print("---- math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
  568. print()
  569. #
  570. # We've already tested up to the 10th root.
  571. #
  572. TEST_ROOT_N_PARAMS = [2, 3, 4, 5, 6]
  573. for test_proc in RANDOM_TESTS:
  574. if BITS > 1_200 and test_proc in SKIP_LARGE: continue
  575. if BITS > 4_096 and test_proc in SKIP_LARGEST: continue
  576. count_pass = 0
  577. count_fail = 0
  578. TIMINGS = {}
  579. UNTIL_ITERS = ITERATIONS
  580. if test_proc == test_root_n and BITS == 1_200:
  581. UNTIL_ITERS /= 10
  582. UNTIL_TIME = TOTAL_TIME + BITS / args.timed_bits
  583. # We run each test for a second per 20k bits
  584. index = 0
  585. while we_iterate():
  586. a = randint(-(1 << BITS), 1 << BITS)
  587. b = randint(-(1 << BITS), 1 << BITS)
  588. if test_proc == test_div:
  589. # We've already tested division by zero above.
  590. bits = int(BITS * 0.6)
  591. b = randint(-(1 << bits), 1 << bits)
  592. if b == 0:
  593. b == 42
  594. elif test_proc == test_log:
  595. # We've already tested log's domain errors.
  596. a = randint(1, 1 << BITS)
  597. b = randint(2, 1 << 60)
  598. elif test_proc == test_pow:
  599. b = randint(1, 10)
  600. elif test_proc == test_sqrt:
  601. a = randint(1, 1 << BITS)
  602. b = Error.Okay
  603. elif test_proc == test_root_n:
  604. a = randint(1, 1 << BITS)
  605. b = TEST_ROOT_N_PARAMS[index]
  606. index = (index + 1) % len(TEST_ROOT_N_PARAMS)
  607. elif test_proc == test_shl_digit:
  608. b = randint(0, 10);
  609. elif test_proc == test_shr_digit:
  610. a = abs(a)
  611. b = randint(0, 10);
  612. elif test_proc == test_shl:
  613. b = randint(0, min(BITS, 120))
  614. elif test_proc == test_shr_signed:
  615. b = randint(0, min(BITS, 120))
  616. elif test_proc == test_is_square:
  617. a = randint(0, 1 << BITS)
  618. elif test_proc == test_lcm:
  619. smallest = min(a, b)
  620. biggest = max(a, b)
  621. # Randomly swap biggest and smallest
  622. if randint(1, 11) % 2 == 0:
  623. smallest, biggest = biggest, smallest
  624. a, b = smallest, biggest
  625. else:
  626. b = randint(0, 1 << BITS)
  627. res = None
  628. start = time.perf_counter()
  629. res = test_proc(a, b)
  630. diff = time.perf_counter() - start
  631. TOTAL_TIME += diff
  632. if test_proc not in TIMINGS:
  633. TIMINGS[test_proc] = diff
  634. else:
  635. TIMINGS[test_proc] += diff
  636. if res:
  637. count_pass += 1; total_passes += 1
  638. else:
  639. count_fail += 1; total_failures += 1
  640. print(fmt_string.format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
  641. print()
  642. print("---- THE END ----")
  643. print()
  644. print(fmt_string.format(name="total", count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
  645. if total_failures:
  646. exit(1)