| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428 |
- # Copyright (C) Microsoft Corporation. All rights reserved.
- # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
- import argparse
- import functools
- import collections
- from hctdb import *
- # get db singletons
- g_db_dxil = None
- def get_db_dxil():
- global g_db_dxil
- if g_db_dxil is None:
- g_db_dxil = db_dxil()
- return g_db_dxil
- g_db_hlsl = None
- def get_db_hlsl():
- global g_db_hlsl
- if g_db_hlsl is None:
- thisdir = os.path.dirname(os.path.realpath(__file__))
- with open(os.path.join(thisdir, "gen_intrin_main.txt"), "r") as f:
- g_db_hlsl = db_hlsl(f)
- return g_db_hlsl
- def format_comment(prefix, val):
- "Formats a value with a line-comment prefix."
- result = ""
- line_width = 80
- content_width = line_width - len(prefix)
- l = len(val)
- while l:
- if l < content_width:
- result += prefix + val.strip()
- result += "\n"
- l = 0
- else:
- split_idx = val.rfind(" ", 0, content_width)
- result += prefix + val[:split_idx].strip()
- result += "\n"
- val = val[split_idx+1:]
- l = len(val)
- return result
- def format_rst_table(list_of_tuples):
- "Produces a reStructuredText simple table from the specified list of tuples."
- # Calculate widths.
- widths = None
- for t in list_of_tuples:
- if widths is None:
- widths = [0] * len(t)
- for i, v in enumerate(t):
- widths[i] = max(widths[i], len(str(v)))
- # Build banner line.
- banner = ""
- for i, w in enumerate(widths):
- if i > 0:
- banner += " "
- banner += "=" * w
- banner += "\n"
- # Build the result.
- result = banner
- for i, t in enumerate(list_of_tuples):
- for j, v in enumerate(t):
- if j > 0:
- result += " "
- result += str(v)
- result += " " * (widths[j] - len(str(v)))
- result = result.rstrip()
- result += "\n"
- if i == 0:
- result += banner
- result += banner
- return result
- def build_range_tuples(i):
- "Produces a list of tuples with contiguous ranges in the input list."
- i = sorted(i)
- low_bound = None
- high_bound = None
- for val in i:
- if low_bound is None:
- low_bound = val
- high_bound = val
- else:
- assert(not high_bound is None)
- if val == high_bound + 1:
- high_bound = val
- else:
- yield (low_bound, high_bound)
- low_bound = val
- high_bound = val
- if not low_bound is None:
- yield (low_bound, high_bound)
- def build_range_code(var, i):
- "Produces a fragment of code that tests whether the variable name matches values in the given range."
- ranges = build_range_tuples(i)
- result = ""
- for r in ranges:
- if r[0] == r[1]:
- cond = var + " == " + str(r[0])
- else:
- cond = "(%d <= %s && %s <= %d)" % (r[0], var, var, r[1])
- if result == "":
- result = cond
- else:
- result = result + " || " + cond
- return result
- class db_docsref_gen:
- "A generator of reference documentation."
- def __init__(self, db):
- self.db = db
- instrs = [i for i in self.db.instr if i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
- self.instrs = instrs
- val_rules = sorted(db.val_rules, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
- self.val_rules = val_rules
- def print_content(self):
- self.print_header()
- self.print_body()
- self.print_footer()
- def print_header(self):
- print("<!DOCTYPE html>")
- print("<html><head><title>DXIL Reference</title>")
- print("<style>body { font-family: Verdana; font-size: small; }</style>")
- print("</head><body><h1>DXIL Reference</h1>")
- self.print_toc("Instructions", "i", self.instrs)
- self.print_toc("Rules", "r", self.val_rules)
- def print_body(self):
- self.print_instruction_details()
- self.print_valrule_details()
- def print_instruction_details(self):
- print("<h2>Instruction Details</h2>")
- for i in self.instrs:
- print("<h3><a name='i%s'>%s</a></h3>" % (i.name, i.name))
- print("<div>Opcode: %d. This instruction %s.</div>" % (i.dxil_opid, i.doc))
- if i.remarks:
- # This is likely a .rst fragment, but this will do for now.
- print("<div> " + i.remarks + "</div>")
- print("<div>Operands:</div>")
- print("<ul>")
- for o in i.ops:
- if o.pos == 0:
- print("<li>result: %s - %s</li>" % (o.llvm_type, o.doc))
- else:
- enum_desc = "" if o.enum_name == "" else " one of %s: %s" % (o.enum_name, ",".join(db.enum_idx[o.enum_name].value_names()))
- print("<li>%d - %s: %s%s%s</li>" % (o.pos - 1, o.name, o.llvm_type, "" if o.doc == "" else " - " + o.doc, enum_desc))
- print("</ul>")
- print("<div><a href='#Instructions'>(top)</a></div>")
- def print_valrule_details(self):
- print("<h2>Rule Details</h2>")
- for i in self.val_rules:
- print("<h3><a name='r%s'>%s</a></h3>" % (i.name, i.name))
- print("<div>" + i.doc + "</div>")
- print("<div><a href='#Rules'>(top)</a></div>")
- def print_toc(self, name, aprefix, values):
- print("<h2><a name='" + name + "'>" + name + "</a></h2>")
- last_category = ""
- for i in values:
- if i.category != last_category:
- if last_category != None:
- print("</ul>")
- print("<div><b>%s</b></div><ul>" % i.category)
- last_category = i.category
- print("<li><a href='#" + aprefix + "%s'>%s</a></li>" % (i.name, i.name))
- print("</ul>")
- def print_footer(self):
- print("</body></html>")
- class db_instrhelp_gen:
- "A generator of instruction helper classes."
- def __init__(self, db):
- self.db = db
- TypeInfo = collections.namedtuple("TypeInfo", "name bits")
- self.llvm_type_map = {
- "i1": TypeInfo("bool", 1),
- "i8": TypeInfo("int8_t", 8),
- "u8": TypeInfo("uint8_t", 8),
- "i32": TypeInfo("int32_t", 32),
- "u32": TypeInfo("uint32_t", 32)
- }
- self.IsDxilOpFuncCallInst = "hlsl::OP::IsDxilOpFuncCallInst"
- def print_content(self):
- self.print_header()
- self.print_body()
- self.print_footer()
- def print_header(self):
- print("///////////////////////////////////////////////////////////////////////////////")
- print("// //")
- print("// Copyright (C) Microsoft Corporation. All rights reserved. //")
- print("// DxilInstructions.h //")
- print("// //")
- print("// This file provides a library of instruction helper classes. //")
- print("// //")
- print("// MUCH WORK YET TO BE DONE - EXPECT THIS WILL CHANGE - GENERATED FILE //")
- print("// //")
- print("///////////////////////////////////////////////////////////////////////////////")
- print("")
- print("// TODO: add correct include directives")
- print("// TODO: add accessors with values")
- print("// TODO: add validation support code, including calling into right fn")
- print("// TODO: add type hierarchy")
- print("namespace hlsl {")
- def bool_lit(self, val):
- return "true" if val else "false";
- def op_type(self, o):
- if o.llvm_type in self.llvm_type_map:
- return self.llvm_type_map[o.llvm_type].name
- raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
- def op_size(self, o):
- if o.llvm_type in self.llvm_type_map:
- return self.llvm_type_map[o.llvm_type].bits
- raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
- def op_const_expr(self, o):
- return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
- def op_set_const_expr(self, o):
- type_size = self.op_size(o)
- return "llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), %d), llvm::APInt(%d, (uint64_t)val))" % (type_size, type_size)
- def print_body(self):
- for i in self.db.instr:
- if i.is_reserved: continue
- if i.inst_helper_prefix:
- struct_name = "%s_%s" % (i.inst_helper_prefix, i.name)
- elif i.is_dxil_op:
- struct_name = "DxilInst_%s" % i.name
- else:
- struct_name = "LlvmInst_%s" % i.name
- if i.doc:
- print("/// This instruction %s" % i.doc)
- print("struct %s {" % struct_name)
- print(" llvm::Instruction *Instr;")
- print(" // Construction and identification")
- print(" %s(llvm::Instruction *pInstr) : Instr(pInstr) {}" % struct_name)
- print(" operator bool() const {")
- if i.is_dxil_op:
- op_name = i.fully_qualified_name()
- print(" return %s(Instr, %s);" % (self.IsDxilOpFuncCallInst, op_name))
- else:
- print(" return Instr->getOpcode() == llvm::Instruction::%s;" % i.name)
- print(" }")
- print(" // Validation support")
- print(" bool isAllowed() const { return %s; }" % self.bool_lit(i.is_allowed))
- if i.is_dxil_op:
- print(" bool isArgumentListValid() const {")
- print(" if (%d != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;" % (len(i.ops) - 1))
- print(" return true;")
- # TODO - check operand types
- print(" }")
- print(" // Metadata")
- print(" bool requiresUniformInputs() const { return %s; }" % self.bool_lit(i.requires_uniform_inputs))
- EnumWritten = False
- for o in i.ops:
- if o.pos > 1: # 0 is return type, 1 is DXIL OP id
- if not EnumWritten:
- print(" // Operand indexes")
- print(" enum OperandIdx {")
- EnumWritten = True
- print(" arg_%s = %d," % (o.name, o.pos - 1))
- if EnumWritten:
- print(" };")
- AccessorsWritten = False
- for o in i.ops:
- if o.pos > 1: # 0 is return type, 1 is DXIL OP id
- if not AccessorsWritten:
- print(" // Accessors")
- AccessorsWritten = True
- print(" llvm::Value *get_%s() const { return Instr->getOperand(%d); }" % (o.name, o.pos - 1))
- print(" void set_%s(llvm::Value *val) { Instr->setOperand(%d, val); }" % (o.name, o.pos - 1))
- if o.is_const:
- print(" %s get_%s_val() const { return %s; }" % (self.op_type(o), o.name, self.op_const_expr(o)))
- print(" void set_%s_val(%s val) { Instr->setOperand(%d, %s); }" % (o.name, self.op_type(o), o.pos - 1, self.op_set_const_expr(o)))
- print("};")
- print("")
- def print_footer(self):
- print("} // namespace hlsl")
- class db_enumhelp_gen:
- "A generator of enumeration declarations."
- def __init__(self, db):
- self.db = db
- # Some enums should get a last enum marker.
- self.lastEnumNames = {
- "OpCode": "NumOpCodes",
- "OpCodeClass": "NumOpClasses"
- }
- def print_enum(self, e, **kwargs):
- print("// %s" % e.doc)
- print("enum class %s : unsigned {" % e.name)
- hide_val = kwargs.get("hide_val", False)
- sorted_values = e.values
- if kwargs.get("sort_val", True):
- sorted_values = sorted(e.values, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
- last_category = None
- for v in sorted_values:
- if v.category != last_category:
- if last_category != None:
- print("")
- print(" // %s" % v.category)
- last_category = v.category
- line_format = " {name}"
- if not e.is_internal and not hide_val:
- line_format += " = {value}"
- line_format += ","
- if v.doc:
- line_format += " // {doc}"
- print(line_format.format(name=v.name, value=v.value, doc=v.doc))
- if e.name in self.lastEnumNames:
- lastName = self.lastEnumNames[e.name]
- versioned = ["%s_Dxil_%d_%d = %d," % (lastName, major, minor, info[lastName])
- for (major, minor), info in sorted(self.db.dxil_version_info.items())
- if lastName in info]
- if versioned:
- print("")
- for val in versioned:
- print(" " + val)
- print("")
- print(" " + lastName + " = " + str(len(sorted_values)) + " // exclusive last value of enumeration")
- print("};")
- def print_content(self):
- for e in sorted(self.db.enums, key=lambda e : e.name):
- self.print_enum(e)
- class db_oload_gen:
- "A generator of overload tables."
- def __init__(self, db):
- self.db = db
- instrs = [i for i in self.db.instr if i.is_dxil_op]
- self.instrs = sorted(instrs, key=lambda i : i.dxil_opid)
- def print_content(self):
- self.print_opfunc_props()
- print("...")
- self.print_opfunc_table()
- def print_opfunc_props(self):
- print("const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {")
- print("// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64, udt, obj, function attribute")
- # Example formatted string:
- # { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, },
- # 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
- # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
- last_category = None
- # overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong, u(dt)
- f = lambda i,c : "true" if i.oload_types.find(c) >= 0 else "false"
- lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
- lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
- attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn", "wv" : "None" }
- attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
- for i in self.instrs:
- if last_category != i.category:
- if last_category != None:
- print("")
- print(" // {category:118} void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute".format(category=i.category))
- last_category = i.category
- print(" {{ OC::{name:24} {quotName:27} OCC::{className:25} {classNameQuot:28} {{{v:>6},{h:>6},{f:>6},{d:>6},{b:>6},{e:>6},{w:>6},{i:>6},{l:>6},{u:>6},{o:>6}}}, {attr:20} }},".format(
- name=i.name+",", quotName='"'+i.name+'",', className=i.dxil_class+",", classNameQuot='"'+lower_fn(i.dxil_class)+'",',
- v=f(i,"v"), h=f(i,"h"), f=f(i,"f"), d=f(i,"d"), b=f(i,"1"), e=f(i,"8"), w=f(i,"w"), i=f(i,"i"), l=f(i,"l"), u=f(i,"u"), o=f(i,"o"), attr=attr_fn(i)))
- print("};")
- def print_opfunc_table(self):
- # Print the table for OP::GetOpFunc
- op_type_texts = {
- "$cb": "CBRT(pETy);",
- "$o": "A(pETy);",
- "$r": "RRT(pETy);",
- "d": "A(pF64);",
- "dims": "A(pDim);",
- "f": "A(pF32);",
- "h": "A(pF16);",
- "i1": "A(pI1);",
- "i16": "A(pI16);",
- "i32": "A(pI32);",
- "i32c": "A(pI32C);",
- "i64": "A(pI64);",
- "i8": "A(pI8);",
- "$u4": "A(pI4S);",
- "pf32": "A(pPF32);",
- "res": "A(pRes);",
- "splitdouble": "A(pSDT);",
- "twoi32": "A(p2I32);",
- "twof32": "A(p2F32);",
- "twof16": "A(p2F16);",
- "twoi16": "A(p2I16);",
- "threei32": "A(p3I32);",
- "threef32": "A(p3F32);",
- "fouri32": "A(p4I32);",
- "fourf32": "A(p4F32);",
- "u32": "A(pI32);",
- "u64": "A(pI64);",
- "u8": "A(pI8);",
- "v": "A(pV);",
- "w": "A(pWav);",
- "SamplePos": "A(pPos);",
- "udt": "A(udt);",
- "obj": "A(obj);",
- "resproperty": "A(resProperty);",
- }
- last_category = None
- for i in self.instrs:
- if last_category != i.category:
- if last_category != None:
- print("")
- print(" // %s" % i.category)
- last_category = i.category
- line = " case OpCode::{name:24}".format(name = i.name + ":")
- for index, o in enumerate(i.ops):
- assert o.llvm_type in op_type_texts, "llvm type %s in instruction %s is unknown" % (o.llvm_type, i.name)
- op_type_text = op_type_texts[o.llvm_type]
- if index == 0:
- line = line + "{val:13}".format(val=op_type_text)
- else:
- line = line + "{val:9}".format(val=op_type_text)
- line = line + "break;"
- print(line)
- def print_opfunc_oload_type(self):
- # Print the function for OP::GetOverloadType
- elt_ty = "$o"
- res_ret_ty = "$r"
- cb_ret_ty = "$cb"
- udt_ty = "udt"
- obj_ty = "obj"
- last_category = None
- index_dict = collections.OrderedDict()
- single_dict = collections.OrderedDict()
- struct_list = []
- for instr in self.instrs:
- ret_ty = instr.ops[0].llvm_type
- # Skip case return type is overload type
- if (ret_ty == elt_ty):
- continue
- if ret_ty == res_ret_ty:
- struct_list.append(instr.name)
- continue
- if ret_ty == cb_ret_ty:
- struct_list.append(instr.name)
- continue
- in_param_ty = False
- # Try to find elt_ty in parameter types.
- for index, op in enumerate(instr.ops):
- # Skip return type.
- if (op.pos == 0):
- continue
- # Skip dxil opcode.
- if (op.pos == 1):
- continue
- op_type = op.llvm_type
- if (op_type == elt_ty):
- # Skip return op
- index = index - 1
- if index not in index_dict:
- index_dict[index] = [instr.name]
- else:
- index_dict[index].append(instr.name)
- in_param_ty = True
- break
- if (op_type == udt_ty or op_type == obj_ty):
- # Skip return op
- index = index - 1
- if index not in index_dict:
- index_dict[index] = [instr.name]
- else:
- index_dict[index].append(instr.name)
- in_param_ty = True
- if in_param_ty:
- continue
- # No overload, just return the single oload_type.
- assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
- ty = instr.oload_types[0]
- type_code_texts = {
- "d": "Type::getDoubleTy(m_Ctx)",
- "f": "Type::getFloatTy(m_Ctx)",
- "h": "Type::getHalfTy",
- "1": "IntegerType::get(m_Ctx, 1)",
- "8": "IntegerType::get(m_Ctx, 8)",
- "w": "IntegerType::get(m_Ctx, 16)",
- "i": "IntegerType::get(m_Ctx, 32)",
- "l": "IntegerType::get(m_Ctx, 64)",
- "v": "Type::getVoidTy(m_Ctx)",
- "u": "Type::getInt32PtrTy(m_Ctx)",
- "o": "Type::getInt32PtrTy(m_Ctx)",
- }
- assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
- ty_code = type_code_texts[ty]
- if ty_code not in single_dict:
- single_dict[ty_code] = [instr.name]
- else:
- single_dict[ty_code].append(instr.name)
- for index, opcodes in index_dict.items():
- line = ""
- for opcode in opcodes:
- line = line + "case OpCode::{name}".format(name = opcode + ":\n")
- line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n"
- line = line + " return FT->getParamType(" + str(index) + ");"
- print(line)
- for code, opcodes in single_dict.items():
- line = ""
- for opcode in opcodes:
- line = line + "case OpCode::{name}".format(name = opcode + ":\n")
- line = line + " return " + code + ";"
- print(line)
- line = ""
- for opcode in struct_list:
- line = line + "case OpCode::{name}".format(name = opcode + ":\n")
- line = line + "{\n"
- line = line + " StructType *ST = cast<StructType>(Ty);\n"
- line = line + " return ST->getElementType(0);\n"
- line = line + "}"
- print(line)
- class db_valfns_gen:
- "A generator of validation functions."
- def __init__(self, db):
- self.db = db
- def print_content(self):
- self.print_header()
- self.print_body()
- def print_header(self):
- print("///////////////////////////////////////////////////////////////////////////////")
- print("// Instruction validation functions. //")
- def bool_lit(self, val):
- return "true" if val else "false";
- def op_type(self, o):
- if o.llvm_type == "i8":
- return "int8_t"
- if o.llvm_type == "u8":
- return "uint8_t"
- raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
- def op_const_expr(self, o):
- if o.llvm_type == "i8" or o.llvm_type == "u8":
- return "(%s)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
- raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
- def print_body(self):
- llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
- print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
- self.print_comment(" // ", "Allow: %s" % ", ".join([i.name + "=" + str(i.llvm_id) for i in llvm_instrs]))
- print(" unsigned op = I.getOpcode();")
- print(" return %s;" % build_range_code("op", [i.llvm_id for i in llvm_instrs]))
- print("}")
- print("")
- def print_comment(self, prefix, val):
- print(format_comment(prefix, val))
- class macro_table_gen:
- "A generator for macro tables."
- def format_row(self, row, widths, sep=', '):
- frow = [str(item) + sep + (' ' * (width - len(item)))
- for item, width in list(zip(row, widths))[:-1]] + [str(row[-1])]
- return ''.join(frow)
- def format_table(self, table, *args, **kwargs):
- widths = [ functools.reduce(max, [ len(row[i])
- for row in table], 1)
- for i in range(len(table[0]))]
- formatted = []
- for row in table:
- formatted.append(self.format_row(row, widths, *args, **kwargs))
- return formatted
- def print_table(self, table, macro_name):
- formatted = self.format_table(table)
- print( '// %s\n' % formatted[0] +
- '#define %s(ROW) \\\n' % macro_name +
- ' \\\n'.join([' ROW(%s)' % frow for frow in formatted[1:]]))
- class db_sigpoint_gen(macro_table_gen):
- "A generator for SigPoint tables."
- def __init__(self, db):
- self.db = db
- def print_sigpoint_table(self):
- self.print_table(self.db.sigpoint_table, 'DO_SIGPOINTS')
- def print_interpretation_table(self):
- self.print_table(self.db.interpretation_table, 'DO_INTERPRETATION_TABLE')
- def print_content(self):
- self.print_sigpoint_table()
- self.print_interpretation_table()
- class string_output:
- def __init__(self):
- self.val = ""
- def write(self, text):
- self.val = self.val + str(text)
- def __str__(self):
- return self.val
- def run_with_stdout(fn):
- import sys
- _stdout_saved = sys.stdout
- so = string_output()
- try:
- sys.stdout = so
- fn()
- finally:
- sys.stdout = _stdout_saved
- return str(so)
- def get_hlsl_intrinsic_stats():
- db = get_db_hlsl()
- longest_fn = db.intrinsics[0]
- longest_param = None
- longest_arglist_fn = db.intrinsics[0]
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- # Get some values for maximum lengths.
- if len(i.name) > len(longest_fn.name):
- longest_fn = i
- for p_idx, p in enumerate(i.params):
- if p_idx > 0 and (longest_param is None or len(p.name) > len(longest_param.name)):
- longest_param = p
- if len(i.params) > len(longest_arglist_fn.params):
- longest_arglist_fn = i
- result = ""
- for k in sorted(db.namespaces.keys()):
- v = db.namespaces[k]
- result += "static const UINT g_u%sCount = %d;\n" % (k, len(v.intrinsics))
- result += "\n"
- result += "static const int g_MaxIntrinsicName = %d; // Count of characters for longest intrinsic name - '%s'\n" % (len(longest_fn.name), longest_fn.name)
- result += "static const int g_MaxIntrinsicParamName = %d; // Count of characters for longest intrinsic parameter name - '%s'\n" % (len(longest_param.name), longest_param.name)
- result += "static const int g_MaxIntrinsicParamCount = %d; // Count of parameters (without return) for longest intrinsic argument list - '%s'\n" % (len(longest_arglist_fn.params) - 1, longest_arglist_fn.name)
- return result
- def get_hlsl_intrinsics():
- db = get_db_hlsl()
- result = ""
- last_ns = ""
- ns_table = ""
- is_vk_table = False # SPIRV Change
- id_prefix = ""
- arg_idx = 0
- opcode_namespace = db.opcode_namespace
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- if last_ns != i.ns:
- last_ns = i.ns
- id_prefix = "IOP" if last_ns == "Intrinsics" else "MOP"
- if (len(ns_table)):
- result += ns_table + "};\n"
- # SPIRV Change Starts
- if is_vk_table:
- result += "\n#endif // ENABLE_SPIRV_CODEGEN\n"
- is_vk_table = False
- # SPIRV Change Ends
- result += "\n//\n// Start of %s\n//\n\n" % (last_ns)
- # This used to be qualified as __declspec(selectany), but that's no longer necessary.
- ns_table = "static const HLSL_INTRINSIC g_%s[] =\n{\n" % (last_ns)
- # SPIRV Change Starts
- if (i.vulkanSpecific):
- is_vk_table = True
- result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
- # SPIRV Change Ends
- arg_idx = 0
- ns_table += " {(UINT)%s::%s_%s, %s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), str(i.wave).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
- result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
- for p in i.params:
- name = p.name
- if name == i.name and i.hidden:
- # First parameter defines intrinsic name for parsing in HLSL.
- # Prepend '$hidden$' for hidden intrinsic so it can't be used in HLSL.
- name = "$hidden$" + name
- result += " {\"%s\", %s, %s, %s, %s, %s, %s, %s},\n" % (
- name, p.param_qual, p.template_id, p.template_list,
- p.component_id, p.component_list, p.rows, p.cols)
- result += "};\n\n"
- arg_idx += 1
- result += ns_table + "};\n"
- result += "\n#endif // ENABLE_SPIRV_CODEGEN\n" if is_vk_table else "" # SPIRV Change
- return result
- # SPIRV Change Starts
- def wrap_with_ifdef_if_vulkan_specific(intrinsic, text):
- if intrinsic.vulkanSpecific:
- return "#ifdef ENABLE_SPIRV_CODEGEN\n" + text + "#endif // ENABLE_SPIRV_CODEGEN\n"
- return text
- # SPIRV Change Ends
- def enum_hlsl_intrinsics():
- db = get_db_hlsl()
- result = ""
- enumed = []
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- if (i.enum_name not in enumed):
- enumerant = " %s,\n" % (i.enum_name)
- result += wrap_with_ifdef_if_vulkan_specific(i, enumerant) # SPIRV Change
- enumed.append(i.enum_name)
- # unsigned
- result += " // unsigned\n"
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- if (i.unsigned_op != ""):
- if (i.unsigned_op not in enumed):
- result += " %s,\n" % (i.unsigned_op)
- enumed.append(i.unsigned_op)
- result += " Num_Intrinsics,\n"
- return result
- def has_unsigned_hlsl_intrinsics():
- db = get_db_hlsl()
- result = ""
- enumed = []
- # unsigned
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- if (i.unsigned_op != ""):
- if (i.enum_name not in enumed):
- result += " case IntrinsicOp::%s:\n" % (i.enum_name)
- enumed.append(i.enum_name)
- return result
- def get_unsigned_hlsl_intrinsics():
- db = get_db_hlsl()
- result = ""
- enumed = []
- # unsigned
- for i in sorted(db.intrinsics, key=lambda x: x.key):
- if (i.unsigned_op != ""):
- if (i.enum_name not in enumed):
- enumed.append(i.enum_name)
- result += " case IntrinsicOp::%s:\n" % (i.enum_name)
- result += " return static_cast<unsigned>(IntrinsicOp::%s);\n" % (i.unsigned_op)
- return result
- def get_oloads_props():
- db = get_db_dxil()
- gen = db_oload_gen(db)
- return run_with_stdout(lambda: gen.print_opfunc_props())
- def get_oloads_funcs():
- db = get_db_dxil()
- gen = db_oload_gen(db)
- return run_with_stdout(lambda: gen.print_opfunc_table())
- def get_funcs_oload_type():
- db = get_db_dxil()
- gen = db_oload_gen(db)
- return run_with_stdout(lambda: gen.print_opfunc_oload_type())
- def get_enum_decl(name, **kwargs):
- db = get_db_dxil()
- gen = db_enumhelp_gen(db)
- return run_with_stdout(lambda: gen.print_enum(db.enum_idx[name], **kwargs))
- def get_valrule_enum():
- return get_enum_decl("ValidationRule", hide_val=True)
- def get_valrule_text():
- db = get_db_dxil()
- result = "switch(value) {\n"
- for v in db.enum_idx["ValidationRule"].values:
- result += " case hlsl::ValidationRule::" + v.name + ": return \"" + v.err_msg + "\";\n"
- result += "}\n"
- return result
- def get_instrhelper():
- db = get_db_dxil()
- gen = db_instrhelp_gen(db)
- return run_with_stdout(lambda: gen.print_body())
- def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
- db = get_db_dxil()
- if type(pred) == str:
- pred_fn = lambda i: getattr(i, pred)
- else:
- pred_fn = pred
- llvm_instrs = [i for i in db.instr if pred_fn(i)]
- result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(getattr(i, attr_name)) for i in llvm_instrs]))
- result += "return %s;" % build_range_code(varname, [getattr(i, attr_name) for i in llvm_instrs])
- result += "\n"
- return result
- def get_instrs_rst():
- "Create an rst table of allowed LLVM instructions."
- db = get_db_dxil()
- instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : v.llvm_id)
- rows = []
- rows.append(["Instruction", "Action", "Operand overloads"])
- for i in instrs:
- rows.append([i.name, i.doc, i.oload_types])
- result = "\n\n" + format_rst_table(rows) + "\n\n"
- # Add detailed instruction information where available.
- for i in instrs:
- if i.remarks:
- result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
- return result + "\n"
- def get_init_passes(category_libs):
- "Create a series of statements to initialize passes in a registry."
- db = get_db_dxil()
- result = ""
- for p in sorted(db.passes, key=lambda p : p.type_name):
- # Skip if not in target category.
- if (p.category_lib not in category_libs):
- continue
- result += "initialize%sPass(Registry);\n" % p.type_name
- return result
- def get_pass_arg_names():
- "Return an ArrayRef of argument names based on passName"
- db = get_db_dxil()
- decl_result = ""
- check_result = ""
- for p in sorted(db.passes, key=lambda p : p.type_name):
- if len(p.args):
- decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
- check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
- sep = ""
- for a in p.args:
- decl_result += sep + "\"%s\"" % a.name
- sep = ", "
- decl_result += " };\n"
- return decl_result + check_result
- def get_pass_arg_descs():
- "Return an ArrayRef of argument descriptions based on passName"
- db = get_db_dxil()
- decl_result = ""
- check_result = ""
- for p in sorted(db.passes, key=lambda p : p.type_name):
- if len(p.args):
- decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
- check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef<LPCSTR>(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
- sep = ""
- for a in p.args:
- decl_result += sep + "\"%s\"" % a.doc
- sep = ", "
- decl_result += " };\n"
- return decl_result + check_result
- def get_is_pass_option_name():
- "Create a return expression to check whether a value 'S' is a pass option name."
- db = get_db_dxil()
- prefix = ""
- result = "return "
- for k in sorted(db.pass_idx_args):
- result += prefix + "S.equals(\"%s\")" % k
- prefix = "\n || "
- return result + ";"
- def get_opcodes_rst():
- "Create an rst table of opcodes"
- db = get_db_dxil()
- instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : v.dxil_opid)
- rows = []
- rows.append(["ID", "Name", "Description"])
- for i in instrs:
- op_name = i.dxil_op
- if i.remarks:
- op_name = op_name + "_" # append _ to enable internal hyperlink on rst files
- rows.append([i.dxil_opid, op_name, i.doc])
- result = "\n\n" + format_rst_table(rows) + "\n\n"
- # Add detailed instruction information where available.
- instrs = sorted(instrs, key=lambda v : v.name)
- for i in instrs:
- if i.remarks:
- result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
- return result + "\n"
- def get_valrules_rst():
- "Create an rst table of validation rules instructions."
- db = get_db_dxil()
- rules = [i for i in db.val_rules if not i.is_disabled]
- rules = sorted(rules, key=lambda v : v.name)
- rows = []
- rows.append(["Rule Code", "Description"])
- for i in rules:
- rows.append([i.name, i.doc])
- return "\n\n" + format_rst_table(rows) + "\n\n"
- def get_opsigs():
- # Create a list of DXIL operation signatures, sorted by ID.
- db = get_db_dxil()
- instrs = [i for i in db.instr if i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : v.dxil_opid)
- # db_dxil already asserts that the numbering is dense.
- # Create the code to write out.
- code = "static const char *OpCodeSignatures[] = {\n"
- for inst_idx,i in enumerate(instrs):
- code += " \"("
- for operand in i.ops:
- if operand.pos > 1: # skip 0 (the return value) and 1 (the opcode itself)
- code += operand.name
- if operand.pos < len(i.ops) - 1:
- code += ","
- code += ")\""
- if inst_idx < len(instrs) - 1:
- code += ","
- code += " // " + i.name
- code += "\n"
- code += "};\n"
- return code
- shader_stage_to_ShaderKind = {
- 'vertex': 'Vertex',
- 'pixel': 'Pixel',
- 'geometry': 'Geometry',
- 'compute': 'Compute',
- 'hull': 'Hull',
- 'domain': 'Domain',
- 'library': 'Library',
- 'raygeneration': 'RayGeneration',
- 'intersection': 'Intersection',
- 'anyhit': 'AnyHit',
- 'closesthit': 'ClosestHit',
- 'miss': 'Miss',
- 'callable': 'Callable',
- 'mesh' : 'Mesh',
- 'amplification' : 'Amplification',
- }
- def get_min_sm_and_mask_text():
- db = get_db_dxil()
- instrs = [i for i in db.instr if i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_model_translated, v.shader_stages, v.dxil_opid))
- last_model = None
- last_model_translated = None
- last_stage = None
- grouped_instrs = []
- code = ""
- def flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage):
- if len(grouped_instrs) == 0:
- return ""
- result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
- result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ") {\n"
- default = True
- if last_model != (6,0):
- default = False
- if last_model_translated:
- result += " if (bWithTranslation) {\n"
- result += " major = %d; minor = %d;\n } else {\n " % last_model_translated
- result += " major = %d; minor = %d;\n" % last_model
- if last_model_translated:
- result += " }\n"
- if last_stage:
- default = False
- result += " mask = %s;\n" % ' | '.join([ 'SFLAG(%s)' % shader_stage_to_ShaderKind[c]
- for c in last_stage
- ])
- if default:
- # don't write these out, instead fall through
- return ""
- return result + " return;\n}\n"
- for i in instrs:
- if ((i.shader_model, i.shader_model_translated, i.shader_stages) !=
- (last_model, last_model_translated, last_stage)):
- code += flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage)
- grouped_instrs = []
- last_model = i.shader_model
- last_model_translated = i.shader_model_translated
- last_stage = i.shader_stages
- grouped_instrs.append(i)
- code += flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage)
- return code
- check_pSM_for_shader_stage = {
- 'vertex': 'SK == DXIL::ShaderKind::Vertex',
- 'pixel': 'SK == DXIL::ShaderKind::Pixel',
- 'geometry': 'SK == DXIL::ShaderKind::Geometry',
- 'compute': 'SK == DXIL::ShaderKind::Compute',
- 'hull': 'SK == DXIL::ShaderKind::Hull',
- 'domain': 'SK == DXIL::ShaderKind::Domain',
- 'library': 'SK == DXIL::ShaderKind::Library',
- 'raygeneration': 'SK == DXIL::ShaderKind::RayGeneration',
- 'intersection': 'SK == DXIL::ShaderKind::Intersection',
- 'anyhit': 'SK == DXIL::ShaderKind::AnyHit',
- 'closesthit': 'SK == DXIL::ShaderKind::ClosestHit',
- 'miss': 'SK == DXIL::ShaderKind::Miss',
- 'callable': 'SK == DXIL::ShaderKind::Callable',
- 'mesh': 'SK == DXIL::ShaderKind::Mesh',
- 'amplification': 'SK == DXIL::ShaderKind::Amplification',
- }
- def get_valopcode_sm_text():
- db = get_db_dxil()
- instrs = [i for i in db.instr if i.is_dxil_op]
- instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_stages, v.dxil_opid))
- last_model = None
- last_stage = None
- grouped_instrs = []
- code = ""
- def flush_instrs(grouped_instrs, last_model, last_stage):
- if len(grouped_instrs) == 0:
- return ""
- result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
- result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ")\n"
- result += " return "
- model_cond = stage_cond = None
- if last_model != (6,0):
- model_cond = "major > %d || (major == %d && minor >= %d)" % (
- last_model[0], last_model[0], last_model[1])
- if last_stage:
- stage_cond = ' || '.join([check_pSM_for_shader_stage[c] for c in last_stage])
- if model_cond or stage_cond:
- result += '\n && '.join(
- ["(%s)" % expr for expr in (model_cond, stage_cond) if expr] )
- return result + ";\n"
- else:
- # don't write these out, instead fall through
- return ""
- for i in instrs:
- if (i.shader_model, i.shader_stages) != (last_model, last_stage):
- code += flush_instrs(grouped_instrs, last_model, last_stage)
- grouped_instrs = []
- last_model = i.shader_model
- last_stage = i.shader_stages
- grouped_instrs.append(i)
- code += flush_instrs(grouped_instrs, last_model, last_stage)
- code += "return true;\n"
- return code
- def get_sigpoint_table():
- db = get_db_dxil()
- gen = db_sigpoint_gen(db)
- return run_with_stdout(lambda: gen.print_sigpoint_table())
- def get_sigpoint_rst():
- "Create an rst table for SigPointKind."
- db = get_db_dxil()
- rows = [row[:] for row in db.sigpoint_table[:-1]] # Copy table
- e = dict([(v.name, v) for v in db.enum_idx['SigPointKind'].values])
- rows[0] = ['ID'] + rows[0] + ['Description']
- for i in range(1, len(rows)):
- row = rows[i]
- v = e[row[0]]
- rows[i] = [v.value] + row + [v.doc]
- return "\n\n" + format_rst_table(rows) + "\n\n"
- def get_sem_interpretation_enum_rst():
- db = get_db_dxil()
- rows = ([['ID', 'Name', 'Description']] +
- [[v.value, v.name, v.doc]
- for v in db.enum_idx['SemanticInterpretationKind'].values[:-1]])
- return "\n\n" + format_rst_table(rows) + "\n\n"
- def get_sem_interpretation_table_rst():
- db = get_db_dxil()
- return "\n\n" + format_rst_table(db.interpretation_table) + "\n\n"
- def get_interpretation_table():
- db = get_db_dxil()
- gen = db_sigpoint_gen(db)
- return run_with_stdout(lambda: gen.print_interpretation_table())
- highest_major = 6
- highest_minor = 6
- highest_shader_models = {4:1, 5:1, 6:highest_minor}
- def getShaderModels():
- shader_models = []
- for major, minor in highest_shader_models.items():
- for i in range(0, minor+1):
- shader_models.append(str(major) + "_" + str(i))
- return shader_models;
- def get_highest_shader_model():
- result = """static const unsigned kHighestMajor = %d;
- static const unsigned kHighestMinor = %d;"""%(highest_major, highest_minor)
- return result
- def get_dxil_version_minor():
- return "const unsigned kDxilMinor = %d;"%highest_minor
- def get_is_shader_model_plus():
- result = ""
- for i in range(0, highest_minor+1):
- result += "bool IsSM%d%dPlus() const { return IsSMAtLeast(%d, %d); }\n"%(highest_major, i,highest_major, i)
- return result
- profile_to_kind = {"ps":"Kind::Pixel", "vs":"Kind::Vertex", "gs":"Kind::Geometry", "hs":"5_0", "ds":"5_0", "cs":"4_0", "lib":"6_1", "ms":"6_5", "as":"6_5"}
- class shader_profile(object):
- "The profile description for a DXIL instruction"
- def __init__(self, kind, kind_name, enum_name, start_sm, input_size, output_size):
- self.kind = kind # position in parameter list
- self.kind_name = kind_name
- self.enum_name = enum_name
- self.start_sm = start_sm
- self.input_size = input_size
- self.output_size = output_size
- # kind is from DXIL::ShaderKind.
- shader_profiles = [ shader_profile(0, "ps", "Kind::Pixel", "4_0", 32, 8),
- shader_profile(1, "vs", "Kind::Vertex", "4_0", 32, 32),
- shader_profile(2, "gs", "Kind::Geometry", "4_0", 32, 32),
- shader_profile(3, "hs", "Kind::Hull", "5_0", 32, 32),
- shader_profile(4, "ds", "Kind::Domain", "5_0", 32, 32),
- shader_profile(5, "cs", "Kind::Compute", "4_0", 0,0),
- shader_profile(6, "lib", "Kind::Library", "6_1", 32,32),
- shader_profile(13, "ms", "Kind::Mesh", "6_5", 0,0),
- shader_profile(14, "as", "Kind::Amplification", "6_5", 0,0),
- ]
- def getShaderProfiles():
- # order match DXIL::ShaderKind.
- profiles = {"ps":"4_0", "vs":"4_0", "gs":"4_0", "hs":"5_0", "ds":"5_0", "cs":"4_0", "lib":"6_1", "ms":"6_5", "as":"6_5"}
- return profiles;
- def get_shader_models():
- result = ""
- for profile in shader_profiles:
- min_sm = profile.start_sm
- input_size = profile.input_size
- output_size = profile.output_size
- kind = profile.kind
- kind_name = profile.kind_name
- enum_name = profile.enum_name
- for major, minor in highest_shader_models.items():
- UAV_info = "true, true, UINT_MAX"
- if major > 5:
- pass
- elif major == 4:
- UAV_info = "false, false, 0"
- if kind == "cs":
- UAV_info = "true, false, 1"
- elif major == 5:
- UAV_info = "true, true, 64"
- for i in range(0, minor+1):
- sm = "%d_%d"%(major, i)
- if (min_sm > sm):
- continue
- input_size = profile.input_size
- output_size = profile.output_size
- if major == 4:
- if i == 0:
- if kind_name == "gs":
- input_size = 16
- elif kind_name == "vs":
- input_size = 16
- output_size = 16
- sm_name = "%s_%s"%(kind_name,sm)
- result += "SM(%s, %d, %d, \"%s\", %d, %d, %s),\n" % (enum_name, major, i, sm_name, input_size, output_size, UAV_info)
- if kind_name == "lib":
- result += "// lib_6_x is for offline linking only, and relaxes restrictions\n"
- result += "SM(Kind::Library, 6, kOfflineMinor, \"lib_6_x\", 32, 32, true, true, UINT_MAX),\n"
- result += "// Values before Invalid must remain sorted by Kind, then Major, then Minor.\n"
- result += "SM(Kind::Invalid, 0, 0, \"invalid\", 0, 0, false, false, 0),\n"
- return result
- def get_num_shader_models():
- count = 0
- for profile in shader_profiles:
- min_sm = profile.start_sm
- input_size = profile.input_size
- output_size = profile.output_size
- kind = profile.kind
- kind_name = profile.kind_name
- enum_name = profile.enum_name
- for major, minor in highest_shader_models.items():
- for i in range(0, minor+1):
- sm = "%d_%d"%(major, i)
- if (min_sm > sm):
- continue
- count += 1
- if kind_name == "lib":
- # for lib_6_x
- count += 1
- # for invalid shader_model.
- count += 1
- return "static const unsigned kNumShaderModels = %d;"%count
- def build_shader_model_hash_idx_map():
- #must match get_shader_models.
- result = "const static std::unordered_map<unsigned, unsigned> hashToIdxMap = {\n"
- count = 0
- for profile in shader_profiles:
- min_sm = profile.start_sm
- kind = profile.kind
- kind_name = profile.kind_name
- for major, minor in highest_shader_models.items():
- for i in range(0, minor+1):
- sm = "%d_%d"%(major, i)
- if (min_sm > sm):
- continue
- sm_name = "%s_%s"%(kind_name,sm)
- hash_v = kind << 16 | major << 8 | i;
- result += "{%d,%d}, //%s\n" % (hash_v, count, sm_name)
- count += 1
- if kind_name == "lib":
- result += "// lib_6_x is for offline linking only, and relaxes restrictions\n"
- major = 6
- #static const unsigned kOfflineMinor = 0xF;
- i = 15
- hash_v = kind << 16 | major << 8 | i;
- result += "{%d,%d},//%s\n" % (hash_v, count, "lib_6_x")
- count += 1
- result += "};\n"
- return result
- def get_validation_version():
- result = """// 1.0 is the first validator.
- // 1.1 adds:
- // - ILDN container part support
- // 1.2 adds:
- // - Metadata for floating point denorm mode
- // 1.3 adds:
- // - Library support
- // - Raytracing support
- // - i64/f64 overloads for rawBufferLoad/Store
- // 1.4 adds:
- // - packed u8x4/i8x4 dot with accumulate to i32
- // - half dot2 with accumulate to float
- // 1.5 adds:
- // - WaveMatch, WaveMultiPrefixOp, WaveMultiPrefixBitCount
- // - HASH container part support
- // - Mesh and Amplification shaders
- // - DXR 1.1 & RayQuery support
- *pMajor = 1;
- *pMinor = %d;
- """ % highest_minor
- return result
- def get_target_profiles():
- result = "HelpText<\"Set target profile. \\n"
- result += "\\t<profile>: "
- profiles = getShaderProfiles()
- shader_models = getShaderModels()
- base_sm = "%d_0"%highest_major
- for profile, min_sm in profiles.items():
- for shader_model in shader_models:
- if (base_sm > shader_model):
- continue
- if (min_sm > shader_model):
- continue
- result += "%s_%s, "%(profile,shader_model)
- result += "\\n\\t\\t "
- result += "\">;"
- return result
- def get_min_validator_version():
- result = ""
- for i in range(0, highest_minor+1):
- result += "case %d:\n"%i
- result += " ValMinor = %d;\n"%i
- result += " break;\n"
- return result
- def get_dxil_version():
- result = ""
- for i in range(0, highest_minor+1):
- result += "case %d:\n"%i
- result += " DxilMinor = %d;\n"%i
- result += " break;\n"
- result += "case kOfflineMinor: // Always update this to highest dxil version\n"
- result += " DxilMinor = %d;\n"%highest_minor
- result += " break;\n"
- return result
- def get_shader_model_get():
- # const static std::unordered_map<unsigned, unsigned> hashToIdxMap = {};
- result = build_shader_model_hash_idx_map()
- result += "unsigned hash = (unsigned)Kind << 16 | Major << 8 | Minor;\n"
- result += "auto it = hashToIdxMap.find(hash);\n"
- result += "if (it == hashToIdxMap.end())\n"
- result += " return GetInvalid();\n"
- result += "return &ms_ShaderModels[it->second];"
- return result
- def get_shader_model_by_name():
- result = ""
- for i in range(2, highest_minor+1):
- result += "case '%d':\n"%i
- result += " if (Major == %d) {\n"%highest_major
- result += " Minor = %d;\n"%i
- result += " break;\n"
- result += " }\n"
- result += "else return GetInvalid();\n"
- return result
- def get_is_valid_for_dxil():
- result = ""
- for i in range(0, highest_minor+1):
- result += "case %d:\n"%i
- return result
- def RunCodeTagUpdate(file_path):
- import os
- import CodeTags
- print(" ... updating " + file_path)
- args = [file_path, file_path + ".tmp"]
- result = CodeTags.main(args)
- if result != 0:
- print(" ... error: %d" % result)
- else:
- with open(file_path, 'rt') as f:
- before = f.read()
- with open(file_path + ".tmp", 'rt') as f:
- after = f.read()
- if before == after:
- print(" --- no changes found")
- else:
- print(" +++ changes found, updating file")
- with open(file_path, 'wt') as f:
- f.write(after)
- os.remove(file_path + ".tmp")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Generate code to handle instructions.")
- parser.add_argument("-gen", choices=["docs-ref", "docs-spec", "inst-header", "enums", "oloads", "valfns"], help="Output type to generate.")
- parser.add_argument("-update-files", action="store_const", const=True)
- args = parser.parse_args()
- db = get_db_dxil() # used by all generators, also handy to have it run validation
- if args.gen == "docs-ref":
- gen = db_docsref_gen(db)
- gen.print_content()
- if args.gen == "docs-spec":
- import os, docutils.core
- assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
- hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
- spec_file = os.path.abspath(os.path.join(hlsl_src_dir, "docs/DXIL.rst"))
- with open(spec_file) as f:
- s = docutils.core.publish_file(f, writer_name="html")
- if args.gen == "inst-header":
- gen = db_instrhelp_gen(db)
- gen.print_content()
- if args.gen == "enums":
- gen = db_enumhelp_gen(db)
- gen.print_content()
- if args.gen == "oloads":
- gen = db_oload_gen(db)
- gen.print_content()
- if args.gen == "valfns":
- gen = db_valfns_gen(db)
- gen.print_content()
- if args.update_files:
- print("Updating files ...")
- import CodeTags
- import os
- assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
- hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
- pj = lambda *parts: os.path.abspath(os.path.join(*parts))
- files = [
- 'docs/DXIL.rst',
- 'lib/DXIL/DXILOperations.cpp',
- 'lib/DXIL/DXILShaderModel.cpp',
- 'include/dxc/DXIL/DXILConstants.h',
- 'include/dxc/DXIL/DXILShaderModel.h',
- 'include/dxc/HLSL/DxilValidation.h',
- 'include/dxc/Support/HLSLOptions.td',
- 'include/dxc/DXIL/DxilInstructions.h',
- 'lib/HLSL/DxcOptimizer.cpp',
- 'lib/DxilPIXPasses/DxilPIXPasses.cpp',
- 'lib/HLSL/DxilValidation.cpp',
- 'tools/clang/lib/Sema/gen_intrin_main_tables_15.h',
- 'include/dxc/HlslIntrinsicOp.h',
- 'tools/clang/tools/dxcompiler/dxcdisassembler.cpp',
- 'include/dxc/DXIL/DxilSigPoint.inl',
- ]
- for relative_file_path in files:
- RunCodeTagUpdate(pj(hlsl_src_dir, relative_file_path))
|