2
0

create_vulkan_odin_wrapper.py 20 KB


  1. import re
  2. import urllib.request as req
  3. from tokenize import tokenize
  4. from io import BytesIO
  5. import string
  6. import os.path
  7. import math
  8. file_and_urls = [
  9. ("vk_platform.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vk_platform.h', True),
  10. ("vulkan_core.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_core.h', False),
  11. ("vk_layer.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vk_layer.h', True),
  12. ("vk_icd.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vk_icd.h', True),
  13. ("vulkan_win32.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_win32.h', False),
  14. ("vulkan_metal.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_metal.h', False),
  15. ("vulkan_macos.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_macos.h', False),
  16. ("vulkan_ios.h", 'https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_ios.h', False),
  17. ]
  18. for file, url, _ in file_and_urls:
  19. if not os.path.isfile(file):
  20. with open(file, 'w', encoding='utf-8') as f:
  21. f.write(req.urlopen(url).read().decode('utf-8'))
  22. src = ""
  23. for file, _, skip in file_and_urls:
  24. if skip: continue
  25. with open(file, 'r', encoding='utf-8') as f:
  26. src += f.read()
  27. def no_vk(t):
  28. t = t.replace('Vk', '')
  29. t = t.replace('PFN_vk_icd', 'Procicd')
  30. t = t.replace('PFN_vk', 'Proc')
  31. t = t.replace('PFN_', 'Proc')
  32. t = t.replace('PFN_', 'Proc')
  33. t = t.replace('VK_', '')
  34. return t
  35. def convert_type(t, prev_name, curr_name):
  36. table = {
  37. "Bool32": 'b32',
  38. "float": 'f32',
  39. "double": 'f64',
  40. "uint32_t": 'u32',
  41. "uint64_t": 'u64',
  42. "size_t": 'int',
  43. 'int32_t': 'i32',
  44. 'int64_t': 'i64',
  45. 'int': 'c.int',
  46. 'uint8_t': 'u8',
  47. "uint16_t": 'u16',
  48. "char": "byte",
  49. "void": "void",
  50. "void*": "rawptr",
  51. "void *": "rawptr",
  52. "char*": 'cstring',
  53. "const uint32_t* const*": "^[^]u32",
  54. "const void*": 'rawptr',
  55. "const char*": 'cstring',
  56. "const char* const*": '[^]cstring',
  57. "const ObjectTableEntryNVX* const*": "^^ObjectTableEntryNVX",
  58. "const void* const *": "[^]rawptr",
  59. "const AccelerationStructureGeometryKHR* const*": "^[^]AccelerationStructureGeometryKHR",
  60. "const AccelerationStructureBuildRangeInfoKHR* const*": "^[^]AccelerationStructureBuildRangeInfoKHR",
  61. "struct BaseOutStructure": "BaseOutStructure",
  62. "struct BaseInStructure": "BaseInStructure",
  63. 'v': '',
  64. }
  65. if t in table.keys():
  66. return table[t]
  67. if t == "":
  68. return t
  69. elif t.endswith("*"):
  70. elem = ""
  71. pointer = "^"
  72. if t.startswith("const"):
  73. ttype = t[6:len(t)-1]
  74. elem = convert_type(ttype, prev_name, curr_name)
  75. else:
  76. ttype = t[:len(t)-1]
  77. elem = convert_type(ttype, prev_name, curr_name)
  78. if curr_name.endswith("s") or curr_name.endswith("Table"):
  79. if prev_name.endswith("Count") or prev_name.endswith("Counts"):
  80. pointer = "[^]"
  81. elif curr_name.startswith("pp"):
  82. if elem.startswith("[^]"):
  83. pass
  84. else:
  85. pointer = "[^]"
  86. elif curr_name.startswith("p"):
  87. pointer = "[^]"
  88. if curr_name and elem.endswith("Flags"):
  89. pointer = "[^]"
  90. return "{}{}".format(pointer, elem)
  91. elif t[0].isupper():
  92. return t
  93. return t
  94. def parse_array(n, t):
  95. name, length = n.split('[', 1)
  96. length = no_vk(length[:-1])
  97. type_ = "[{}]{}".format(length, do_type(t))
  98. return name, type_
  99. def remove_prefix(text, prefix):
  100. if text.startswith(prefix):
  101. return text[len(prefix):]
  102. return text
  103. def remove_suffix(text, suffix):
  104. if text.endswith(suffix):
  105. return text[:-len(suffix)]
  106. return text
  107. def to_snake_case(name):
  108. s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
  109. return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
  110. ext_suffixes = ["KHR", "EXT", "AMD", "NV", "NVX", "GOOGLE"]
  111. ext_suffixes_title = [ext.title() for ext in ext_suffixes]
  112. def fix_arg(arg):
  113. name = arg
  114. # Remove useless pointer identifier in field name
  115. for p in ('s_', 'p_', 'pp_', 'pfn_'):
  116. if name.startswith(p):
  117. name = name[len(p)::]
  118. name = name.replace("__", "_")
  119. return name
  120. def fix_ext_suffix(name):
  121. for ext in ext_suffixes_title:
  122. if name.endswith(ext):
  123. start = name[:-len(ext)]
  124. end = name[-len(ext):].upper()
  125. return start+end
  126. return name
  127. def to_int(x):
  128. if x.startswith('0x'):
  129. return int(x, 16)
  130. return int(x)
  131. def is_int(x):
  132. try:
  133. int(x)
  134. return True
  135. except ValueError:
  136. return False
  137. def fix_enum_arg(name, is_flag_bit=False):
  138. # name = name.title()
  139. name = fix_ext_suffix(name)
  140. if len(name) > 0 and name[0].isdigit() and not name.startswith("0x") and not is_int(name):
  141. if name[1] == "D":
  142. name = name[1] + name[0] + (name[2:] if len(name) > 2 else "")
  143. else:
  144. name = "_"+name
  145. if is_flag_bit:
  146. name = name.replace("_BIT", "")
  147. return name
  148. def do_type(t, prev_name="", name=""):
  149. return convert_type(no_vk(t), prev_name, name).replace("FlagBits", "Flags")
  150. def parse_handles_def(f):
  151. f.write("// Handles types\n")
  152. handles = [h for h in re.findall(r"VK_DEFINE_HANDLE\(Vk(\w+)\)", src, re.S)]
  153. max_len = max(len(h) for h in handles)
  154. for h in handles:
  155. f.write("{} :: distinct Handle\n".format(h.ljust(max_len)))
  156. handles_non_dispatchable = [h for h in re.findall(r"VK_DEFINE_NON_DISPATCHABLE_HANDLE\(Vk(\w+)\)", src, re.S)]
  157. max_len = max(len(h) for h in handles_non_dispatchable)
  158. for h in handles_non_dispatchable:
  159. f.write("{} :: distinct NonDispatchableHandle\n".format(h.ljust(max_len)))
  160. flags_defs = set()
  161. def parse_flags_def(f):
  162. names = [n for n in re.findall(r"typedef VkFlags Vk(\w+?);", src)]
  163. global flags_defs
  164. flags_defs = set(names)
  165. class FlagError(ValueError):
  166. pass
  167. class IgnoreFlagError(ValueError):
  168. pass
  169. def fix_enum_name(name, prefix, suffix, is_flag_bit):
  170. name = remove_prefix(name, prefix)
  171. if suffix:
  172. name = remove_suffix(name, suffix)
  173. if name.startswith("0x"):
  174. if is_flag_bit:
  175. i = int(name, 16)
  176. if i == 0:
  177. raise IgnoreFlagError(i)
  178. v = int(math.log2(i))
  179. if 2**v != i:
  180. raise FlagError(i)
  181. return str(v)
  182. return name
  183. elif is_flag_bit:
  184. ignore = False
  185. try:
  186. if int(name) == 0:
  187. ignore = True
  188. except:
  189. pass
  190. if ignore:
  191. raise IgnoreFlagError()
  192. return fix_enum_arg(name, is_flag_bit)
  193. def fix_enum_value(value, prefix, suffix, is_flag_bit):
  194. v = no_vk(value)
  195. g = tokenize(BytesIO(v.encode('utf-8')).readline)
  196. tokens = [val for _, val, _, _, _ in g]
  197. assert len(tokens) > 2
  198. token = ''.join([t for t in tokens[1:-1] if t])
  199. token = fix_enum_name(token, prefix, suffix, is_flag_bit)
  200. return token
  201. def parse_constants(f):
  202. f.write("// General Constants\n")
  203. all_data = re.findall(r"#define VK_(\w+)\s*(.*?)U?\n", src, re.S)
  204. allowed_names = (
  205. "HEADER_VERSION",
  206. "MAX_DRIVER_NAME_SIZE",
  207. "MAX_DRIVER_INFO_SIZE",
  208. )
  209. allowed_data = [nv for nv in all_data if nv[0] in allowed_names]
  210. max_len = max(len(name) for name, value in allowed_data)
  211. for name, value in allowed_data:
  212. f.write("{}{} :: {}\n".format(name, "".rjust(max_len-len(name)), value))
  213. f.write("\n// Vendor Constants\n")
  214. data = re.findall(r"#define VK_((?:"+'|'.join(ext_suffixes)+r")\w+)\s*(.*?)\n", src, re.S)
  215. max_len = max(len(name) for name, value in data)
  216. for name, value in data:
  217. f.write("{}{} :: {}\n".format(name, "".rjust(max_len-len(name)), value))
  218. f.write("\n")
  219. def parse_enums(f):
  220. f.write("import \"core:c\"\n\n")
  221. f.write("// Enums\n")
  222. data = re.findall(r"typedef enum Vk(\w+) {(.+?)} \w+;", src, re.S)
  223. data.sort(key=lambda x: x[0])
  224. generated_flags = set()
  225. for name, fields in data:
  226. enum_name = name
  227. is_flag_bit = False
  228. if "FlagBits" in enum_name:
  229. is_flag_bit = True
  230. flags_name = enum_name.replace("FlagBits", "Flags")
  231. enum_name = enum_name.replace("FlagBits", "Flag")
  232. generated_flags.add(flags_name)
  233. f.write("{} :: distinct bit_set[{}; Flags]\n".format(flags_name, enum_name))
  234. if is_flag_bit:
  235. f.write("{} :: enum Flags {{\n".format(name.replace("FlagBits", "Flag")))
  236. else:
  237. f.write("{} :: enum c.int {{\n".format(name))
  238. prefix = to_snake_case(name).upper()
  239. suffix = None
  240. for ext in ext_suffixes:
  241. prefix_new = remove_suffix(prefix, "_"+ext)
  242. assert suffix is None
  243. if prefix_new != prefix:
  244. suffix = "_"+ext
  245. prefix = prefix_new
  246. break
  247. prefix = prefix.replace("_FLAG_BITS", "")
  248. prefix += "_"
  249. ff = []
  250. names_and_values = re.findall(r"VK_(\w+?) = (.*?)(?:,|})", fields, re.S)
  251. groups = []
  252. flags = {}
  253. for name, value in names_and_values:
  254. n = fix_enum_name(name, prefix, suffix, is_flag_bit)
  255. try:
  256. v = fix_enum_value(value, prefix, suffix, is_flag_bit)
  257. except FlagError as e:
  258. v = int(str(e))
  259. groups.append((n, v))
  260. continue
  261. except IgnoreFlagError as e:
  262. groups.append((n, 0))
  263. continue
  264. if n == v:
  265. continue
  266. try:
  267. flags[int(v)] = n
  268. except ValueError as e:
  269. pass
  270. if v == "NONE":
  271. continue
  272. ff.append((n, v))
  273. max_flag_value = max([int(v) for n, v in ff if is_int(v)] + [0])
  274. max_group_value = max([int(v) for n, v in groups if is_int(v)] + [0])
  275. if max_flag_value < max_group_value:
  276. if (1<<max_flag_value)+1 < max_group_value:
  277. ff.append(('_MAX', 31))
  278. flags[31] = '_MAX'
  279. pass
  280. max_len = max([len(n) for n, v in ff] + [0])
  281. flag_names = set([n for n, v in ff])
  282. for n, v in ff:
  283. if is_flag_bit and not is_int(v) and v not in flag_names:
  284. print("Ignoring", n, "=", v)
  285. continue
  286. f.write("\t{} = {},".format(n.ljust(max_len), v))
  287. if n == "_MAX":
  288. f.write(" // Needed for the *_ALL bit set")
  289. f.write("\n")
  290. f.write("}\n\n")
  291. for n, v in groups:
  292. used_flags = []
  293. for i in range(0, 32):
  294. if 1<<i & v != 0:
  295. if i in flags:
  296. used_flags.append('.'+flags[i])
  297. else:
  298. used_flags.append('{}({})'.format(enum_name, i))
  299. s = "{enum_name}s_{n} :: {enum_name}s{{".format(enum_name=enum_name, n=n)
  300. s += ', '.join(used_flags)
  301. s += "}\n"
  302. f.write(s)
  303. if len(groups) > 0:
  304. f.write("\n\n")
  305. unused_flags = [flag for flag in flags_defs if flag not in generated_flags]
  306. unused_flags.sort()
  307. max_len = max(len(flag) for flag in unused_flags)
  308. for flag in unused_flags:
  309. flag_name = flag.replace("Flags", "Flag")
  310. f.write("{} :: distinct bit_set[{}; Flags]\n".format(flag.ljust(max_len), flag_name))
  311. f.write("{} :: enum u32 {{}}\n".format(flag_name.ljust(max_len)))
  312. def parse_structs(f):
  313. data = re.findall(r"typedef (struct|union) Vk(\w+?) {(.+?)} \w+?;", src, re.S)
  314. for _type, name, fields in data:
  315. fields = re.findall(r"\s+(.+?)\s+([_a-zA-Z0-9[\]]+);", fields)
  316. f.write("{} :: struct ".format(name))
  317. if _type == "union":
  318. f.write("#raw_union ")
  319. f.write("{\n")
  320. prev_name = ""
  321. ffields = []
  322. for type_, fname in fields:
  323. if '[' in fname:
  324. fname, type_ = parse_array(fname, type_)
  325. comment = None
  326. n = fix_arg(fname)
  327. if "Flag_Bits" in type_:
  328. comment = " // only single bit set"
  329. t = do_type(type_, prev_name, fname)
  330. if n == "matrix":
  331. n = "mat"
  332. ffields.append(tuple([n, t, comment]))
  333. prev_name = fname
  334. max_len = max(len(n) for n, _, _ in ffields)
  335. for n, t, comment in ffields:
  336. k = max_len-len(n)+len(t)
  337. f.write("\t{}: {},{}\n".format(n, t.rjust(k), comment or ""))
  338. f.write("}\n\n")
  339. f.write("// Aliases\n")
  340. data = re.findall(r"typedef Vk(\w+?) Vk(\w+?);", src, re.S)
  341. aliases = []
  342. for _type, name in data:
  343. if _type == "Flags":
  344. continue
  345. name = name.replace("FlagBits", "Flag")
  346. _type = _type.replace("FlagBits", "Flag")
  347. aliases.append((name, _type))
  348. max_len = max([len(n) for n, _ in aliases] + [0])
  349. for n, t in aliases:
  350. k = max_len
  351. f.write("{} :: {}\n".format(n.ljust(k), t))
  352. procedure_map = {}
  353. def parse_procedures(f):
  354. data = re.findall(r"typedef (\w+\*?) \(\w+ \*(\w+)\)\((.+?)\);", src, re.S)
  355. ff = []
  356. for rt, name, fields in data:
  357. proc_name = no_vk(name)
  358. pf = []
  359. prev_name = ""
  360. for type_, fname in re.findall(r"(?:\s*|)(.+?)\s*(\w+)(?:,|$)", fields):
  361. curr_name = fix_arg(fname)
  362. pf.append((do_type(type_, prev_name, curr_name), curr_name))
  363. prev_name = curr_name
  364. data_fields = ', '.join(["{}: {}".format(n, t) for t, n in pf if t != ""])
  365. ts = "proc \"c\" ({})".format(data_fields)
  366. rt_str = do_type(rt)
  367. if rt_str != "void":
  368. ts += " -> {}".format(rt_str)
  369. procedure_map[proc_name] = ts
  370. ff.append( (proc_name, ts) )
  371. max_len = max(len(n) for n, t in ff)
  372. f.write("import \"core:c\"\n\n")
  373. f.write("// Procedure Types\n\n");
  374. for n, t in ff:
  375. f.write("{} :: #type {}\n".format(n.ljust(max_len), t.replace('"c"', '"system"')))
  376. def group_functions(f):
  377. data = re.findall(r"typedef (\w+\*?) \(\w+ \*(\w+)\)\((.+?)\);", src, re.S)
  378. group_map = {"Instance":[], "Device":[], "Loader":[]}
  379. for rt, vkname, fields in data:
  380. fields_types_name = [do_type(t) for t in re.findall(r"(?:\s*|)(.+?)\s*\w+(?:,|$)", fields)]
  381. table_name = fields_types_name[0]
  382. name = no_vk(vkname)
  383. nn = (fix_arg(name), fix_ext_suffix(name))
  384. if table_name in ('Device', 'Queue', 'CommandBuffer') and name != 'GetDeviceProcAddr':
  385. group_map["Device"].append(nn)
  386. elif table_name in ('Instance', 'PhysicalDevice') or name == 'GetDeviceProcAddr':
  387. group_map["Instance"].append(nn)
  388. elif table_name in ('rawptr', '', 'DebugReportFlagsEXT') or name == 'GetInstanceProcAddr':
  389. # Skip the allocation function and the dll entry point
  390. pass
  391. else:
  392. group_map["Loader"].append(nn)
  393. for group_name, group_lines in group_map.items():
  394. f.write("// {} Procedures\n".format(group_name))
  395. max_len = max(len(name) for name, _ in group_lines)
  396. for name, vk_name in group_lines:
  397. type_str = procedure_map[vk_name]
  398. f.write('{}: {}\n'.format(remove_prefix(name, "Proc"), name.rjust(max_len)))
  399. f.write("\n")
  400. f.write("load_proc_addresses :: proc(set_proc_address: SetProcAddressType) {\n")
  401. for group_name, group_lines in group_map.items():
  402. f.write("\t// {} Procedures\n".format(group_name))
  403. max_len = max(len(name) for name, _ in group_lines)
  404. for name, vk_name in group_lines:
  405. k = max_len - len(name)
  406. f.write('\tset_proc_address(&{}, {}"vk{}")\n'.format(
  407. remove_prefix(name, 'Proc'),
  408. "".ljust(k),
  409. remove_prefix(vk_name, 'Proc'),
  410. ))
  411. f.write("\n")
  412. f.write("}\n")
  413. BASE = """
  414. //
  415. // Vulkan wrapper generated from "https://raw.githubusercontent.com/KhronosGroup/Vulkan-Headers/master/include/vulkan/vulkan_core.h"
  416. //
  417. package vulkan
  418. """[1::]
  419. with open("../core.odin", 'w', encoding='utf-8') as f:
  420. f.write(BASE)
  421. f.write("""
  422. API_VERSION_1_0 :: (1<<22) | (0<<12) | (0)
  423. MAKE_VERSION :: proc(major, minor, patch: u32) -> u32 {
  424. return (major<<22) | (minor<<12) | (patch)
  425. }
  426. // Base types
  427. Flags :: distinct u32
  428. Flags64 :: distinct u64
  429. DeviceSize :: distinct u64
  430. DeviceAddress :: distinct u64
  431. SampleMask :: distinct u32
  432. Handle :: distinct rawptr
  433. NonDispatchableHandle :: distinct u64
  434. SetProcAddressType :: #type proc(p: rawptr, name: cstring)
  435. RemoteAddressNV :: distinct rawptr // Declared inline before MemoryGetRemoteAddressInfoNV
  436. // Base constants
  437. LOD_CLAMP_NONE :: 1000.0
  438. REMAINING_MIP_LEVELS :: ~u32(0)
  439. REMAINING_ARRAY_LAYERS :: ~u32(0)
  440. WHOLE_SIZE :: ~u64(0)
  441. ATTACHMENT_UNUSED :: ~u32(0)
  442. TRUE :: 1
  443. FALSE :: 0
  444. QUEUE_FAMILY_IGNORED :: ~u32(0)
  445. SUBPASS_EXTERNAL :: ~u32(0)
  446. MAX_PHYSICAL_DEVICE_NAME_SIZE :: 256
  447. UUID_SIZE :: 16
  448. MAX_MEMORY_TYPES :: 32
  449. MAX_MEMORY_HEAPS :: 16
  450. MAX_EXTENSION_NAME_SIZE :: 256
  451. MAX_DESCRIPTION_SIZE :: 256
  452. MAX_DEVICE_GROUP_SIZE_KHX :: 32
  453. MAX_DEVICE_GROUP_SIZE :: 32
  454. LUID_SIZE_KHX :: 8
  455. LUID_SIZE_KHR :: 8
  456. LUID_SIZE :: 8
  457. MAX_DRIVER_NAME_SIZE_KHR :: 256
  458. MAX_DRIVER_INFO_SIZE_KHR :: 256
  459. MAX_QUEUE_FAMILY_EXTERNAL :: ~u32(0)-1
  460. MAX_GLOBAL_PRIORITY_SIZE_EXT :: 16
  461. """[1::])
  462. parse_constants(f)
  463. parse_handles_def(f)
  464. f.write("\n\n")
  465. parse_flags_def(f)
  466. with open("../enums.odin", 'w', encoding='utf-8') as f:
  467. f.write(BASE)
  468. f.write("\n")
  469. parse_enums(f)
  470. f.write("\n\n")
  471. with open("../structs.odin", 'w', encoding='utf-8') as f:
  472. f.write(BASE)
  473. f.write("""
  474. import "core:c"
  475. when ODIN_OS == "windows" {
  476. \timport win32 "core:sys/windows"
  477. \tHINSTANCE :: win32.HINSTANCE
  478. \tHWND :: win32.HWND
  479. \tHMONITOR :: win32.HMONITOR
  480. \tHANDLE :: win32.HANDLE
  481. \tLPCWSTR :: win32.LPCWSTR
  482. \tSECURITY_ATTRIBUTES :: win32.SECURITY_ATTRIBUTES
  483. \tDWORD :: win32.DWORD
  484. \tLONG :: win32.LONG
  485. \tLUID :: win32.LUID
  486. } else {
  487. \tHINSTANCE :: distinct rawptr
  488. \tHWND :: distinct rawptr
  489. \tHMONITOR :: distinct rawptr
  490. \tHANDLE :: distinct rawptr
  491. \tLPCWSTR :: ^u16
  492. \tSECURITY_ATTRIBUTES :: struct {}
  493. \tDWORD :: u32
  494. \tLONG :: c.long
  495. \tLUID :: struct {
  496. \t\tLowPart: DWORD,
  497. \t\tHighPart: LONG,
  498. \t}
  499. }
  500. CAMetalLayer :: struct {}
  501. /********************************/
  502. """)
  503. f.write("\n")
  504. parse_structs(f)
  505. f.write("\n\n")
  506. with open("../procedures.odin", 'w', encoding='utf-8') as f:
  507. f.write(BASE)
  508. f.write("\n")
  509. parse_procedures(f)
  510. f.write("\n")
  511. group_functions(f)
  512. f.write("\n\n")