hctdb_instrhelp.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460
  1. # Copyright (C) Microsoft Corporation. All rights reserved.
  2. # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
  3. import argparse
  4. import functools
  5. import collections
  6. from hctdb import *
  7. # get db singletons
  8. g_db_dxil = None
  9. def get_db_dxil():
  10. global g_db_dxil
  11. if g_db_dxil is None:
  12. g_db_dxil = db_dxil()
  13. return g_db_dxil
  14. g_db_hlsl = None
  15. def get_db_hlsl():
  16. global g_db_hlsl
  17. if g_db_hlsl is None:
  18. thisdir = os.path.dirname(os.path.realpath(__file__))
  19. with open(os.path.join(thisdir, "gen_intrin_main.txt"), "r") as f:
  20. g_db_hlsl = db_hlsl(f)
  21. return g_db_hlsl
  22. def format_comment(prefix, val):
  23. "Formats a value with a line-comment prefix."
  24. result = ""
  25. line_width = 80
  26. content_width = line_width - len(prefix)
  27. l = len(val)
  28. while l:
  29. if l < content_width:
  30. result += prefix + val.strip()
  31. result += "\n"
  32. l = 0
  33. else:
  34. split_idx = val.rfind(" ", 0, content_width)
  35. result += prefix + val[:split_idx].strip()
  36. result += "\n"
  37. val = val[split_idx+1:]
  38. l = len(val)
  39. return result
  40. def format_rst_table(list_of_tuples):
  41. "Produces a reStructuredText simple table from the specified list of tuples."
  42. # Calculate widths.
  43. widths = None
  44. for t in list_of_tuples:
  45. if widths is None:
  46. widths = [0] * len(t)
  47. for i, v in enumerate(t):
  48. widths[i] = max(widths[i], len(str(v)))
  49. # Build banner line.
  50. banner = ""
  51. for i, w in enumerate(widths):
  52. if i > 0:
  53. banner += " "
  54. banner += "=" * w
  55. banner += "\n"
  56. # Build the result.
  57. result = banner
  58. for i, t in enumerate(list_of_tuples):
  59. for j, v in enumerate(t):
  60. if j > 0:
  61. result += " "
  62. result += str(v)
  63. result += " " * (widths[j] - len(str(v)))
  64. result = result.rstrip()
  65. result += "\n"
  66. if i == 0:
  67. result += banner
  68. result += banner
  69. return result
  70. def build_range_tuples(i):
  71. "Produces a list of tuples with contiguous ranges in the input list."
  72. i = sorted(i)
  73. low_bound = None
  74. high_bound = None
  75. for val in i:
  76. if low_bound is None:
  77. low_bound = val
  78. high_bound = val
  79. else:
  80. assert(not high_bound is None)
  81. if val == high_bound + 1:
  82. high_bound = val
  83. else:
  84. yield (low_bound, high_bound)
  85. low_bound = val
  86. high_bound = val
  87. if not low_bound is None:
  88. yield (low_bound, high_bound)
  89. def build_range_code(var, i):
  90. "Produces a fragment of code that tests whether the variable name matches values in the given range."
  91. ranges = build_range_tuples(i)
  92. result = ""
  93. for r in ranges:
  94. if r[0] == r[1]:
  95. cond = var + " == " + str(r[0])
  96. else:
  97. cond = "(%d <= %s && %s <= %d)" % (r[0], var, var, r[1])
  98. if result == "":
  99. result = cond
  100. else:
  101. result = result + " || " + cond
  102. return result
  103. class db_docsref_gen:
  104. "A generator of reference documentation."
  105. def __init__(self, db):
  106. self.db = db
  107. instrs = [i for i in self.db.instr if i.is_dxil_op]
  108. instrs = sorted(instrs, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  109. self.instrs = instrs
  110. val_rules = sorted(db.val_rules, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  111. self.val_rules = val_rules
  112. def print_content(self):
  113. self.print_header()
  114. self.print_body()
  115. self.print_footer()
  116. def print_header(self):
  117. print("<!DOCTYPE html>")
  118. print("<html><head><title>DXIL Reference</title>")
  119. print("<style>body { font-family: Verdana; font-size: small; }</style>")
  120. print("</head><body><h1>DXIL Reference</h1>")
  121. self.print_toc("Instructions", "i", self.instrs)
  122. self.print_toc("Rules", "r", self.val_rules)
  123. def print_body(self):
  124. self.print_instruction_details()
  125. self.print_valrule_details()
  126. def print_instruction_details(self):
  127. print("<h2>Instruction Details</h2>")
  128. for i in self.instrs:
  129. print("<h3><a name='i%s'>%s</a></h3>" % (i.name, i.name))
  130. print("<div>Opcode: %d. This instruction %s.</div>" % (i.dxil_opid, i.doc))
  131. if i.remarks:
  132. # This is likely a .rst fragment, but this will do for now.
  133. print("<div> " + i.remarks + "</div>")
  134. print("<div>Operands:</div>")
  135. print("<ul>")
  136. for o in i.ops:
  137. if o.pos == 0:
  138. print("<li>result: %s - %s</li>" % (o.llvm_type, o.doc))
  139. else:
  140. enum_desc = "" if o.enum_name == "" else " one of %s: %s" % (o.enum_name, ",".join(db.enum_idx[o.enum_name].value_names()))
  141. print("<li>%d - %s: %s%s%s</li>" % (o.pos - 1, o.name, o.llvm_type, "" if o.doc == "" else " - " + o.doc, enum_desc))
  142. print("</ul>")
  143. print("<div><a href='#Instructions'>(top)</a></div>")
  144. def print_valrule_details(self):
  145. print("<h2>Rule Details</h2>")
  146. for i in self.val_rules:
  147. print("<h3><a name='r%s'>%s</a></h3>" % (i.name, i.name))
  148. print("<div>" + i.doc + "</div>")
  149. print("<div><a href='#Rules'>(top)</a></div>")
  150. def print_toc(self, name, aprefix, values):
  151. print("<h2><a name='" + name + "'>" + name + "</a></h2>")
  152. last_category = ""
  153. for i in values:
  154. if i.category != last_category:
  155. if last_category != None:
  156. print("</ul>")
  157. print("<div><b>%s</b></div><ul>" % i.category)
  158. last_category = i.category
  159. print("<li><a href='#" + aprefix + "%s'>%s</a></li>" % (i.name, i.name))
  160. print("</ul>")
  161. def print_footer(self):
  162. print("</body></html>")
  163. class db_instrhelp_gen:
  164. "A generator of instruction helper classes."
  165. def __init__(self, db):
  166. self.db = db
  167. TypeInfo = collections.namedtuple("TypeInfo", "name bits")
  168. self.llvm_type_map = {
  169. "i1": TypeInfo("bool", 1),
  170. "i8": TypeInfo("int8_t", 8),
  171. "u8": TypeInfo("uint8_t", 8),
  172. "i32": TypeInfo("int32_t", 32),
  173. "u32": TypeInfo("uint32_t", 32)
  174. }
  175. self.IsDxilOpFuncCallInst = "hlsl::OP::IsDxilOpFuncCallInst"
  176. def print_content(self):
  177. self.print_header()
  178. self.print_body()
  179. self.print_footer()
  180. def print_header(self):
  181. print("///////////////////////////////////////////////////////////////////////////////")
  182. print("// //")
  183. print("// Copyright (C) Microsoft Corporation. All rights reserved. //")
  184. print("// DxilInstructions.h //")
  185. print("// //")
  186. print("// This file provides a library of instruction helper classes. //")
  187. print("// //")
  188. print("// MUCH WORK YET TO BE DONE - EXPECT THIS WILL CHANGE - GENERATED FILE //")
  189. print("// //")
  190. print("///////////////////////////////////////////////////////////////////////////////")
  191. print("")
  192. print("// TODO: add correct include directives")
  193. print("// TODO: add accessors with values")
  194. print("// TODO: add validation support code, including calling into right fn")
  195. print("// TODO: add type hierarchy")
  196. print("namespace hlsl {")
  197. def bool_lit(self, val):
  198. return "true" if val else "false";
  199. def op_type(self, o):
  200. if o.llvm_type in self.llvm_type_map:
  201. return self.llvm_type_map[o.llvm_type].name
  202. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  203. def op_size(self, o):
  204. if o.llvm_type in self.llvm_type_map:
  205. return self.llvm_type_map[o.llvm_type].bits
  206. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  207. def op_const_expr(self, o):
  208. return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
  209. def op_set_const_expr(self, o):
  210. type_size = self.op_size(o)
  211. return "llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), %d), llvm::APInt(%d, (uint64_t)val))" % (type_size, type_size)
  212. def print_body(self):
  213. for i in self.db.instr:
  214. if i.is_reserved: continue
  215. if i.inst_helper_prefix:
  216. struct_name = "%s_%s" % (i.inst_helper_prefix, i.name)
  217. elif i.is_dxil_op:
  218. struct_name = "DxilInst_%s" % i.name
  219. else:
  220. struct_name = "LlvmInst_%s" % i.name
  221. if i.doc:
  222. print("/// This instruction %s" % i.doc)
  223. print("struct %s {" % struct_name)
  224. print(" llvm::Instruction *Instr;")
  225. print(" // Construction and identification")
  226. print(" %s(llvm::Instruction *pInstr) : Instr(pInstr) {}" % struct_name)
  227. print(" operator bool() const {")
  228. if i.is_dxil_op:
  229. op_name = i.fully_qualified_name()
  230. print(" return %s(Instr, %s);" % (self.IsDxilOpFuncCallInst, op_name))
  231. else:
  232. print(" return Instr->getOpcode() == llvm::Instruction::%s;" % i.name)
  233. print(" }")
  234. print(" // Validation support")
  235. print(" bool isAllowed() const { return %s; }" % self.bool_lit(i.is_allowed))
  236. if i.is_dxil_op:
  237. print(" bool isArgumentListValid() const {")
  238. print(" if (%d != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;" % (len(i.ops) - 1))
  239. print(" return true;")
  240. # TODO - check operand types
  241. print(" }")
  242. print(" // Metadata")
  243. print(" bool requiresUniformInputs() const { return %s; }" % self.bool_lit(i.requires_uniform_inputs))
  244. EnumWritten = False
  245. for o in i.ops:
  246. if o.pos > 1: # 0 is return type, 1 is DXIL OP id
  247. if not EnumWritten:
  248. print(" // Operand indexes")
  249. print(" enum OperandIdx {")
  250. EnumWritten = True
  251. print(" arg_%s = %d," % (o.name, o.pos - 1))
  252. if EnumWritten:
  253. print(" };")
  254. AccessorsWritten = False
  255. for o in i.ops:
  256. if o.pos > 1: # 0 is return type, 1 is DXIL OP id
  257. if not AccessorsWritten:
  258. print(" // Accessors")
  259. AccessorsWritten = True
  260. print(" llvm::Value *get_%s() const { return Instr->getOperand(%d); }" % (o.name, o.pos - 1))
  261. print(" void set_%s(llvm::Value *val) { Instr->setOperand(%d, val); }" % (o.name, o.pos - 1))
  262. if o.is_const:
  263. if o.llvm_type in self.llvm_type_map:
  264. print(" %s get_%s_val() const { return %s; }" % (self.op_type(o), o.name, self.op_const_expr(o)))
  265. print(" void set_%s_val(%s val) { Instr->setOperand(%d, %s); }" % (o.name, self.op_type(o), o.pos - 1, self.op_set_const_expr(o)))
  266. print("};")
  267. print("")
  268. def print_footer(self):
  269. print("} // namespace hlsl")
  270. class db_enumhelp_gen:
  271. "A generator of enumeration declarations."
  272. def __init__(self, db):
  273. self.db = db
  274. # Some enums should get a last enum marker.
  275. self.lastEnumNames = {
  276. "OpCode": "NumOpCodes",
  277. "OpCodeClass": "NumOpClasses"
  278. }
  279. def print_enum(self, e, **kwargs):
  280. print("// %s" % e.doc)
  281. print("enum class %s : unsigned {" % e.name)
  282. hide_val = kwargs.get("hide_val", False)
  283. sorted_values = e.values
  284. if kwargs.get("sort_val", True):
  285. sorted_values = sorted(e.values, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
  286. last_category = None
  287. for v in sorted_values:
  288. if v.category != last_category:
  289. if last_category != None:
  290. print("")
  291. print(" // %s" % v.category)
  292. last_category = v.category
  293. line_format = " {name}"
  294. if not e.is_internal and not hide_val:
  295. line_format += " = {value}"
  296. line_format += ","
  297. if v.doc:
  298. line_format += " // {doc}"
  299. print(line_format.format(name=v.name, value=v.value, doc=v.doc))
  300. if e.name in self.lastEnumNames:
  301. lastName = self.lastEnumNames[e.name]
  302. versioned = ["%s_Dxil_%d_%d = %d," % (lastName, major, minor, info[lastName])
  303. for (major, minor), info in sorted(self.db.dxil_version_info.items())
  304. if lastName in info]
  305. if versioned:
  306. print("")
  307. for val in versioned:
  308. print(" " + val)
  309. print("")
  310. print(" " + lastName + " = " + str(len(sorted_values)) + " // exclusive last value of enumeration")
  311. print("};")
  312. def print_content(self):
  313. for e in sorted(self.db.enums, key=lambda e : e.name):
  314. self.print_enum(e)
  315. class db_oload_gen:
  316. "A generator of overload tables."
  317. def __init__(self, db):
  318. self.db = db
  319. instrs = [i for i in self.db.instr if i.is_dxil_op]
  320. self.instrs = sorted(instrs, key=lambda i : i.dxil_opid)
  321. # Allow these to be overridden by external scripts.
  322. self.OP = "OP"
  323. self.OC = "OC"
  324. self.OCC = "OCC"
  325. def print_content(self):
  326. self.print_opfunc_props()
  327. print("...")
  328. self.print_opfunc_table()
  329. def print_opfunc_props(self):
  330. print("const {OP}::OpCodeProperty {OP}::m_OpCodeProps[(unsigned){OP}::OpCode::NumOpCodes] = {{".format(OP=self.OP))
  331. print("// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64, udt, obj, function attribute")
  332. # Example formatted string:
  333. # { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, },
  334. # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
  335. # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
  336. last_category = None
  337. # overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong, u(dt)
  338. f = lambda i,c : "true" if i.oload_types.find(c) >= 0 else "false"
  339. lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
  340. lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
  341. attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn", "wv" : "None" }
  342. attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
  343. for i in self.instrs:
  344. if last_category != i.category:
  345. if last_category != None:
  346. print("")
  347. print(" // {category:118} void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute".format(category=i.category))
  348. last_category = i.category
  349. print(" {{ {OC}::{name:24} {quotName:27} {OCC}::{className:25} {classNameQuot:28} {{{v:>6},{h:>6},{f:>6},{d:>6},{b:>6},{e:>6},{w:>6},{i:>6},{l:>6},{u:>6},{o:>6}}}, {attr:20} }},".format(
  350. name=i.name+",", quotName='"'+i.name+'",', className=i.dxil_class+",", classNameQuot='"'+lower_fn(i.dxil_class)+'",',
  351. v=f(i,"v"), h=f(i,"h"), f=f(i,"f"), d=f(i,"d"), b=f(i,"1"), e=f(i,"8"), w=f(i,"w"), i=f(i,"i"), l=f(i,"l"), u=f(i,"u"), o=f(i,"o"), attr=attr_fn(i),
  352. OC=self.OC, OCC=self.OCC))
  353. print("};")
  354. def print_opfunc_table(self):
  355. # Print the table for OP::GetOpFunc
  356. op_type_texts = {
  357. "$cb": "CBRT(pETy);",
  358. "$o": "A(pETy);",
  359. "$r": "RRT(pETy);",
  360. "d": "A(pF64);",
  361. "dims": "A(pDim);",
  362. "f": "A(pF32);",
  363. "h": "A(pF16);",
  364. "i1": "A(pI1);",
  365. "i16": "A(pI16);",
  366. "i32": "A(pI32);",
  367. "i32c": "A(pI32C);",
  368. "i64": "A(pI64);",
  369. "i8": "A(pI8);",
  370. "$u4": "A(pI4S);",
  371. "pf32": "A(pPF32);",
  372. "res": "A(pRes);",
  373. "splitdouble": "A(pSDT);",
  374. "twoi32": "A(p2I32);",
  375. "twof32": "A(p2F32);",
  376. "twof16": "A(p2F16);",
  377. "twoi16": "A(p2I16);",
  378. "threei32": "A(p3I32);",
  379. "threef32": "A(p3F32);",
  380. "fouri32": "A(p4I32);",
  381. "fourf32": "A(p4F32);",
  382. "fouri16": "A(p4I16);",
  383. "u32": "A(pI32);",
  384. "u64": "A(pI64);",
  385. "u8": "A(pI8);",
  386. "v": "A(pV);",
  387. "$vec4" : "VEC4(pETy);",
  388. "w": "A(pWav);",
  389. "SamplePos": "A(pPos);",
  390. "udt": "A(udt);",
  391. "obj": "A(obj);",
  392. "resproperty": "A(resProperty);",
  393. "resbind": "A(resBind);",
  394. }
  395. last_category = None
  396. for i in self.instrs:
  397. if last_category != i.category:
  398. if last_category != None:
  399. print("")
  400. print(" // %s" % i.category)
  401. last_category = i.category
  402. line = " case OpCode::{name:24}".format(name = i.name + ":")
  403. for index, o in enumerate(i.ops):
  404. assert o.llvm_type in op_type_texts, "llvm type %s in instruction %s is unknown" % (o.llvm_type, i.name)
  405. op_type_text = op_type_texts[o.llvm_type]
  406. if index == 0:
  407. line = line + "{val:13}".format(val=op_type_text)
  408. else:
  409. line = line + "{val:9}".format(val=op_type_text)
  410. line = line + "break;"
  411. print(line)
  412. def print_opfunc_oload_type(self):
  413. # Print the function for OP::GetOverloadType
  414. elt_ty = "$o"
  415. res_ret_ty = "$r"
  416. cb_ret_ty = "$cb"
  417. udt_ty = "udt"
  418. obj_ty = "obj"
  419. vec_ty = "$vec"
  420. last_category = None
  421. index_dict = collections.OrderedDict()
  422. single_dict = collections.OrderedDict()
  423. struct_list = []
  424. for instr in self.instrs:
  425. ret_ty = instr.ops[0].llvm_type
  426. # Skip case return type is overload type
  427. if (ret_ty == elt_ty):
  428. continue
  429. if ret_ty == res_ret_ty:
  430. struct_list.append(instr.name)
  431. continue
  432. if ret_ty == cb_ret_ty:
  433. struct_list.append(instr.name)
  434. continue
  435. if ret_ty.startswith(vec_ty):
  436. struct_list.append(instr.name);
  437. continue
  438. in_param_ty = False
  439. # Try to find elt_ty in parameter types.
  440. for index, op in enumerate(instr.ops):
  441. # Skip return type.
  442. if (op.pos == 0):
  443. continue
  444. # Skip dxil opcode.
  445. if (op.pos == 1):
  446. continue
  447. op_type = op.llvm_type
  448. if (op_type == elt_ty):
  449. # Skip return op
  450. index = index - 1
  451. if index not in index_dict:
  452. index_dict[index] = [instr.name]
  453. else:
  454. index_dict[index].append(instr.name)
  455. in_param_ty = True
  456. break
  457. if (op_type == udt_ty or op_type == obj_ty):
  458. # Skip return op
  459. index = index - 1
  460. if index not in index_dict:
  461. index_dict[index] = [instr.name]
  462. else:
  463. index_dict[index].append(instr.name)
  464. in_param_ty = True
  465. if in_param_ty:
  466. continue
  467. # No overload, just return the single oload_type.
  468. assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
  469. ty = instr.oload_types[0]
  470. type_code_texts = {
  471. "d": "Type::getDoubleTy(Ctx)",
  472. "f": "Type::getFloatTy(Ctx)",
  473. "h": "Type::getHalfTy",
  474. "1": "IntegerType::get(Ctx, 1)",
  475. "8": "IntegerType::get(Ctx, 8)",
  476. "w": "IntegerType::get(Ctx, 16)",
  477. "i": "IntegerType::get(Ctx, 32)",
  478. "l": "IntegerType::get(Ctx, 64)",
  479. "v": "Type::getVoidTy(Ctx)",
  480. "u": "Type::getInt32PtrTy(Ctx)",
  481. "o": "Type::getInt32PtrTy(Ctx)",
  482. }
  483. assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
  484. ty_code = type_code_texts[ty]
  485. if ty_code not in single_dict:
  486. single_dict[ty_code] = [instr.name]
  487. else:
  488. single_dict[ty_code].append(instr.name)
  489. for index, opcodes in index_dict.items():
  490. line = ""
  491. for opcode in opcodes:
  492. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  493. line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n"
  494. line = line + " return FT->getParamType(" + str(index) + ");"
  495. print(line)
  496. for code, opcodes in single_dict.items():
  497. line = ""
  498. for opcode in opcodes:
  499. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  500. line = line + " return " + code + ";"
  501. print(line)
  502. line = ""
  503. for opcode in struct_list:
  504. line = line + "case OpCode::{name}".format(name = opcode + ":\n")
  505. line = line + "{\n"
  506. line = line + " StructType *ST = cast<StructType>(Ty);\n"
  507. line = line + " return ST->getElementType(0);\n"
  508. line = line + "}"
  509. print(line)
  510. class db_valfns_gen:
  511. "A generator of validation functions."
  512. def __init__(self, db):
  513. self.db = db
  514. def print_content(self):
  515. self.print_header()
  516. self.print_body()
  517. def print_header(self):
  518. print("///////////////////////////////////////////////////////////////////////////////")
  519. print("// Instruction validation functions. //")
  520. def bool_lit(self, val):
  521. return "true" if val else "false";
  522. def op_type(self, o):
  523. if o.llvm_type == "i8":
  524. return "int8_t"
  525. if o.llvm_type == "u8":
  526. return "uint8_t"
  527. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  528. def op_const_expr(self, o):
  529. if o.llvm_type == "i8" or o.llvm_type == "u8":
  530. return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
  531. raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
  532. def print_body(self):
  533. llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
  534. print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
  535. self.print_comment(" // ", "Allow: %s" % ", ".join([i.name + "=" + str(i.llvm_id) for i in llvm_instrs]))
  536. print(" unsigned op = I.getOpcode();")
  537. print(" return %s;" % build_range_code("op", [i.llvm_id for i in llvm_instrs]))
  538. print("}")
  539. print("")
  540. def print_comment(self, prefix, val):
  541. print(format_comment(prefix, val))
  542. class macro_table_gen:
  543. "A generator for macro tables."
  544. def format_row(self, row, widths, sep=', '):
  545. frow = [str(item) + sep + (' ' * (width - len(item)))
  546. for item, width in list(zip(row, widths))[:-1]] + [str(row[-1])]
  547. return ''.join(frow)
  548. def format_table(self, table, *args, **kwargs):
  549. widths = [ functools.reduce(max, [ len(row[i])
  550. for row in table], 1)
  551. for i in range(len(table[0]))]
  552. formatted = []
  553. for row in table:
  554. formatted.append(self.format_row(row, widths, *args, **kwargs))
  555. return formatted
  556. def print_table(self, table, macro_name):
  557. formatted = self.format_table(table)
  558. print( '// %s\n' % formatted[0] +
  559. '#define %s(ROW) \\\n' % macro_name +
  560. ' \\\n'.join([' ROW(%s)' % frow for frow in formatted[1:]]))
  561. class db_sigpoint_gen(macro_table_gen):
  562. "A generator for SigPoint tables."
  563. def __init__(self, db):
  564. self.db = db
  565. def print_sigpoint_table(self):
  566. self.print_table(self.db.sigpoint_table, 'DO_SIGPOINTS')
  567. def print_interpretation_table(self):
  568. self.print_table(self.db.interpretation_table, 'DO_INTERPRETATION_TABLE')
  569. def print_content(self):
  570. self.print_sigpoint_table()
  571. self.print_interpretation_table()
  572. class string_output:
  573. def __init__(self):
  574. self.val = ""
  575. def write(self, text):
  576. self.val = self.val + str(text)
  577. def __str__(self):
  578. return self.val
  579. def run_with_stdout(fn):
  580. import sys
  581. _stdout_saved = sys.stdout
  582. so = string_output()
  583. try:
  584. sys.stdout = so
  585. fn()
  586. finally:
  587. sys.stdout = _stdout_saved
  588. return str(so)
  589. def get_hlsl_intrinsic_stats():
  590. db = get_db_hlsl()
  591. longest_fn = db.intrinsics[0]
  592. longest_param = None
  593. longest_arglist_fn = db.intrinsics[0]
  594. for i in sorted(db.intrinsics, key=lambda x: x.key):
  595. # Get some values for maximum lengths.
  596. if len(i.name) > len(longest_fn.name):
  597. longest_fn = i
  598. for p_idx, p in enumerate(i.params):
  599. if p_idx > 0 and (longest_param is None or len(p.name) > len(longest_param.name)):
  600. longest_param = p
  601. if len(i.params) > len(longest_arglist_fn.params):
  602. longest_arglist_fn = i
  603. result = ""
  604. for k in sorted(db.namespaces.keys()):
  605. v = db.namespaces[k]
  606. result += "static const UINT g_u%sCount = %d;\n" % (k, len(v.intrinsics))
  607. result += "\n"
  608. result += "static const int g_MaxIntrinsicName = %d; // Count of characters for longest intrinsic name - '%s'\n" % (len(longest_fn.name), longest_fn.name)
  609. result += "static const int g_MaxIntrinsicParamName = %d; // Count of characters for longest intrinsic parameter name - '%s'\n" % (len(longest_param.name), longest_param.name)
  610. result += "static const int g_MaxIntrinsicParamCount = %d; // Count of parameters (without return) for longest intrinsic argument list - '%s'\n" % (len(longest_arglist_fn.params) - 1, longest_arglist_fn.name)
  611. return result
  612. def get_hlsl_intrinsics():
  613. db = get_db_hlsl()
  614. result = ""
  615. last_ns = ""
  616. ns_table = ""
  617. is_vk_table = False # SPIRV Change
  618. id_prefix = ""
  619. arg_idx = 0
  620. opcode_namespace = db.opcode_namespace
  621. for i in sorted(db.intrinsics, key=lambda x: x.key):
  622. if last_ns != i.ns:
  623. last_ns = i.ns
  624. id_prefix = "IOP" if last_ns == "Intrinsics" or last_ns == "VkIntrinsics" else "MOP" # SPIRV Change
  625. if (len(ns_table)):
  626. result += ns_table + "};\n"
  627. # SPIRV Change Starts
  628. if is_vk_table:
  629. result += "\n#endif // ENABLE_SPIRV_CODEGEN\n"
  630. is_vk_table = False
  631. # SPIRV Change Ends
  632. result += "\n//\n// Start of %s\n//\n\n" % (last_ns)
  633. # This used to be qualified as __declspec(selectany), but that's no longer necessary.
  634. ns_table = "static const HLSL_INTRINSIC g_%s[] =\n{\n" % (last_ns)
  635. # SPIRV Change Starts
  636. if (i.vulkanSpecific):
  637. is_vk_table = True
  638. result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
  639. # SPIRV Change Ends
  640. arg_idx = 0
  641. ns_table += " {(UINT)%s::%s_%s, %s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), str(i.wave).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
  642. result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
  643. for p in i.params:
  644. name = p.name
  645. if name == i.name and i.hidden:
  646. # First parameter defines intrinsic name for parsing in HLSL.
  647. # Prepend '$hidden$' for hidden intrinsic so it can't be used in HLSL.
  648. name = "$hidden$" + name
  649. result += " {\"%s\", %s, %s, %s, %s, %s, %s, %s},\n" % (
  650. name, p.param_qual, p.template_id, p.template_list,
  651. p.component_id, p.component_list, p.rows, p.cols)
  652. result += "};\n\n"
  653. arg_idx += 1
  654. result += ns_table + "};\n"
  655. result += "\n#endif // ENABLE_SPIRV_CODEGEN\n" if is_vk_table else "" # SPIRV Change
  656. return result
  657. # SPIRV Change Starts
  658. def wrap_with_ifdef_if_vulkan_specific(intrinsic, text):
  659. if intrinsic.vulkanSpecific:
  660. return "#ifdef ENABLE_SPIRV_CODEGEN\n" + text + "#endif // ENABLE_SPIRV_CODEGEN\n"
  661. return text
  662. # SPIRV Change Ends
  663. def enum_hlsl_intrinsics():
  664. db = get_db_hlsl()
  665. result = ""
  666. enumed = []
  667. for i in sorted(db.intrinsics, key=lambda x: x.key):
  668. if (i.enum_name not in enumed):
  669. enumerant = " %s,\n" % (i.enum_name)
  670. result += wrap_with_ifdef_if_vulkan_specific(i, enumerant) # SPIRV Change
  671. enumed.append(i.enum_name)
  672. # unsigned
  673. result += " // unsigned\n"
  674. for i in sorted(db.intrinsics, key=lambda x: x.key):
  675. if (i.unsigned_op != ""):
  676. if (i.unsigned_op not in enumed):
  677. result += " %s,\n" % (i.unsigned_op)
  678. enumed.append(i.unsigned_op)
  679. result += " Num_Intrinsics,\n"
  680. return result
  681. def has_unsigned_hlsl_intrinsics():
  682. db = get_db_hlsl()
  683. result = ""
  684. enumed = []
  685. # unsigned
  686. for i in sorted(db.intrinsics, key=lambda x: x.key):
  687. if (i.unsigned_op != ""):
  688. if (i.enum_name not in enumed):
  689. result += " case IntrinsicOp::%s:\n" % (i.enum_name)
  690. enumed.append(i.enum_name)
  691. return result
  692. def get_unsigned_hlsl_intrinsics():
  693. db = get_db_hlsl()
  694. result = ""
  695. enumed = []
  696. # unsigned
  697. for i in sorted(db.intrinsics, key=lambda x: x.key):
  698. if (i.unsigned_op != ""):
  699. if (i.enum_name not in enumed):
  700. enumed.append(i.enum_name)
  701. result += " case IntrinsicOp::%s:\n" % (i.enum_name)
  702. result += " return static_cast<unsigned>(IntrinsicOp::%s);\n" % (i.unsigned_op)
  703. return result
  704. def get_oloads_props():
  705. db = get_db_dxil()
  706. gen = db_oload_gen(db)
  707. return run_with_stdout(lambda: gen.print_opfunc_props())
  708. def get_oloads_funcs():
  709. db = get_db_dxil()
  710. gen = db_oload_gen(db)
  711. return run_with_stdout(lambda: gen.print_opfunc_table())
  712. def get_funcs_oload_type():
  713. db = get_db_dxil()
  714. gen = db_oload_gen(db)
  715. return run_with_stdout(lambda: gen.print_opfunc_oload_type())
  716. def get_enum_decl(name, **kwargs):
  717. db = get_db_dxil()
  718. gen = db_enumhelp_gen(db)
  719. return run_with_stdout(lambda: gen.print_enum(db.enum_idx[name], **kwargs))
  720. def get_valrule_enum():
  721. return get_enum_decl("ValidationRule", hide_val=True)
  722. def get_valrule_text():
  723. db = get_db_dxil()
  724. result = "switch(value) {\n"
  725. for v in db.enum_idx["ValidationRule"].values:
  726. result += " case hlsl::ValidationRule::" + v.name + ": return \"" + v.err_msg + "\";\n"
  727. result += "}\n"
  728. return result
  729. def get_instrhelper():
  730. db = get_db_dxil()
  731. gen = db_instrhelp_gen(db)
  732. return run_with_stdout(lambda: gen.print_body())
  733. def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
  734. db = get_db_dxil()
  735. if type(pred) == str:
  736. pred_fn = lambda i: getattr(i, pred)
  737. else:
  738. pred_fn = pred
  739. llvm_instrs = [i for i in db.instr if pred_fn(i)]
  740. result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(getattr(i, attr_name)) for i in llvm_instrs]))
  741. result += "return %s;" % build_range_code(varname, [getattr(i, attr_name) for i in llvm_instrs])
  742. result += "\n"
  743. return result
  744. def counter_pred(name, dxil_op=True):
  745. def pred(i):
  746. return (dxil_op == i.is_dxil_op) and getattr(i, 'props') and 'counters' in i.props and name in i.props['counters']
  747. return pred
  748. def get_counters():
  749. db = get_db_dxil()
  750. return db.counters
  751. def get_llvm_op_counters():
  752. db = get_db_dxil()
  753. return [c for c in db.counters if c in db.llvm_op_counters]
  754. def get_dxil_op_counters():
  755. db = get_db_dxil()
  756. return [c for c in db.counters if c in db.dxil_op_counters]
  757. def get_instrs_rst():
  758. "Create an rst table of allowed LLVM instructions."
  759. db = get_db_dxil()
  760. instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
  761. instrs = sorted(instrs, key=lambda v : v.llvm_id)
  762. rows = []
  763. rows.append(["Instruction", "Action", "Operand overloads"])
  764. for i in instrs:
  765. rows.append([i.name, i.doc, i.oload_types])
  766. result = "\n\n" + format_rst_table(rows) + "\n\n"
  767. # Add detailed instruction information where available.
  768. for i in instrs:
  769. if i.remarks:
  770. result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
  771. return result + "\n"
  772. def get_init_passes(category_libs):
  773. "Create a series of statements to initialize passes in a registry."
  774. db = get_db_dxil()
  775. result = ""
  776. for p in sorted(db.passes, key=lambda p : p.type_name):
  777. # Skip if not in target category.
  778. if (p.category_lib not in category_libs):
  779. continue
  780. result += "initialize%sPass(Registry);\n" % p.type_name
  781. return result
  782. def get_pass_arg_names():
  783. "Return an ArrayRef of argument names based on passName"
  784. db = get_db_dxil()
  785. decl_result = ""
  786. check_result = ""
  787. for p in sorted(db.passes, key=lambda p : p.type_name):
  788. if len(p.args):
  789. decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
  790. check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
  791. sep = ""
  792. for a in p.args:
  793. decl_result += sep + "\"%s\"" % a.name
  794. sep = ", "
  795. decl_result += " };\n"
  796. return decl_result + check_result
  797. def get_pass_arg_descs():
  798. "Return an ArrayRef of argument descriptions based on passName"
  799. db = get_db_dxil()
  800. decl_result = ""
  801. check_result = ""
  802. for p in sorted(db.passes, key=lambda p : p.type_name):
  803. if len(p.args):
  804. decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
  805. check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
  806. sep = ""
  807. for a in p.args:
  808. decl_result += sep + "\"%s\"" % a.doc
  809. sep = ", "
  810. decl_result += " };\n"
  811. return decl_result + check_result
  812. def get_is_pass_option_name():
  813. "Create a return expression to check whether a value 'S' is a pass option name."
  814. db = get_db_dxil()
  815. prefix = ""
  816. result = "return "
  817. for k in sorted(db.pass_idx_args):
  818. result += prefix + "S.equals(\"%s\")" % k
  819. prefix = "\n || "
  820. return result + ";"
  821. def get_opcodes_rst():
  822. "Create an rst table of opcodes"
  823. db = get_db_dxil()
  824. instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
  825. instrs = sorted(instrs, key=lambda v : v.dxil_opid)
  826. rows = []
  827. rows.append(["ID", "Name", "Description"])
  828. for i in instrs:
  829. op_name = i.dxil_op
  830. if i.remarks:
  831. op_name = op_name + "_" # append _ to enable internal hyperlink on rst files
  832. rows.append([i.dxil_opid, op_name, i.doc])
  833. result = "\n\n" + format_rst_table(rows) + "\n\n"
  834. # Add detailed instruction information where available.
  835. instrs = sorted(instrs, key=lambda v : v.name)
  836. for i in instrs:
  837. if i.remarks:
  838. result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
  839. return result + "\n"
  840. def get_valrules_rst():
  841. "Create an rst table of validation rules instructions."
  842. db = get_db_dxil()
  843. rules = [i for i in db.val_rules if not i.is_disabled]
  844. rules = sorted(rules, key=lambda v : v.name)
  845. rows = []
  846. rows.append(["Rule Code", "Description"])
  847. for i in rules:
  848. rows.append([i.name, i.doc])
  849. return "\n\n" + format_rst_table(rows) + "\n\n"
  850. def get_opsigs():
  851. # Create a list of DXIL operation signatures, sorted by ID.
  852. db = get_db_dxil()
  853. instrs = [i for i in db.instr if i.is_dxil_op]
  854. instrs = sorted(instrs, key=lambda v : v.dxil_opid)
  855. # db_dxil already asserts that the numbering is dense.
  856. # Create the code to write out.
  857. code = "static const char *OpCodeSignatures[] = {\n"
  858. for inst_idx,i in enumerate(instrs):
  859. code += " \"("
  860. for operand in i.ops:
  861. if operand.pos > 1: # skip 0 (the return value) and 1 (the opcode itself)
  862. code += operand.name
  863. if operand.pos < len(i.ops) - 1:
  864. code += ","
  865. code += ")\""
  866. if inst_idx < len(instrs) - 1:
  867. code += ","
  868. code += " // " + i.name
  869. code += "\n"
  870. code += "};\n"
  871. return code
  872. shader_stage_to_ShaderKind = {
  873. 'vertex': 'Vertex',
  874. 'pixel': 'Pixel',
  875. 'geometry': 'Geometry',
  876. 'compute': 'Compute',
  877. 'hull': 'Hull',
  878. 'domain': 'Domain',
  879. 'library': 'Library',
  880. 'raygeneration': 'RayGeneration',
  881. 'intersection': 'Intersection',
  882. 'anyhit': 'AnyHit',
  883. 'closesthit': 'ClosestHit',
  884. 'miss': 'Miss',
  885. 'callable': 'Callable',
  886. 'mesh' : 'Mesh',
  887. 'amplification' : 'Amplification',
  888. }
  889. def get_min_sm_and_mask_text():
  890. db = get_db_dxil()
  891. instrs = [i for i in db.instr if i.is_dxil_op]
  892. instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_model_translated, v.shader_stages, v.dxil_opid))
  893. last_model = None
  894. last_model_translated = None
  895. last_stage = None
  896. grouped_instrs = []
  897. code = ""
  898. def flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage):
  899. if len(grouped_instrs) == 0:
  900. return ""
  901. result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
  902. result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ") {\n"
  903. default = True
  904. if last_model != (6,0):
  905. default = False
  906. if last_model_translated:
  907. result += " if (bWithTranslation) {\n"
  908. result += " major = %d; minor = %d;\n } else {\n " % last_model_translated
  909. result += " major = %d; minor = %d;\n" % last_model
  910. if last_model_translated:
  911. result += " }\n"
  912. if last_stage:
  913. default = False
  914. result += " mask = %s;\n" % ' | '.join([ 'SFLAG(%s)' % shader_stage_to_ShaderKind[c]
  915. for c in last_stage
  916. ])
  917. if default:
  918. # don't write these out, instead fall through
  919. return ""
  920. return result + " return;\n}\n"
  921. for i in instrs:
  922. if ((i.shader_model, i.shader_model_translated, i.shader_stages) !=
  923. (last_model, last_model_translated, last_stage)):
  924. code += flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage)
  925. grouped_instrs = []
  926. last_model = i.shader_model
  927. last_model_translated = i.shader_model_translated
  928. last_stage = i.shader_stages
  929. grouped_instrs.append(i)
  930. code += flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage)
  931. return code
  932. check_pSM_for_shader_stage = {
  933. 'vertex': 'SK == DXIL::ShaderKind::Vertex',
  934. 'pixel': 'SK == DXIL::ShaderKind::Pixel',
  935. 'geometry': 'SK == DXIL::ShaderKind::Geometry',
  936. 'compute': 'SK == DXIL::ShaderKind::Compute',
  937. 'hull': 'SK == DXIL::ShaderKind::Hull',
  938. 'domain': 'SK == DXIL::ShaderKind::Domain',
  939. 'library': 'SK == DXIL::ShaderKind::Library',
  940. 'raygeneration': 'SK == DXIL::ShaderKind::RayGeneration',
  941. 'intersection': 'SK == DXIL::ShaderKind::Intersection',
  942. 'anyhit': 'SK == DXIL::ShaderKind::AnyHit',
  943. 'closesthit': 'SK == DXIL::ShaderKind::ClosestHit',
  944. 'miss': 'SK == DXIL::ShaderKind::Miss',
  945. 'callable': 'SK == DXIL::ShaderKind::Callable',
  946. 'mesh': 'SK == DXIL::ShaderKind::Mesh',
  947. 'amplification': 'SK == DXIL::ShaderKind::Amplification',
  948. }
  949. def get_valopcode_sm_text():
  950. db = get_db_dxil()
  951. instrs = [i for i in db.instr if i.is_dxil_op]
  952. instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_stages, v.dxil_opid))
  953. last_model = None
  954. last_stage = None
  955. grouped_instrs = []
  956. code = ""
  957. def flush_instrs(grouped_instrs, last_model, last_stage):
  958. if len(grouped_instrs) == 0:
  959. return ""
  960. result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
  961. result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ")\n"
  962. result += " return "
  963. model_cond = stage_cond = None
  964. if last_model != (6,0):
  965. model_cond = "major > %d || (major == %d && minor >= %d)" % (
  966. last_model[0], last_model[0], last_model[1])
  967. if last_stage:
  968. stage_cond = ' || '.join([check_pSM_for_shader_stage[c] for c in last_stage])
  969. if model_cond or stage_cond:
  970. result += '\n && '.join(
  971. ["(%s)" % expr for expr in (model_cond, stage_cond) if expr] )
  972. return result + ";\n"
  973. else:
  974. # don't write these out, instead fall through
  975. return ""
  976. for i in instrs:
  977. if (i.shader_model, i.shader_stages) != (last_model, last_stage):
  978. code += flush_instrs(grouped_instrs, last_model, last_stage)
  979. grouped_instrs = []
  980. last_model = i.shader_model
  981. last_stage = i.shader_stages
  982. grouped_instrs.append(i)
  983. code += flush_instrs(grouped_instrs, last_model, last_stage)
  984. code += "return true;\n"
  985. return code
  986. def get_sigpoint_table():
  987. db = get_db_dxil()
  988. gen = db_sigpoint_gen(db)
  989. return run_with_stdout(lambda: gen.print_sigpoint_table())
  990. def get_sigpoint_rst():
  991. "Create an rst table for SigPointKind."
  992. db = get_db_dxil()
  993. rows = [row[:] for row in db.sigpoint_table[:-1]] # Copy table
  994. e = dict([(v.name, v) for v in db.enum_idx['SigPointKind'].values])
  995. rows[0] = ['ID'] + rows[0] + ['Description']
  996. for i in range(1, len(rows)):
  997. row = rows[i]
  998. v = e[row[0]]
  999. rows[i] = [v.value] + row + [v.doc]
  1000. return "\n\n" + format_rst_table(rows) + "\n\n"
  1001. def get_sem_interpretation_enum_rst():
  1002. db = get_db_dxil()
  1003. rows = ([['ID', 'Name', 'Description']] +
  1004. [[v.value, v.name, v.doc]
  1005. for v in db.enum_idx['SemanticInterpretationKind'].values[:-1]])
  1006. return "\n\n" + format_rst_table(rows) + "\n\n"
  1007. def get_sem_interpretation_table_rst():
  1008. db = get_db_dxil()
  1009. return "\n\n" + format_rst_table(db.interpretation_table) + "\n\n"
  1010. def get_interpretation_table():
  1011. db = get_db_dxil()
  1012. gen = db_sigpoint_gen(db)
  1013. return run_with_stdout(lambda: gen.print_interpretation_table())
  1014. highest_major = 6
  1015. highest_minor = 7
  1016. highest_shader_models = {4:1, 5:1, 6:highest_minor}
  1017. def getShaderModels():
  1018. shader_models = []
  1019. for major, minor in highest_shader_models.items():
  1020. for i in range(0, minor+1):
  1021. shader_models.append(str(major) + "_" + str(i))
  1022. return shader_models;
  1023. def get_highest_shader_model():
  1024. result = """static const unsigned kHighestMajor = %d;
  1025. static const unsigned kHighestMinor = %d;"""%(highest_major, highest_minor)
  1026. return result
  1027. def get_dxil_version_minor():
  1028. return "const unsigned kDxilMinor = %d;"%highest_minor
  1029. def get_is_shader_model_plus():
  1030. result = ""
  1031. for i in range(0, highest_minor+1):
  1032. result += "bool IsSM%d%dPlus() const { return IsSMAtLeast(%d, %d); }\n"%(highest_major, i,highest_major, i)
  1033. return result
  1034. profile_to_kind = {"ps":"Kind::Pixel", "vs":"Kind::Vertex", "gs":"Kind::Geometry", "hs":"5_0", "ds":"5_0", "cs":"4_0", "lib":"6_1", "ms":"6_5", "as":"6_5"}
  1035. class shader_profile(object):
  1036. "The profile description for a DXIL instruction"
  1037. def __init__(self, kind, kind_name, enum_name, start_sm, input_size, output_size):
  1038. self.kind = kind # position in parameter list
  1039. self.kind_name = kind_name
  1040. self.enum_name = enum_name
  1041. self.start_sm = start_sm
  1042. self.input_size = input_size
  1043. self.output_size = output_size
  1044. # kind is from DXIL::ShaderKind.
  1045. shader_profiles = [ shader_profile(0, "ps", "Kind::Pixel", "4_0", 32, 8),
  1046. shader_profile(1, "vs", "Kind::Vertex", "4_0", 32, 32),
  1047. shader_profile(2, "gs", "Kind::Geometry", "4_0", 32, 32),
  1048. shader_profile(3, "hs", "Kind::Hull", "5_0", 32, 32),
  1049. shader_profile(4, "ds", "Kind::Domain", "5_0", 32, 32),
  1050. shader_profile(5, "cs", "Kind::Compute", "4_0", 0,0),
  1051. shader_profile(6, "lib", "Kind::Library", "6_1", 32,32),
  1052. shader_profile(13, "ms", "Kind::Mesh", "6_5", 0,0),
  1053. shader_profile(14, "as", "Kind::Amplification", "6_5", 0,0),
  1054. ]
  1055. def getShaderProfiles():
  1056. # order match DXIL::ShaderKind.
  1057. profiles = {"ps":"4_0", "vs":"4_0", "gs":"4_0", "hs":"5_0", "ds":"5_0", "cs":"4_0", "lib":"6_1", "ms":"6_5", "as":"6_5"}
  1058. return profiles;
  1059. def get_shader_models():
  1060. result = ""
  1061. for profile in shader_profiles:
  1062. min_sm = profile.start_sm
  1063. input_size = profile.input_size
  1064. output_size = profile.output_size
  1065. kind = profile.kind
  1066. kind_name = profile.kind_name
  1067. enum_name = profile.enum_name
  1068. for major, minor in highest_shader_models.items():
  1069. UAV_info = "true, true, UINT_MAX"
  1070. if major > 5:
  1071. pass
  1072. elif major == 4:
  1073. UAV_info = "false, false, 0"
  1074. if kind == "cs":
  1075. UAV_info = "true, false, 1"
  1076. elif major == 5:
  1077. UAV_info = "true, true, 64"
  1078. for i in range(0, minor+1):
  1079. sm = "%d_%d"%(major, i)
  1080. if (min_sm > sm):
  1081. continue
  1082. input_size = profile.input_size
  1083. output_size = profile.output_size
  1084. if major == 4:
  1085. if i == 0:
  1086. if kind_name == "gs":
  1087. input_size = 16
  1088. elif kind_name == "vs":
  1089. input_size = 16
  1090. output_size = 16
  1091. sm_name = "%s_%s"%(kind_name,sm)
  1092. result += "SM(%s, %d, %d, \"%s\", %d, %d, %s),\n" % (enum_name, major, i, sm_name, input_size, output_size, UAV_info)
  1093. if kind_name == "lib":
  1094. result += "// lib_6_x is for offline linking only, and relaxes restrictions\n"
  1095. result += "SM(Kind::Library, 6, kOfflineMinor, \"lib_6_x\", 32, 32, true, true, UINT_MAX),\n"
  1096. result += "// Values before Invalid must remain sorted by Kind, then Major, then Minor.\n"
  1097. result += "SM(Kind::Invalid, 0, 0, \"invalid\", 0, 0, false, false, 0),\n"
  1098. return result
  1099. def get_num_shader_models():
  1100. count = 0
  1101. for profile in shader_profiles:
  1102. min_sm = profile.start_sm
  1103. input_size = profile.input_size
  1104. output_size = profile.output_size
  1105. kind = profile.kind
  1106. kind_name = profile.kind_name
  1107. enum_name = profile.enum_name
  1108. for major, minor in highest_shader_models.items():
  1109. for i in range(0, minor+1):
  1110. sm = "%d_%d"%(major, i)
  1111. if (min_sm > sm):
  1112. continue
  1113. count += 1
  1114. if kind_name == "lib":
  1115. # for lib_6_x
  1116. count += 1
  1117. # for invalid shader_model.
  1118. count += 1
  1119. return "static const unsigned kNumShaderModels = %d;"%count
  1120. def build_shader_model_hash_idx_map():
  1121. #must match get_shader_models.
  1122. result = "const static std::unordered_map<unsigned, unsigned> hashToIdxMap = {\n"
  1123. count = 0
  1124. for profile in shader_profiles:
  1125. min_sm = profile.start_sm
  1126. kind = profile.kind
  1127. kind_name = profile.kind_name
  1128. for major, minor in highest_shader_models.items():
  1129. for i in range(0, minor+1):
  1130. sm = "%d_%d"%(major, i)
  1131. if (min_sm > sm):
  1132. continue
  1133. sm_name = "%s_%s"%(kind_name,sm)
  1134. hash_v = kind << 16 | major << 8 | i;
  1135. result += "{%d,%d}, //%s\n" % (hash_v, count, sm_name)
  1136. count += 1
  1137. if kind_name == "lib":
  1138. result += "// lib_6_x is for offline linking only, and relaxes restrictions\n"
  1139. major = 6
  1140. #static const unsigned kOfflineMinor = 0xF;
  1141. i = 15
  1142. hash_v = kind << 16 | major << 8 | i;
  1143. result += "{%d,%d},//%s\n" % (hash_v, count, "lib_6_x")
  1144. count += 1
  1145. result += "};\n"
  1146. return result
  1147. def get_validation_version():
  1148. result = """// 1.0 is the first validator.
  1149. // 1.1 adds:
  1150. // - ILDN container part support
  1151. // 1.2 adds:
  1152. // - Metadata for floating point denorm mode
  1153. // 1.3 adds:
  1154. // - Library support
  1155. // - Raytracing support
  1156. // - i64/f64 overloads for rawBufferLoad/Store
  1157. // 1.4 adds:
  1158. // - packed u8x4/i8x4 dot with accumulate to i32
  1159. // - half dot2 with accumulate to float
  1160. // 1.5 adds:
  1161. // - WaveMatch, WaveMultiPrefixOp, WaveMultiPrefixBitCount
  1162. // - HASH container part support
  1163. // - Mesh and Amplification shaders
  1164. // - DXR 1.1 & RayQuery support
  1165. *pMajor = 1;
  1166. *pMinor = %d;
  1167. """ % highest_minor
  1168. return result
  1169. def get_target_profiles():
  1170. result = "HelpText<\"Set target profile. \\n"
  1171. result += "\\t<profile>: "
  1172. profiles = getShaderProfiles()
  1173. shader_models = getShaderModels()
  1174. base_sm = "%d_0"%highest_major
  1175. for profile, min_sm in profiles.items():
  1176. for shader_model in shader_models:
  1177. if (base_sm > shader_model):
  1178. continue
  1179. if (min_sm > shader_model):
  1180. continue
  1181. result += "%s_%s, "%(profile,shader_model)
  1182. result += "\\n\\t\\t "
  1183. result += "\">;"
  1184. return result
  1185. def get_min_validator_version():
  1186. result = ""
  1187. for i in range(0, highest_minor+1):
  1188. result += "case %d:\n"%i
  1189. result += " ValMinor = %d;\n"%i
  1190. result += " break;\n"
  1191. return result
  1192. def get_dxil_version():
  1193. result = ""
  1194. for i in range(0, highest_minor+1):
  1195. result += "case %d:\n"%i
  1196. result += " DxilMinor = %d;\n"%i
  1197. result += " break;\n"
  1198. result += "case kOfflineMinor: // Always update this to highest dxil version\n"
  1199. result += " DxilMinor = %d;\n"%highest_minor
  1200. result += " break;\n"
  1201. return result
  1202. def get_shader_model_get():
  1203. # const static std::unordered_map<unsigned, unsigned> hashToIdxMap = {};
  1204. result = build_shader_model_hash_idx_map()
  1205. result += "unsigned hash = (unsigned)Kind << 16 | Major << 8 | Minor;\n"
  1206. result += "auto it = hashToIdxMap.find(hash);\n"
  1207. result += "if (it == hashToIdxMap.end())\n"
  1208. result += " return GetInvalid();\n"
  1209. result += "return &ms_ShaderModels[it->second];"
  1210. return result
  1211. def get_shader_model_by_name():
  1212. result = ""
  1213. for i in range(2, highest_minor+1):
  1214. result += "case '%d':\n"%i
  1215. result += " if (Major == %d) {\n"%highest_major
  1216. result += " Minor = %d;\n"%i
  1217. result += " break;\n"
  1218. result += " }\n"
  1219. result += "else return GetInvalid();\n"
  1220. return result
  1221. def get_is_valid_for_dxil():
  1222. result = ""
  1223. for i in range(0, highest_minor+1):
  1224. result += "case %d:\n"%i
  1225. return result
  1226. def RunCodeTagUpdate(file_path):
  1227. import os
  1228. import CodeTags
  1229. print(" ... updating " + file_path)
  1230. args = [file_path, file_path + ".tmp"]
  1231. result = CodeTags.main(args)
  1232. if result != 0:
  1233. print(" ... error: %d" % result)
  1234. else:
  1235. with open(file_path, 'rt') as f:
  1236. before = f.read()
  1237. with open(file_path + ".tmp", 'rt') as f:
  1238. after = f.read()
  1239. if before == after:
  1240. print(" --- no changes found")
  1241. else:
  1242. print(" +++ changes found, updating file")
  1243. with open(file_path, 'wt') as f:
  1244. f.write(after)
  1245. os.remove(file_path + ".tmp")
  1246. if __name__ == "__main__":
  1247. parser = argparse.ArgumentParser(description="Generate code to handle instructions.")
  1248. parser.add_argument("-gen", choices=["docs-ref", "docs-spec", "inst-header", "enums", "oloads", "valfns"], help="Output type to generate.")
  1249. parser.add_argument("-update-files", action="store_const", const=True)
  1250. args = parser.parse_args()
  1251. db = get_db_dxil() # used by all generators, also handy to have it run validation
  1252. if args.gen == "docs-ref":
  1253. gen = db_docsref_gen(db)
  1254. gen.print_content()
  1255. if args.gen == "docs-spec":
  1256. import os, docutils.core
  1257. assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
  1258. hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
  1259. spec_file = os.path.abspath(os.path.join(hlsl_src_dir, "docs/DXIL.rst"))
  1260. with open(spec_file) as f:
  1261. s = docutils.core.publish_file(f, writer_name="html")
  1262. if args.gen == "inst-header":
  1263. gen = db_instrhelp_gen(db)
  1264. gen.print_content()
  1265. if args.gen == "enums":
  1266. gen = db_enumhelp_gen(db)
  1267. gen.print_content()
  1268. if args.gen == "oloads":
  1269. gen = db_oload_gen(db)
  1270. gen.print_content()
  1271. if args.gen == "valfns":
  1272. gen = db_valfns_gen(db)
  1273. gen.print_content()
  1274. if args.update_files:
  1275. print("Updating files ...")
  1276. import CodeTags
  1277. import os
  1278. assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
  1279. hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
  1280. pj = lambda *parts: os.path.abspath(os.path.join(*parts))
  1281. files = [
  1282. 'docs/DXIL.rst',
  1283. 'lib/DXIL/DXILOperations.cpp',
  1284. 'lib/DXIL/DXILShaderModel.cpp',
  1285. 'include/dxc/DXIL/DXILConstants.h',
  1286. 'include/dxc/DXIL/DXILShaderModel.h',
  1287. 'include/dxc/HLSL/DxilValidation.h',
  1288. 'include/dxc/Support/HLSLOptions.td',
  1289. 'include/dxc/DXIL/DxilInstructions.h',
  1290. 'lib/HLSL/DxcOptimizer.cpp',
  1291. 'lib/DxilPIXPasses/DxilPIXPasses.cpp',
  1292. 'lib/HLSL/DxilValidation.cpp',
  1293. 'tools/clang/lib/Sema/gen_intrin_main_tables_15.h',
  1294. 'include/dxc/HlslIntrinsicOp.h',
  1295. 'tools/clang/tools/dxcompiler/dxcdisassembler.cpp',
  1296. 'include/dxc/DXIL/DxilSigPoint.inl',
  1297. 'include/dxc/DXIL/DxilCounters.h',
  1298. 'lib/DXIL/DxilCounters.cpp',
  1299. 'lib/DXIL/DxilMetadataHelper.cpp',
  1300. ]
  1301. for relative_file_path in files:
  1302. RunCodeTagUpdate(pj(hlsl_src_dir, relative_file_path))