util.lua 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. local json = require "cjson"
  2. -- Various common routines used by the Lua CJSON package
  3. --
  4. -- Mark Pulford <[email protected]>
  5. -- Determine with a Lua table can be treated as an array.
  6. -- Explicitly returns "not an array" for very sparse arrays.
  7. -- Returns:
  8. -- -1 Not an array
  9. -- 0 Empty table
  10. -- >0 Highest index in the array
  11. -- Provide unpack for Lua 5.3+ built without LUA_COMPAT_UNPACK
  12. local unpack = unpack
  13. if table.unpack then unpack = table.unpack end
  14. local function is_array(table)
  15. local max = 0
  16. local count = 0
  17. for k, v in pairs(table) do
  18. if type(k) == "number" then
  19. if k > max then max = k end
  20. count = count + 1
  21. else
  22. return -1
  23. end
  24. end
  25. if max > count * 2 then
  26. return -1
  27. end
  28. return max
  29. end
  30. local serialise_value
  31. local function serialise_table(value, indent, depth)
  32. local spacing, spacing2, indent2
  33. if indent then
  34. spacing = "\n" .. indent
  35. spacing2 = spacing .. " "
  36. indent2 = indent .. " "
  37. else
  38. spacing, spacing2, indent2 = " ", " ", false
  39. end
  40. depth = depth + 1
  41. if depth > 50 then
  42. return "Cannot serialise any further: too many nested tables"
  43. end
  44. local max = is_array(value)
  45. local comma = false
  46. local fragment = { "{" .. spacing2 }
  47. if max > 0 then
  48. -- Serialise array
  49. for i = 1, max do
  50. if comma then
  51. table.insert(fragment, "," .. spacing2)
  52. end
  53. table.insert(fragment, serialise_value(value[i], indent2, depth))
  54. comma = true
  55. end
  56. elseif max < 0 then
  57. -- Serialise table
  58. for k, v in pairs(value) do
  59. if comma then
  60. table.insert(fragment, "," .. spacing2)
  61. end
  62. table.insert(fragment,
  63. ("[%s] = %s"):format(serialise_value(k, indent2, depth),
  64. serialise_value(v, indent2, depth)))
  65. comma = true
  66. end
  67. end
  68. table.insert(fragment, spacing .. "}")
  69. return table.concat(fragment)
  70. end
  71. function serialise_value(value, indent, depth)
  72. if indent == nil then indent = "" end
  73. if depth == nil then depth = 0 end
  74. if value == json.null then
  75. return "json.null"
  76. elseif type(value) == "string" then
  77. return ("%q"):format(value)
  78. elseif type(value) == "nil" or type(value) == "number" or
  79. type(value) == "boolean" then
  80. return tostring(value)
  81. elseif type(value) == "table" then
  82. return serialise_table(value, indent, depth)
  83. else
  84. return "\"<" .. type(value) .. ">\""
  85. end
  86. end
  87. local function file_load(filename)
  88. local file
  89. if filename == nil then
  90. file = io.stdin
  91. else
  92. local err
  93. file, err = io.open(filename, "rb")
  94. if file == nil then
  95. error(("Unable to read '%s': %s"):format(filename, err))
  96. end
  97. end
  98. local data = file:read("*a")
  99. if filename ~= nil then
  100. file:close()
  101. end
  102. if data == nil then
  103. error("Failed to read " .. filename)
  104. end
  105. return data
  106. end
  107. local function file_save(filename, data)
  108. local file
  109. if filename == nil then
  110. file = io.stdout
  111. else
  112. local err
  113. file, err = io.open(filename, "wb")
  114. if file == nil then
  115. error(("Unable to write '%s': %s"):format(filename, err))
  116. end
  117. end
  118. file:write(data)
  119. if filename ~= nil then
  120. file:close()
  121. end
  122. end
  123. local function compare_values(val1, val2)
  124. local type1 = type(val1)
  125. local type2 = type(val2)
  126. if type1 ~= type2 then
  127. return false
  128. end
  129. -- Check for NaN
  130. if type1 == "number" and val1 ~= val1 and val2 ~= val2 then
  131. return true
  132. end
  133. if type1 ~= "table" then
  134. return val1 == val2
  135. end
  136. -- check_keys stores all the keys that must be checked in val2
  137. local check_keys = {}
  138. for k, _ in pairs(val1) do
  139. check_keys[k] = true
  140. end
  141. for k, v in pairs(val2) do
  142. if not check_keys[k] then
  143. return false
  144. end
  145. if not compare_values(val1[k], val2[k]) then
  146. return false
  147. end
  148. check_keys[k] = nil
  149. end
  150. for k, _ in pairs(check_keys) do
  151. -- Not the same if any keys from val1 were not found in val2
  152. return false
  153. end
  154. return true
  155. end
  156. local test_count_pass = 0
  157. local test_count_total = 0
  158. local function run_test_summary()
  159. return test_count_pass, test_count_total
  160. end
  161. local function run_test(testname, func, input, should_work, output)
  162. local function status_line(name, status, value)
  163. local statusmap = { [true] = ":success", [false] = ":error" }
  164. if status ~= nil then
  165. name = name .. statusmap[status]
  166. end
  167. print(("[%s] %s"):format(name, serialise_value(value, false)))
  168. end
  169. local result = { pcall(func, unpack(input)) }
  170. local success = table.remove(result, 1)
  171. local correct = false
  172. if success == should_work and compare_values(result, output) then
  173. correct = true
  174. test_count_pass = test_count_pass + 1
  175. end
  176. test_count_total = test_count_total + 1
  177. local teststatus = { [true] = "PASS", [false] = "FAIL" }
  178. print(("==> Test [%d] %s: %s"):format(test_count_total, testname,
  179. teststatus[correct]))
  180. status_line("Input", nil, input)
  181. if not correct then
  182. status_line("Expected", should_work, output)
  183. end
  184. status_line("Received", success, result)
  185. print()
  186. return correct, result
  187. end
  188. local function run_test_group(tests)
  189. local function run_helper(name, func, input)
  190. if type(name) == "string" and #name > 0 then
  191. print("==> " .. name)
  192. end
  193. -- Not a protected call, these functions should never generate errors.
  194. func(unpack(input or {}))
  195. print()
  196. end
  197. for _, v in ipairs(tests) do
  198. -- Run the helper if "should_work" is missing
  199. if v[4] == nil then
  200. run_helper(unpack(v))
  201. else
  202. run_test(unpack(v))
  203. end
  204. end
  205. end
  206. -- Run a Lua script in a separate environment
  207. local function run_script(script, env)
  208. local env = env or {}
  209. local func
  210. -- Use setfenv() if it exists, otherwise assume Lua 5.2 load() exists
  211. if _G.setfenv then
  212. func = loadstring(script)
  213. if func then
  214. setfenv(func, env)
  215. end
  216. else
  217. func = load(script, nil, nil, env)
  218. end
  219. if func == nil then
  220. error("Invalid syntax.")
  221. end
  222. func()
  223. return env
  224. end
  225. -- Export functions
  226. return {
  227. serialise_value = serialise_value,
  228. file_load = file_load,
  229. file_save = file_save,
  230. compare_values = compare_values,
  231. run_test_summary = run_test_summary,
  232. run_test = run_test,
  233. run_test_group = run_test_group,
  234. run_script = run_script
  235. }
  236. -- vi:ai et sw=4 ts=4: