hctdb_instrhelp.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060
  1. # Copyright (C) Microsoft Corporation. All rights reserved.
  2. # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
  3. import argparse
  4. import functools
  5. import collections
  6. from hctdb import *
  7. # get db singletons
  8. g_db_dxil = None
  9. def get_db_dxil():
  10. global g_db_dxil
  11. if g_db_dxil is None:
  12. g_db_dxil = db_dxil()
  13. return g_db_dxil
  14. g_db_hlsl = None
  15. def get_db_hlsl():
  16. global g_db_hlsl
  17. if g_db_hlsl is None:
  18. thisdir = os.path.dirname(os.path.realpath(__file__))
  19. with open(os.path.join(thisdir, "gen_intrin_main.txt"), "r") as f:
  20. g_db_hlsl = db_hlsl(f)
  21. return g_db_hlsl
  22. def format_comment(prefix, val):
  23. "Formats a value with a line-comment prefix."
  24. result = ""
  25. line_width = 80
  26. content_width = line_width - len(prefix)
  27. l = len(val)
  28. while l:
  29. if l < content_width:
  30. result += prefix + val.strip()
  31. result += "\n"
  32. l = 0
  33. else:
  34. split_idx = val.rfind(" ", 0, content_width)
  35. result += prefix + val[:split_idx].strip()
  36. result += "\n"
  37. val = val[split_idx+1:]
  38. l = len(val)
  39. return result
  40. def format_rst_table(list_of_tuples):
  41. "Produces a reStructuredText simple table from the specified list of tuples."
  42. # Calculate widths.
  43. widths = None
  44. for t in list_of_tuples:
  45. if widths is None:
  46. widths = [0] * len(t)
  47. for i, v in enumerate(t):
  48. widths[i] = max(widths[i], len(str(v)))
  49. # Build banner line.
  50. banner = ""
  51. for i, w in enumerate(widths):
  52. if i > 0:
  53. banner += " "
  54. banner += "=" * w
  55. banner += "\n"
  56. # Build the result.
  57. result = banner
  58. for i, t in enumerate(list_of_tuples):
  59. for j, v in enumerate(t):
  60. if j > 0:
  61. result += " "
  62. result += str(v)
  63. result += " " * (widths[j] - len(str(v)))
  64. result = result.rstrip()
  65. result += "\n"
  66. if i == 0:
  67. result += banner
  68. result += banner
  69. return result
  70. def build_range_tuples(i):
  71. "Produces a list of tuples with contiguous ranges in the input list."
  72. i = sorted(i)
  73. low_bound = None
  74. high_bound = None
  75. for val in i:
  76. if low_bound is None:
  77. low_bound = val
  78. high_bound = val
  79. else:
  80. assert(not high_bound is None)
  81. if val == high_bound + 1:
  82. high_bound = val
  83. else:
  84. yield (low_bound, high_bound)
  85. low_bound = val
  86. high_bound = val
  87. if not low_bound is None:
  88. yield (low_bound, high_bound)
  89. def build_range_code(var, i):
  90. "Produces a fragment of code that tests whether the variable name matches values in the given range."
  91. ranges = build_range_tuples(i)
  92. result = ""
  93. for r in ranges:
  94. if r[0] == r[1]:
  95. cond = var + " == " + str(r[0])
  96. else:
  97. cond = "%d <= %s && %s <= %d" % (r[0], var, var, r[1])
  98. if result == "":
  99. result = cond
  100. else:
  101. result = result + " || " + cond
  102. return result
  103. class db_docsref_gen:
  104. "A generator of reference documentation."
  105. def __init__(self, db):
  106. self.db = db
  107. instrs = [i for i in self.db.instr if i.is_dxil_op]
  108. instrs = sorted(instrs, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  109. self.instrs = instrs
  110. val_rules = sorted(db.val_rules, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  111. self.val_rules = val_rules
  112. def print_content(self):
  113. self.print_header()
  114. self.print_body()
  115. self.print_footer()
  116. def print_header(self):
  117. print("<!DOCTYPE html>")
  118. print("<html><head><title>DXIL Reference</title>")
  119. print("<style>body { font-family: Verdana; font-size: small; }</style>")
  120. print("</head><body><h1>DXIL Reference</h1>")
  121. self.print_toc("Instructions", "i", self.instrs)
  122. self.print_toc("Rules", "r", self.val_rules)
  123. def print_body(self):
  124. self.print_instruction_details()
  125. self.print_valrule_details()
  126. def print_instruction_details(self):
  127. print("<h2>Instruction Details</h2>")
  128. for i in self.instrs:
  129. print("<h3><a name='i%s'>%s</a></h3>" % (i.name, i.name))
  130. print("<div>Opcode: %d. This instruction %s.</div>" % (i.dxil_opid, i.doc))
  131. if i.remarks:
  132. # This is likely a .rst fragment, but this will do for now.
  133. print("<div> " + i.remarks + "</div>")
  134. print("<div>Operands:</div>")
  135. print("<ul>")
  136. for o in i.ops:
  137. if o.pos == 0:
  138. print("<li>result: %s - %s</li>" % (o.llvm_type, o.doc))
  139. else:
  140. enum_desc = "" if o.enum_name == "" else " one of %s: %s" % (o.enum_name, ",".join(db.enum_idx[o.enum_name].value_names()))
  141. print("<li>%d - %s: %s%s%s</li>" % (o.pos - 1, o.name, o.llvm_type, "" if o.doc == "" else " - " + o.doc, enum_desc))
  142. print("</ul>")
  143. print("<div><a href='#Instructions'>(top)</a></div>")
  144. def print_valrule_details(self):
  145. print("<h2>Rule Details</h2>")
  146. for i in self.val_rules:
  147. print("<h3><a name='r%s'>%s</a></h3>" % (i.name, i.name))
  148. print("<div>" + i.doc + "</div>")
  149. print("<div><a href='#Rules'>(top)</a></div>")
  150. def print_toc(self, name, aprefix, values):
  151. print("<h2><a name='" + name + "'>" + name + "</a></h2>")
  152. last_category = ""
  153. for i in values:
  154. if i.category != last_category:
  155. if last_category != None:
  156. print("</ul>")
  157. print("<div><b>%s</b></div><ul>" % i.category)
  158. last_category = i.category
  159. print("<li><a href='#" + aprefix + "%s'>%s</a></li>" % (i.name, i.name))
  160. print("</ul>")
  161. def print_footer(self):
  162. print("</body></html>")
  163. class db_instrhelp_gen:
  164. "A generator of instruction helper classes."
  165. def __init__(self, db):
  166. self.db = db
  167. TypeInfo = collections.namedtuple("TypeInfo", "name bits")
  168. self.llvm_type_map = {
  169. "i1": TypeInfo("bool", 1),
  170. "i8": TypeInfo("int8_t", 8),
  171. "u8": TypeInfo("uint8_t", 8),
  172. "i32": TypeInfo("int32_t", 32),
  173. "u32": TypeInfo("uint32_t", 32)
  174. }
  175. self.IsDxilOpFuncCallInst = "hlsl::OP::IsDxilOpFuncCallInst"
  176. def print_content(self):
  177. self.print_header()
  178. self.print_body()
  179. self.print_footer()
  180. def print_header(self):
  181. print("///////////////////////////////////////////////////////////////////////////////")
  182. print("// //")
  183. print("// Copyright (C) Microsoft Corporation. All rights reserved. //")
  184. print("// DxilInstructions.h //")
  185. print("// //")
  186. print("// This file provides a library of instruction helper classes. //")
  187. print("// //")
  188. print("// MUCH WORK YET TO BE DONE - EXPECT THIS WILL CHANGE - GENERATED FILE //")
  189. print("// //")
  190. print("///////////////////////////////////////////////////////////////////////////////")
  191. print("")
  192. print("// TODO: add correct include directives")
  193. print("// TODO: add accessors with values")
  194. print("// TODO: add validation support code, including calling into right fn")
  195. print("// TODO: add type hierarchy")
  196. print("namespace hlsl {")
  197. def bool_lit(self, val):
  198. return "true" if val else "false";
  199. def op_type(self, o):
  200. if o.llvm_type in self.llvm_type_map:
  201. return self.llvm_type_map[o.llvm_type].name
  202. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  203. def op_size(self, o):
  204. if o.llvm_type in self.llvm_type_map:
  205. return self.llvm_type_map[o.llvm_type].bits
  206. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  207. def op_const_expr(self, o):
  208. return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
  209. def op_set_const_expr(self, o):
  210. type_size = self.op_size(o)
  211. return "llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), %d), llvm::APInt(%d, (uint64_t)val))" % (type_size, type_size)
  212. def print_body(self):
  213. for i in self.db.instr:
  214. if i.is_reserved: continue
  215. if i.inst_helper_prefix:
  216. struct_name = "%s_%s" % (i.inst_helper_prefix, i.name)
  217. elif i.is_dxil_op:
  218. struct_name = "DxilInst_%s" % i.name
  219. else:
  220. struct_name = "LlvmInst_%s" % i.name
  221. if i.doc:
  222. print("/// This instruction %s" % i.doc)
  223. print("struct %s {" % struct_name)
  224. print(" llvm::Instruction *Instr;")
  225. print(" // Construction and identification")
  226. print(" %s(llvm::Instruction *pInstr) : Instr(pInstr) {}" % struct_name)
  227. print(" operator bool() const {")
  228. if i.is_dxil_op:
  229. op_name = i.fully_qualified_name()
  230. print(" return %s(Instr, %s);" % (self.IsDxilOpFuncCallInst, op_name))
  231. else:
  232. print(" return Instr->getOpcode() == llvm::Instruction::%s;" % i.name)
  233. print(" }")
  234. print(" // Validation support")
  235. print(" bool isAllowed() const { return %s; }" % self.bool_lit(i.is_allowed))
  236. if i.is_dxil_op:
  237. print(" bool isArgumentListValid() const {")
  238. print(" if (%d != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;" % (len(i.ops) - 1))
  239. print(" return true;")
  240. # TODO - check operand types
  241. print(" }")
  242. EnumWritten = False
  243. for o in i.ops:
  244. if o.pos > 1: # 0 is return type, 1 is DXIL OP id
  245. if not EnumWritten:
  246. print(" // Operand indexes")
  247. print(" enum OperandIdx {")
  248. EnumWritten = True
  249. print(" arg_%s = %d," % (o.name, o.pos - 1))
  250. if EnumWritten:
  251. print(" };")
  252. AccessorsWritten = False
  253. for o in i.ops:
  254. if o.pos > 1: # 0 is return type, 1 is DXIL OP id
  255. if not AccessorsWritten:
  256. print(" // Accessors")
  257. AccessorsWritten = True
  258. print(" llvm::Value *get_%s() const { return Instr->getOperand(%d); }" % (o.name, o.pos - 1))
  259. print(" void set_%s(llvm::Value *val) { Instr->setOperand(%d, val); }" % (o.name, o.pos - 1))
  260. if o.is_const:
  261. print(" %s get_%s_val() const { return %s; }" % (self.op_type(o), o.name, self.op_const_expr(o)))
  262. print(" void set_%s_val(%s val) { Instr->setOperand(%d, %s); }" % (o.name, self.op_type(o), o.pos - 1, self.op_set_const_expr(o)))
  263. print("};")
  264. print("")
  265. def print_footer(self):
  266. print("} // namespace hlsl")
  267. class db_enumhelp_gen:
  268. "A generator of enumeration declarations."
  269. def __init__(self, db):
  270. self.db = db
  271. # Some enums should get a last enum marker.
  272. self.lastEnumNames = {
  273. "OpCode": "NumOpCodes",
  274. "OpCodeClass": "NumOpClasses"
  275. }
  276. def print_enum(self, e, **kwargs):
  277. print("// %s" % e.doc)
  278. print("enum class %s : unsigned {" % e.name)
  279. hide_val = kwargs.get("hide_val", False)
  280. sorted_values = e.values
  281. if kwargs.get("sort_val", True):
  282. sorted_values = sorted(e.values, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  283. last_category = None
  284. for v in sorted_values:
  285. if v.category != last_category:
  286. if last_category != None:
  287. print("")
  288. print(" // %s" % v.category)
  289. last_category = v.category
  290. line_format = " {name}"
  291. if not e.is_internal and not hide_val:
  292. line_format += " = {value}"
  293. line_format += ","
  294. if v.doc:
  295. line_format += " // {doc}"
  296. print(line_format.format(name=v.name, value=v.value, doc=v.doc))
  297. if e.name in self.lastEnumNames:
  298. lastName = self.lastEnumNames[e.name]
  299. versioned = ["%s_Dxil_%d_%d = %d," % (lastName, major, minor, info[lastName])
  300. for (major, minor), info in sorted(self.db.dxil_version_info.items())
  301. if lastName in info]
  302. if versioned:
  303. print("")
  304. for val in versioned:
  305. print(" " + val)
  306. print("")
  307. print(" " + lastName + " = " + str(len(sorted_values)) + " // exclusive last value of enumeration")
  308. print("};")
  309. def print_content(self):
  310. for e in sorted(self.db.enums, key=lambda e : e.name):
  311. self.print_enum(e)
  312. class db_oload_gen:
  313. "A generator of overload tables."
  314. def __init__(self, db):
  315. self.db = db
  316. instrs = [i for i in self.db.instr if i.is_dxil_op]
  317. self.instrs = sorted(instrs, key=lambda i : i.dxil_opid)
  318. def print_content(self):
  319. self.print_opfunc_props()
  320. print("...")
  321. self.print_opfunc_table()
  322. def print_opfunc_props(self):
  323. print("const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {")
  324. print("// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64 function attribute")
  325. # Example formatted string:
  326. # { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, },
  327. # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
  328. # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
  329. last_category = None
  330. # overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong
  331. f = lambda i,c : "true," if i.oload_types.find(c) >= 0 else "false,"
  332. lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
  333. lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
  334. attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate" }
  335. attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
  336. for i in self.instrs:
  337. if last_category != i.category:
  338. if last_category != None:
  339. print("")
  340. print(" // {category:118} void, h, f, d, i1, i8, i16, i32, i64 function attribute".format(category=i.category))
  341. last_category = i.category
  342. print(" {{ OC::{name:24} {quotName:27} OCC::{className:25} {classNameQuot:28} {v:>7}{h:>7}{f:>7}{d:>7}{b:>7}{e:>7}{w:>7}{i:>7}{l:>7} {attr:20} }},".format(
  343. name=i.name+",", quotName='"'+i.name+'",', className=i.dxil_class+",", classNameQuot='"'+lower_fn(i.dxil_class)+'",',
  344. v=f(i,"v"), h=f(i,"h"), f=f(i,"f"), d=f(i,"d"), b=f(i,"1"), e=f(i,"8"), w=f(i,"w"), i=f(i,"i"), l=f(i,"l"), attr=attr_fn(i)))
  345. print("};")
  346. def print_opfunc_table(self):
  347. # Print the table for OP::GetOpFunc
  348. op_type_texts = {
  349. "$cb": "CBRT(pETy);",
  350. "$o": "A(pETy);",
  351. "$r": "RRT(pETy);",
  352. "d": "A(pF64);",
  353. "dims": "A(pDim);",
  354. "f": "A(pF32);",
  355. "h": "A(pF16);",
  356. "i1": "A(pI1);",
  357. "i16": "A(pI16);",
  358. "i32": "A(pI32);",
  359. "i32c": "A(pI32C);",
  360. "i64": "A(pI64);",
  361. "i8": "A(pI8);",
  362. "$u4": "A(pI4S);",
  363. "pf32": "A(pPF32);",
  364. "res": "A(pRes);",
  365. "splitdouble": "A(pSDT);",
  366. "twoi32": "A(p2I32);",
  367. "twof32": "A(p2F32);",
  368. "fouri32": "A(p4I32);",
  369. "fourf32": "A(p4F32);",
  370. "u32": "A(pI32);",
  371. "u64": "A(pI64);",
  372. "u8": "A(pI8);",
  373. "v": "A(pV);",
  374. "w": "A(pWav);",
  375. "SamplePos": "A(pPos);",
  376. }
  377. last_category = None
  378. for i in self.instrs:
  379. if last_category != i.category:
  380. if last_category != None:
  381. print("")
  382. print(" // %s" % i.category)
  383. last_category = i.category
  384. line = " case OpCode::{name:24}".format(name = i.name + ":")
  385. for index, o in enumerate(i.ops):
  386. assert o.llvm_type in op_type_texts, "llvm type %s in instruction %s is unknown" % (o.llvm_type, i.name)
  387. op_type_text = op_type_texts[o.llvm_type]
  388. if index == 0:
  389. line = line + "{val:13}".format(val=op_type_text)
  390. else:
  391. line = line + "{val:9}".format(val=op_type_text)
  392. line = line + "break;"
  393. print(line)
  394. def print_opfunc_oload_type(self):
  395. # Print the function for OP::GetOverloadType
  396. elt_ty = "$o"
  397. res_ret_ty = "$r"
  398. cb_ret_ty = "$cb"
  399. last_category = None
  400. index_dict = collections.OrderedDict()
  401. single_dict = collections.OrderedDict()
  402. struct_list = []
  403. for instr in self.instrs:
  404. ret_ty = instr.ops[0].llvm_type
  405. # Skip case return type is overload type
  406. if (ret_ty == elt_ty):
  407. continue
  408. if ret_ty == res_ret_ty:
  409. struct_list.append(instr.name)
  410. continue
  411. if ret_ty == cb_ret_ty:
  412. struct_list.append(instr.name)
  413. continue
  414. in_param_ty = False
  415. # Try to find elt_ty in parameter types.
  416. for index, op in enumerate(instr.ops):
  417. # Skip return type.
  418. if (op.pos == 0):
  419. continue
  420. # Skip dxil opcode.
  421. if (op.pos == 1):
  422. continue
  423. op_type = op.llvm_type
  424. if (op_type == elt_ty):
  425. # Skip return op
  426. index = index - 1
  427. if index not in index_dict:
  428. index_dict[index] = [instr.name]
  429. else:
  430. index_dict[index].append(instr.name)
  431. in_param_ty = True
  432. break
  433. if in_param_ty:
  434. continue
  435. # No overload, just return the single oload_type.
  436. assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
  437. ty = instr.oload_types[0]
  438. type_code_texts = {
  439. "d": "Type::getDoubleTy(m_Ctx)",
  440. "f": "Type::getFloatTy(m_Ctx)",
  441. "h": "Type::getHalfTy",
  442. "1": "IntegerType::get(m_Ctx, 1)",
  443. "8": "IntegerType::get(m_Ctx, 8)",
  444. "w": "IntegerType::get(m_Ctx, 16)",
  445. "i": "IntegerType::get(m_Ctx, 32)",
  446. "l": "IntegerType::get(m_Ctx, 64)",
  447. "v": "Type::getVoidTy(m_Ctx)",
  448. }
  449. assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
  450. ty_code = type_code_texts[ty]
  451. if ty_code not in single_dict:
  452. single_dict[ty_code] = [instr.name]
  453. else:
  454. single_dict[ty_code].append(instr.name)
  455. for index, opcodes in index_dict.items():
  456. line = ""
  457. for opcode in opcodes:
  458. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  459. line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n"
  460. line = line + " return FT->getParamType(" + str(index) + ");"
  461. print(line)
  462. for code, opcodes in single_dict.items():
  463. line = ""
  464. for opcode in opcodes:
  465. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  466. line = line + " return " + code + ";"
  467. print(line)
  468. line = ""
  469. for opcode in struct_list:
  470. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  471. line = line + "{\n"
  472. line = line + " StructType *ST = cast<StructType>(Ty);\n"
  473. line = line + " return ST->getElementType(0);\n"
  474. line = line + "}"
  475. print(line)
  476. class db_valfns_gen:
  477. "A generator of validation functions."
  478. def __init__(self, db):
  479. self.db = db
  480. def print_content(self):
  481. self.print_header()
  482. self.print_body()
  483. def print_header(self):
  484. print("///////////////////////////////////////////////////////////////////////////////")
  485. print("// Instruction validation functions. //")
  486. def bool_lit(self, val):
  487. return "true" if val else "false";
  488. def op_type(self, o):
  489. if o.llvm_type == "i8":
  490. return "int8_t"
  491. if o.llvm_type == "u8":
  492. return "uint8_t"
  493. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  494. def op_const_expr(self, o):
  495. if o.llvm_type == "i8" or o.llvm_type == "u8":
  496. return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
  497. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  498. def print_body(self):
  499. llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
  500. print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
  501. self.print_comment(" // ", "Allow: %s" % ", ".join([i.name + "=" + str(i.llvm_id) for i in llvm_instrs]))
  502. print(" unsigned op = I.getOpcode();")
  503. print(" return %s;" % build_range_code("op", [i.llvm_id for i in llvm_instrs]))
  504. print("}")
  505. print("")
  506. def print_comment(self, prefix, val):
  507. print(format_comment(prefix, val))
  508. class macro_table_gen:
  509. "A generator for macro tables."
  510. def format_row(self, row, widths, sep=', '):
  511. frow = [str(item) + sep + (' ' * (width - len(item)))
  512. for item, width in list(zip(row, widths))[:-1]] + [str(row[-1])]
  513. return ''.join(frow)
  514. def format_table(self, table, *args, **kwargs):
  515. widths = [ functools.reduce(max, [ len(row[i])
  516. for row in table], 1)
  517. for i in range(len(table[0]))]
  518. formatted = []
  519. for row in table:
  520. formatted.append(self.format_row(row, widths, *args, **kwargs))
  521. return formatted
  522. def print_table(self, table, macro_name):
  523. formatted = self.format_table(table)
  524. print( '// %s\n' % formatted[0] +
  525. '#define %s(DO) \\\n' % macro_name +
  526. ' \\\n'.join([' DO(%s)' % frow for frow in formatted[1:]]))
  527. class db_sigpoint_gen(macro_table_gen):
  528. "A generator for SigPoint tables."
  529. def __init__(self, db):
  530. self.db = db
  531. def print_sigpoint_table(self):
  532. self.print_table(self.db.sigpoint_table, 'DO_SIGPOINTS')
  533. def print_interpretation_table(self):
  534. self.print_table(self.db.interpretation_table, 'DO_INTERPRETATION_TABLE')
  535. def print_content(self):
  536. self.print_sigpoint_table()
  537. self.print_interpretation_table()
  538. class string_output:
  539. def __init__(self):
  540. self.val = ""
  541. def write(self, text):
  542. self.val = self.val + str(text)
  543. def __str__(self):
  544. return self.val
  545. def run_with_stdout(fn):
  546. import sys
  547. _stdout_saved = sys.stdout
  548. so = string_output()
  549. try:
  550. sys.stdout = so
  551. fn()
  552. finally:
  553. sys.stdout = _stdout_saved
  554. return str(so)
  555. def get_hlsl_intrinsic_stats():
  556. db = get_db_hlsl()
  557. longest_fn = db.intrinsics[0]
  558. longest_param = None
  559. longest_arglist_fn = db.intrinsics[0]
  560. for i in sorted(db.intrinsics, key=lambda x: x.key):
  561. # Get some values for maximum lengths.
  562. if len(i.name) > len(longest_fn.name):
  563. longest_fn = i
  564. for p_idx, p in enumerate(i.params):
  565. if p_idx > 0 and (longest_param is None or len(p.name) > len(longest_param.name)):
  566. longest_param = p
  567. if len(i.params) > len(longest_arglist_fn.params):
  568. longest_arglist_fn = i
  569. result = ""
  570. for k in sorted(db.namespaces.keys()):
  571. v = db.namespaces[k]
  572. result += "static const UINT g_u%sCount = %d;\n" % (k, len(v.intrinsics))
  573. result += "\n"
  574. result += "static const int g_MaxIntrinsicName = %d; // Count of characters for longest intrinsic name - '%s'\n" % (len(longest_fn.name), longest_fn.name)
  575. result += "static const int g_MaxIntrinsicParamName = %d; // Count of characters for longest intrinsic parameter name - '%s'\n" % (len(longest_param.name), longest_param.name)
  576. result += "static const int g_MaxIntrinsicParamCount = %d; // Count of parameters (without return) for longest intrinsic argument list - '%s'\n" % (len(longest_arglist_fn.params) - 1, longest_arglist_fn.name)
  577. return result
  578. def get_hlsl_intrinsics():
  579. db = get_db_hlsl()
  580. result = ""
  581. last_ns = ""
  582. ns_table = ""
  583. is_vk_table = False # SPIRV Change
  584. id_prefix = ""
  585. arg_idx = 0
  586. opcode_namespace = db.opcode_namespace
  587. for i in sorted(db.intrinsics, key=lambda x: x.key):
  588. if last_ns != i.ns:
  589. last_ns = i.ns
  590. id_prefix = "IOP" if last_ns == "Intrinsics" else "MOP"
  591. if (len(ns_table)):
  592. result += ns_table + "};\n"
  593. # SPIRV Change Starts
  594. if is_vk_table:
  595. result += "\n#endif // ENABLE_SPIRV_CODEGEN\n"
  596. is_vk_table = False
  597. # SPIRV Change Ends
  598. result += "\n//\n// Start of %s\n//\n\n" % (last_ns)
  599. # This used to be qualified as __declspec(selectany), but that's no longer necessary.
  600. ns_table = "static const HLSL_INTRINSIC g_%s[] =\n{\n" % (last_ns)
  601. # SPIRV Change Starts
  602. if (i.vulkanSpecific):
  603. is_vk_table = True
  604. result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
  605. # SPIRV Change Ends
  606. arg_idx = 0
  607. ns_table += " {(UINT)%s::%s_%s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
  608. result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
  609. for p in i.params:
  610. result += " {\"%s\", %s, %s, %s, %s, %s, %s, %s},\n" % (
  611. p.name, p.param_qual, p.template_id, p.template_list,
  612. p.component_id, p.component_list, p.rows, p.cols)
  613. result += "};\n\n"
  614. arg_idx += 1
  615. result += ns_table + "};\n"
  616. result += "\n#endif // ENABLE_SPIRV_CODEGEN\n" if is_vk_table else "" # SPIRV Change
  617. return result
  618. # SPIRV Change Starts
  619. def wrap_with_ifdef_if_vulkan_specific(intrinsic, text):
  620. if intrinsic.vulkanSpecific:
  621. return "#ifdef ENABLE_SPIRV_CODEGEN\n" + text + "#endif // ENABLE_SPIRV_CODEGEN\n"
  622. return text
  623. # SPIRV Change Ends
  624. def enum_hlsl_intrinsics():
  625. db = get_db_hlsl()
  626. result = ""
  627. enumed = []
  628. for i in sorted(db.intrinsics, key=lambda x: x.key):
  629. if (i.enum_name not in enumed):
  630. enumerant = " %s,\n" % (i.enum_name)
  631. result += wrap_with_ifdef_if_vulkan_specific(i, enumerant) # SPIRV Change
  632. enumed.append(i.enum_name)
  633. # unsigned
  634. result += " // unsigned\n"
  635. for i in sorted(db.intrinsics, key=lambda x: x.key):
  636. if (i.unsigned_op != ""):
  637. if (i.unsigned_op not in enumed):
  638. result += " %s,\n" % (i.unsigned_op)
  639. enumed.append(i.unsigned_op)
  640. result += " Num_Intrinsics,\n"
  641. return result
  642. def has_unsigned_hlsl_intrinsics():
  643. db = get_db_hlsl()
  644. result = ""
  645. enumed = []
  646. # unsigned
  647. for i in sorted(db.intrinsics, key=lambda x: x.key):
  648. if (i.unsigned_op != ""):
  649. if (i.enum_name not in enumed):
  650. result += " case IntrinsicOp::%s:\n" % (i.enum_name)
  651. enumed.append(i.enum_name)
  652. return result
  653. def get_unsigned_hlsl_intrinsics():
  654. db = get_db_hlsl()
  655. result = ""
  656. enumed = []
  657. # unsigned
  658. for i in sorted(db.intrinsics, key=lambda x: x.key):
  659. if (i.unsigned_op != ""):
  660. if (i.enum_name not in enumed):
  661. enumed.append(i.enum_name)
  662. result += " case IntrinsicOp::%s:\n" % (i.enum_name)
  663. result += " return static_cast<unsigned>(IntrinsicOp::%s);\n" % (i.unsigned_op)
  664. return result
  665. def get_oloads_props():
  666. db = get_db_dxil()
  667. gen = db_oload_gen(db)
  668. return run_with_stdout(lambda: gen.print_opfunc_props())
  669. def get_oloads_funcs():
  670. db = get_db_dxil()
  671. gen = db_oload_gen(db)
  672. return run_with_stdout(lambda: gen.print_opfunc_table())
  673. def get_funcs_oload_type():
  674. db = get_db_dxil()
  675. gen = db_oload_gen(db)
  676. return run_with_stdout(lambda: gen.print_opfunc_oload_type())
  677. def get_enum_decl(name, **kwargs):
  678. db = get_db_dxil()
  679. gen = db_enumhelp_gen(db)
  680. return run_with_stdout(lambda: gen.print_enum(db.enum_idx[name], **kwargs))
  681. def get_valrule_enum():
  682. return get_enum_decl("ValidationRule", hide_val=True)
  683. def get_valrule_text():
  684. db = get_db_dxil()
  685. result = "switch(value) {\n"
  686. for v in db.enum_idx["ValidationRule"].values:
  687. result += " case hlsl::ValidationRule::" + v.name + ": return \"" + v.err_msg + "\";\n"
  688. result += "}\n"
  689. return result
  690. def get_instrhelper():
  691. db = get_db_dxil()
  692. gen = db_instrhelp_gen(db)
  693. return run_with_stdout(lambda: gen.print_body())
  694. def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
  695. db = get_db_dxil()
  696. if type(pred) == str:
  697. pred_fn = lambda i: getattr(i, pred)
  698. else:
  699. pred_fn = pred
  700. llvm_instrs = [i for i in db.instr if pred_fn(i)]
  701. result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(getattr(i, attr_name)) for i in llvm_instrs]))
  702. result += "return %s;" % build_range_code(varname, [getattr(i, attr_name) for i in llvm_instrs])
  703. result += "\n"
  704. return result
  705. def get_instrs_rst():
  706. "Create an rst table of allowed LLVM instructions."
  707. db = get_db_dxil()
  708. instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
  709. instrs = sorted(instrs, key=lambda v : v.llvm_id)
  710. rows = []
  711. rows.append(["Instruction", "Action", "Operand overloads"])
  712. for i in instrs:
  713. rows.append([i.name, i.doc, i.oload_types])
  714. result = "\n\n" + format_rst_table(rows) + "\n\n"
  715. # Add detailed instruction information where available.
  716. for i in instrs:
  717. if i.remarks:
  718. result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
  719. return result + "\n"
  720. def get_init_passes():
  721. "Create a series of statements to initialize passes in a registry."
  722. db = get_db_dxil()
  723. result = ""
  724. for p in sorted(db.passes, key=lambda p : p.type_name):
  725. result += "initialize%sPass(Registry);\n" % p.type_name
  726. return result
  727. def get_pass_arg_names():
  728. "Return an ArrayRef of argument names based on passName"
  729. db = get_db_dxil()
  730. decl_result = ""
  731. check_result = ""
  732. for p in sorted(db.passes, key=lambda p : p.type_name):
  733. if len(p.args):
  734. decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
  735. check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
  736. sep = ""
  737. for a in p.args:
  738. decl_result += sep + "\"%s\"" % a.name
  739. sep = ", "
  740. decl_result += " };\n"
  741. return decl_result + check_result
  742. def get_pass_arg_descs():
  743. "Return an ArrayRef of argument descriptions based on passName"
  744. db = get_db_dxil()
  745. decl_result = ""
  746. check_result = ""
  747. for p in sorted(db.passes, key=lambda p : p.type_name):
  748. if len(p.args):
  749. decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
  750. check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
  751. sep = ""
  752. for a in p.args:
  753. decl_result += sep + "\"%s\"" % a.doc
  754. sep = ", "
  755. decl_result += " };\n"
  756. return decl_result + check_result
  757. def get_is_pass_option_name():
  758. "Create a return expression to check whether a value 'S' is a pass option name."
  759. db = get_db_dxil()
  760. prefix = ""
  761. result = "return "
  762. for k in sorted(db.pass_idx_args):
  763. result += prefix + "S.equals(\"%s\")" % k
  764. prefix = "\n || "
  765. return result + ";"
  766. def get_opcodes_rst():
  767. "Create an rst table of opcodes"
  768. db = get_db_dxil()
  769. instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
  770. instrs = sorted(instrs, key=lambda v : v.dxil_opid)
  771. rows = []
  772. rows.append(["ID", "Name", "Description"])
  773. for i in instrs:
  774. op_name = i.dxil_op
  775. if i.remarks:
  776. op_name = op_name + "_" # append _ to enable internal hyperlink on rst files
  777. rows.append([i.dxil_opid, op_name, i.doc])
  778. result = "\n\n" + format_rst_table(rows) + "\n\n"
  779. # Add detailed instruction information where available.
  780. instrs = sorted(instrs, key=lambda v : v.name)
  781. for i in instrs:
  782. if i.remarks:
  783. result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
  784. return result + "\n"
  785. def get_valrules_rst():
  786. "Create an rst table of validation rules instructions."
  787. db = get_db_dxil()
  788. rules = [i for i in db.val_rules if not i.is_disabled]
  789. rules = sorted(rules, key=lambda v : v.name)
  790. rows = []
  791. rows.append(["Rule Code", "Description"])
  792. for i in rules:
  793. rows.append([i.name, i.doc])
  794. return "\n\n" + format_rst_table(rows) + "\n\n"
  795. def get_opsigs():
  796. # Create a list of DXIL operation signatures, sorted by ID.
  797. db = get_db_dxil()
  798. instrs = [i for i in db.instr if i.is_dxil_op]
  799. instrs = sorted(instrs, key=lambda v : v.dxil_opid)
  800. # db_dxil already asserts that the numbering is dense.
  801. # Create the code to write out.
  802. code = "static const char *OpCodeSignatures[] = {\n"
  803. for inst_idx,i in enumerate(instrs):
  804. code += " \"("
  805. for operand in i.ops:
  806. if operand.pos > 1: # skip 0 (the return value) and 1 (the opcode itself)
  807. code += operand.name
  808. if operand.pos < len(i.ops) - 1:
  809. code += ","
  810. code += ")\""
  811. if inst_idx < len(instrs) - 1:
  812. code += ","
  813. code += " // " + i.name
  814. code += "\n"
  815. code += "};\n"
  816. return code
  817. def get_valopcode_sm_text():
  818. db = get_db_dxil()
  819. instrs = [i for i in db.instr if i.is_dxil_op]
  820. instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_stages, v.dxil_opid))
  821. last_model = None
  822. last_stage = None
  823. grouped_instrs = []
  824. code = ""
  825. def flush_instrs(grouped_instrs, last_model, last_stage):
  826. if len(grouped_instrs) == 0:
  827. return ""
  828. result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
  829. result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ")\n"
  830. result += " return "
  831. model_cond = stage_cond = None
  832. if last_model != (6,0):
  833. model_cond = "pSM->GetMajor() > %d || (pSM->GetMajor() == %d && pSM->GetMinor() >= %d)" % (
  834. last_model[0], last_model[0], last_model[1])
  835. if last_stage != "*":
  836. stage_cond = ' || '.join(["pSM->Is%sS()" % c.upper() for c in last_stage])
  837. if model_cond or stage_cond:
  838. result += '\n && '.join(
  839. ["(%s)" % expr for expr in (model_cond, stage_cond) if expr] )
  840. return result + ";\n"
  841. else:
  842. # don't write these out, instead fall through
  843. return ""
  844. for i in instrs:
  845. if (i.shader_model, i.shader_stages) != (last_model, last_stage):
  846. code += flush_instrs(grouped_instrs, last_model, last_stage)
  847. grouped_instrs = []
  848. last_model = i.shader_model
  849. last_stage = i.shader_stages
  850. grouped_instrs.append(i)
  851. code += flush_instrs(grouped_instrs, last_model, last_stage)
  852. code += "return true;\n"
  853. return code
  854. def get_sigpoint_table():
  855. db = get_db_dxil()
  856. gen = db_sigpoint_gen(db)
  857. return run_with_stdout(lambda: gen.print_sigpoint_table())
  858. def get_sigpoint_rst():
  859. "Create an rst table for SigPointKind."
  860. db = get_db_dxil()
  861. rows = [row[:] for row in db.sigpoint_table[:-1]] # Copy table
  862. e = dict([(v.name, v) for v in db.enum_idx['SigPointKind'].values])
  863. rows[0] = ['ID'] + rows[0] + ['Description']
  864. for i in range(1, len(rows)):
  865. row = rows[i]
  866. v = e[row[0]]
  867. rows[i] = [v.value] + row + [v.doc]
  868. return "\n\n" + format_rst_table(rows) + "\n\n"
  869. def get_sem_interpretation_enum_rst():
  870. db = get_db_dxil()
  871. rows = ([['ID', 'Name', 'Description']] +
  872. [[v.value, v.name, v.doc]
  873. for v in db.enum_idx['SemanticInterpretationKind'].values[:-1]])
  874. return "\n\n" + format_rst_table(rows) + "\n\n"
  875. def get_sem_interpretation_table_rst():
  876. db = get_db_dxil()
  877. return "\n\n" + format_rst_table(db.interpretation_table) + "\n\n"
  878. def get_interpretation_table():
  879. db = get_db_dxil()
  880. gen = db_sigpoint_gen(db)
  881. return run_with_stdout(lambda: gen.print_interpretation_table())
  882. def RunCodeTagUpdate(file_path):
  883. import os
  884. import CodeTags
  885. print(" ... updating " + file_path)
  886. args = [file_path, file_path + ".tmp"]
  887. result = CodeTags.main(args)
  888. if result != 0:
  889. print(" ... error: %d" % result)
  890. else:
  891. with open(file_path, 'rt') as f:
  892. before = f.read()
  893. with open(file_path + ".tmp", 'rt') as f:
  894. after = f.read()
  895. if before == after:
  896. print(" --- no changes found")
  897. else:
  898. print(" +++ changes found, updating file")
  899. with open(file_path, 'wt') as f:
  900. f.write(after)
  901. os.remove(file_path + ".tmp")
  902. if __name__ == "__main__":
  903. parser = argparse.ArgumentParser(description="Generate code to handle instructions.")
  904. parser.add_argument("-gen", choices=["docs-ref", "docs-spec", "inst-header", "enums", "oloads", "valfns"], help="Output type to generate.")
  905. parser.add_argument("-update-files", action="store_const", const=True)
  906. args = parser.parse_args()
  907. db = get_db_dxil() # used by all generators, also handy to have it run validation
  908. if args.gen == "docs-ref":
  909. gen = db_docsref_gen(db)
  910. gen.print_content()
  911. if args.gen == "docs-spec":
  912. import os, docutils.core
  913. assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
  914. hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
  915. spec_file = os.path.abspath(os.path.join(hlsl_src_dir, "docs/DXIL.rst"))
  916. with open(spec_file) as f:
  917. s = docutils.core.publish_file(f, writer_name="html")
  918. if args.gen == "inst-header":
  919. gen = db_instrhelp_gen(db)
  920. gen.print_content()
  921. if args.gen == "enums":
  922. gen = db_enumhelp_gen(db)
  923. gen.print_content()
  924. if args.gen == "oloads":
  925. gen = db_oload_gen(db)
  926. gen.print_content()
  927. if args.gen == "valfns":
  928. gen = db_valfns_gen(db)
  929. gen.print_content()
  930. if args.update_files:
  931. print("Updating files ...")
  932. import CodeTags
  933. import os
  934. assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
  935. hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
  936. pj = lambda *parts: os.path.abspath(os.path.join(*parts))
  937. files = [
  938. 'docs/DXIL.rst',
  939. 'lib/HLSL/DXILOperations.cpp',
  940. 'include/dxc/HLSL/DXILConstants.h',
  941. 'include/dxc/HLSL/DxilValidation.h',
  942. 'include/dxc/HLSL/DxilInstructions.h',
  943. 'lib/HLSL/DxcOptimizer.cpp',
  944. 'lib/HLSL/DxilValidation.cpp',
  945. 'tools/clang/lib/Sema/gen_intrin_main_tables_15.h',
  946. 'include/dxc/HlslIntrinsicOp.h',
  947. 'tools/clang/tools/dxcompiler/dxcdisassembler.cpp',
  948. 'include/dxc/HLSL/DxilSigPoint.inl',
  949. ]
  950. for relative_file_path in files:
  951. RunCodeTagUpdate(pj(hlsl_src_dir, relative_file_path))