# Copyright (C) Microsoft Corporation. All rights reserved.
# This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details.
import argparse
import functools
import collections
from hctdb import *
# get db singletons
g_db_dxil = None
def get_db_dxil():
global g_db_dxil
if g_db_dxil is None:
g_db_dxil = db_dxil()
return g_db_dxil
g_db_hlsl = None
def get_db_hlsl():
global g_db_hlsl
if g_db_hlsl is None:
thisdir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(thisdir, "gen_intrin_main.txt"), "r") as f:
g_db_hlsl = db_hlsl(f)
return g_db_hlsl
def format_comment(prefix, val):
"Formats a value with a line-comment prefix."
result = ""
line_width = 80
content_width = line_width - len(prefix)
l = len(val)
while l:
if l < content_width:
result += prefix + val.strip()
result += "\n"
l = 0
else:
split_idx = val.rfind(" ", 0, content_width)
result += prefix + val[:split_idx].strip()
result += "\n"
val = val[split_idx+1:]
l = len(val)
return result
def format_rst_table(list_of_tuples):
"Produces a reStructuredText simple table from the specified list of tuples."
# Calculate widths.
widths = None
for t in list_of_tuples:
if widths is None:
widths = [0] * len(t)
for i, v in enumerate(t):
widths[i] = max(widths[i], len(str(v)))
# Build banner line.
banner = ""
for i, w in enumerate(widths):
if i > 0:
banner += " "
banner += "=" * w
banner += "\n"
# Build the result.
result = banner
for i, t in enumerate(list_of_tuples):
for j, v in enumerate(t):
if j > 0:
result += " "
result += str(v)
result += " " * (widths[j] - len(str(v)))
result = result.rstrip()
result += "\n"
if i == 0:
result += banner
result += banner
return result
def build_range_tuples(i):
"Produces a list of tuples with contiguous ranges in the input list."
i = sorted(i)
low_bound = None
high_bound = None
for val in i:
if low_bound is None:
low_bound = val
high_bound = val
else:
assert(not high_bound is None)
if val == high_bound + 1:
high_bound = val
else:
yield (low_bound, high_bound)
low_bound = val
high_bound = val
if not low_bound is None:
yield (low_bound, high_bound)
def build_range_code(var, i):
"Produces a fragment of code that tests whether the variable name matches values in the given range."
ranges = build_range_tuples(i)
result = ""
for r in ranges:
if r[0] == r[1]:
cond = var + " == " + str(r[0])
else:
cond = "%d <= %s && %s <= %d" % (r[0], var, var, r[1])
if result == "":
result = cond
else:
result = result + " || " + cond
return result
class db_docsref_gen:
"A generator of reference documentation."
def __init__(self, db):
self.db = db
instrs = [i for i in self.db.instr if i.is_dxil_op]
instrs = sorted(instrs, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
self.instrs = instrs
val_rules = sorted(db.val_rules, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
self.val_rules = val_rules
def print_content(self):
self.print_header()
self.print_body()
self.print_footer()
def print_header(self):
print("")
print("
DXIL Reference")
print("")
print("DXIL Reference
")
self.print_toc("Instructions", "i", self.instrs)
self.print_toc("Rules", "r", self.val_rules)
def print_body(self):
self.print_instruction_details()
self.print_valrule_details()
def print_instruction_details(self):
print("Instruction Details
")
for i in self.instrs:
print("" % (i.name, i.name))
print("Opcode: %d. This instruction %s.
" % (i.dxil_opid, i.doc))
if i.remarks:
# This is likely a .rst fragment, but this will do for now.
print(" " + i.remarks + "
")
print("Operands:
")
print("")
for o in i.ops:
if o.pos == 0:
print("- result: %s - %s
" % (o.llvm_type, o.doc))
else:
enum_desc = "" if o.enum_name == "" else " one of %s: %s" % (o.enum_name, ",".join(db.enum_idx[o.enum_name].value_names()))
print("- %d - %s: %s%s%s
" % (o.pos - 1, o.name, o.llvm_type, "" if o.doc == "" else " - " + o.doc, enum_desc))
print("
")
print("")
def print_valrule_details(self):
print("Rule Details
")
for i in self.val_rules:
print("" % (i.name, i.name))
print("" + i.doc + "
")
print("")
def print_toc(self, name, aprefix, values):
print("")
last_category = ""
for i in values:
if i.category != last_category:
if last_category != None:
print("")
print("%s
" % i.category)
last_category = i.category
print("- %s
" % (i.name, i.name))
print("
")
def print_footer(self):
print("")
class db_instrhelp_gen:
"A generator of instruction helper classes."
def __init__(self, db):
self.db = db
TypeInfo = collections.namedtuple("TypeInfo", "name bits")
self.llvm_type_map = {
"i1": TypeInfo("bool", 1),
"i8": TypeInfo("int8_t", 8),
"u8": TypeInfo("uint8_t", 8),
"i32": TypeInfo("int32_t", 32),
"u32": TypeInfo("uint32_t", 32)
}
self.IsDxilOpFuncCallInst = "hlsl::OP::IsDxilOpFuncCallInst"
def print_content(self):
self.print_header()
self.print_body()
self.print_footer()
def print_header(self):
print("///////////////////////////////////////////////////////////////////////////////")
print("// //")
print("// Copyright (C) Microsoft Corporation. All rights reserved. //")
print("// DxilInstructions.h //")
print("// //")
print("// This file provides a library of instruction helper classes. //")
print("// //")
print("// MUCH WORK YET TO BE DONE - EXPECT THIS WILL CHANGE - GENERATED FILE //")
print("// //")
print("///////////////////////////////////////////////////////////////////////////////")
print("")
print("// TODO: add correct include directives")
print("// TODO: add accessors with values")
print("// TODO: add validation support code, including calling into right fn")
print("// TODO: add type hierarchy")
print("namespace hlsl {")
def bool_lit(self, val):
return "true" if val else "false";
def op_type(self, o):
if o.llvm_type in self.llvm_type_map:
return self.llvm_type_map[o.llvm_type].name
raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
def op_size(self, o):
if o.llvm_type in self.llvm_type_map:
return self.llvm_type_map[o.llvm_type].bits
raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
def op_const_expr(self, o):
return "(%s)(llvm::dyn_cast(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
def op_set_const_expr(self, o):
type_size = self.op_size(o)
return "llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), %d), llvm::APInt(%d, (uint64_t)val))" % (type_size, type_size)
def print_body(self):
for i in self.db.instr:
if i.is_reserved: continue
if i.inst_helper_prefix:
struct_name = "%s_%s" % (i.inst_helper_prefix, i.name)
elif i.is_dxil_op:
struct_name = "DxilInst_%s" % i.name
else:
struct_name = "LlvmInst_%s" % i.name
if i.doc:
print("/// This instruction %s" % i.doc)
print("struct %s {" % struct_name)
print(" llvm::Instruction *Instr;")
print(" // Construction and identification")
print(" %s(llvm::Instruction *pInstr) : Instr(pInstr) {}" % struct_name)
print(" operator bool() const {")
if i.is_dxil_op:
op_name = i.fully_qualified_name()
print(" return %s(Instr, %s);" % (self.IsDxilOpFuncCallInst, op_name))
else:
print(" return Instr->getOpcode() == llvm::Instruction::%s;" % i.name)
print(" }")
print(" // Validation support")
print(" bool isAllowed() const { return %s; }" % self.bool_lit(i.is_allowed))
if i.is_dxil_op:
print(" bool isArgumentListValid() const {")
print(" if (%d != llvm::dyn_cast(Instr)->getNumArgOperands()) return false;" % (len(i.ops) - 1))
print(" return true;")
# TODO - check operand types
print(" }")
EnumWritten = False
for o in i.ops:
if o.pos > 1: # 0 is return type, 1 is DXIL OP id
if not EnumWritten:
print(" // Operand indexes")
print(" enum OperandIdx {")
EnumWritten = True
print(" arg_%s = %d," % (o.name, o.pos - 1))
if EnumWritten:
print(" };")
AccessorsWritten = False
for o in i.ops:
if o.pos > 1: # 0 is return type, 1 is DXIL OP id
if not AccessorsWritten:
print(" // Accessors")
AccessorsWritten = True
print(" llvm::Value *get_%s() const { return Instr->getOperand(%d); }" % (o.name, o.pos - 1))
print(" void set_%s(llvm::Value *val) { Instr->setOperand(%d, val); }" % (o.name, o.pos - 1))
if o.is_const:
print(" %s get_%s_val() const { return %s; }" % (self.op_type(o), o.name, self.op_const_expr(o)))
print(" void set_%s_val(%s val) { Instr->setOperand(%d, %s); }" % (o.name, self.op_type(o), o.pos - 1, self.op_set_const_expr(o)))
print("};")
print("")
def print_footer(self):
print("} // namespace hlsl")
class db_enumhelp_gen:
"A generator of enumeration declarations."
def __init__(self, db):
self.db = db
# Some enums should get a last enum marker.
self.lastEnumNames = {
"OpCode": "NumOpCodes",
"OpCodeClass": "NumOpClasses"
}
def print_enum(self, e, **kwargs):
print("// %s" % e.doc)
print("enum class %s : unsigned {" % e.name)
hide_val = kwargs.get("hide_val", False)
sorted_values = e.values
if kwargs.get("sort_val", True):
sorted_values = sorted(e.values, key=lambda v : ("" if v.category == None else v.category) + "." + v.name)
last_category = None
for v in sorted_values:
if v.category != last_category:
if last_category != None:
print("")
print(" // %s" % v.category)
last_category = v.category
line_format = " {name}"
if not e.is_internal and not hide_val:
line_format += " = {value}"
line_format += ","
if v.doc:
line_format += " // {doc}"
print(line_format.format(name=v.name, value=v.value, doc=v.doc))
if e.name in self.lastEnumNames:
lastName = self.lastEnumNames[e.name]
versioned = ["%s_Dxil_%d_%d = %d," % (lastName, major, minor, info[lastName])
for (major, minor), info in sorted(self.db.dxil_version_info.items())
if lastName in info]
if versioned:
print("")
for val in versioned:
print(" " + val)
print("")
print(" " + lastName + " = " + str(len(sorted_values)) + " // exclusive last value of enumeration")
print("};")
def print_content(self):
for e in sorted(self.db.enums, key=lambda e : e.name):
self.print_enum(e)
class db_oload_gen:
"A generator of overload tables."
def __init__(self, db):
self.db = db
instrs = [i for i in self.db.instr if i.is_dxil_op]
self.instrs = sorted(instrs, key=lambda i : i.dxil_opid)
def print_content(self):
self.print_opfunc_props()
print("...")
self.print_opfunc_table()
def print_opfunc_props(self):
print("const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {")
print("// OpCode OpCode name, OpCodeClass OpCodeClass name, void, h, f, d, i1, i8, i16, i32, i64 function attribute")
# Example formatted string:
# { OC::TempRegLoad, "TempRegLoad", OCC::TempRegLoad, "tempRegLoad", false, true, true, false, true, false, true, true, false, Attribute::ReadOnly, },
# 012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
last_category = None
# overload types are a string of (v)oid, (h)alf, (f)loat, (d)ouble, (1)-bit, (8)-bit, (w)ord, (i)nt, (l)ong
f = lambda i,c : "true," if i.oload_types.find(c) >= 0 else "false,"
lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate" }
attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
for i in self.instrs:
if last_category != i.category:
if last_category != None:
print("")
print(" // {category:118} void, h, f, d, i1, i8, i16, i32, i64 function attribute".format(category=i.category))
last_category = i.category
print(" {{ OC::{name:24} {quotName:27} OCC::{className:25} {classNameQuot:28} {v:>7}{h:>7}{f:>7}{d:>7}{b:>7}{e:>7}{w:>7}{i:>7}{l:>7} {attr:20} }},".format(
name=i.name+",", quotName='"'+i.name+'",', className=i.dxil_class+",", classNameQuot='"'+lower_fn(i.dxil_class)+'",',
v=f(i,"v"), h=f(i,"h"), f=f(i,"f"), d=f(i,"d"), b=f(i,"1"), e=f(i,"8"), w=f(i,"w"), i=f(i,"i"), l=f(i,"l"), attr=attr_fn(i)))
print("};")
def print_opfunc_table(self):
# Print the table for OP::GetOpFunc
op_type_texts = {
"$cb": "CBRT(pETy);",
"$o": "A(pETy);",
"$r": "RRT(pETy);",
"d": "A(pF64);",
"dims": "A(pDim);",
"f": "A(pF32);",
"h": "A(pF16);",
"i1": "A(pI1);",
"i16": "A(pI16);",
"i32": "A(pI32);",
"i32c": "A(pI32C);",
"i64": "A(pI64);",
"i8": "A(pI8);",
"$u4": "A(pI4S);",
"pf32": "A(pPF32);",
"res": "A(pRes);",
"splitdouble": "A(pSDT);",
"twoi32": "A(p2I32);",
"twof32": "A(p2F32);",
"fouri32": "A(p4I32);",
"fourf32": "A(p4F32);",
"u32": "A(pI32);",
"u64": "A(pI64);",
"u8": "A(pI8);",
"v": "A(pV);",
"w": "A(pWav);",
"SamplePos": "A(pPos);",
}
last_category = None
for i in self.instrs:
if last_category != i.category:
if last_category != None:
print("")
print(" // %s" % i.category)
last_category = i.category
line = " case OpCode::{name:24}".format(name = i.name + ":")
for index, o in enumerate(i.ops):
assert o.llvm_type in op_type_texts, "llvm type %s in instruction %s is unknown" % (o.llvm_type, i.name)
op_type_text = op_type_texts[o.llvm_type]
if index == 0:
line = line + "{val:13}".format(val=op_type_text)
else:
line = line + "{val:9}".format(val=op_type_text)
line = line + "break;"
print(line)
def print_opfunc_oload_type(self):
# Print the function for OP::GetOverloadType
elt_ty = "$o"
res_ret_ty = "$r"
cb_ret_ty = "$cb"
last_category = None
index_dict = collections.OrderedDict()
single_dict = collections.OrderedDict()
struct_list = []
for instr in self.instrs:
ret_ty = instr.ops[0].llvm_type
# Skip case return type is overload type
if (ret_ty == elt_ty):
continue
if ret_ty == res_ret_ty:
struct_list.append(instr.name)
continue
if ret_ty == cb_ret_ty:
struct_list.append(instr.name)
continue
in_param_ty = False
# Try to find elt_ty in parameter types.
for index, op in enumerate(instr.ops):
# Skip return type.
if (op.pos == 0):
continue
# Skip dxil opcode.
if (op.pos == 1):
continue
op_type = op.llvm_type
if (op_type == elt_ty):
# Skip return op
index = index - 1
if index not in index_dict:
index_dict[index] = [instr.name]
else:
index_dict[index].append(instr.name)
in_param_ty = True
break
if in_param_ty:
continue
# No overload, just return the single oload_type.
assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
ty = instr.oload_types[0]
type_code_texts = {
"d": "Type::getDoubleTy(m_Ctx)",
"f": "Type::getFloatTy(m_Ctx)",
"h": "Type::getHalfTy",
"1": "IntegerType::get(m_Ctx, 1)",
"8": "IntegerType::get(m_Ctx, 8)",
"w": "IntegerType::get(m_Ctx, 16)",
"i": "IntegerType::get(m_Ctx, 32)",
"l": "IntegerType::get(m_Ctx, 64)",
"v": "Type::getVoidTy(m_Ctx)",
}
assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
ty_code = type_code_texts[ty]
if ty_code not in single_dict:
single_dict[ty_code] = [instr.name]
else:
single_dict[ty_code].append(instr.name)
for index, opcodes in index_dict.items():
line = ""
for opcode in opcodes:
line = line + "case OpCode::{name}".format(name = opcode + ":\n")
line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n"
line = line + " return FT->getParamType(" + str(index) + ");"
print(line)
for code, opcodes in single_dict.items():
line = ""
for opcode in opcodes:
line = line + "case OpCode::{name}".format(name = opcode + ":\n")
line = line + " return " + code + ";"
print(line)
line = ""
for opcode in struct_list:
line = line + "case OpCode::{name}".format(name = opcode + ":\n")
line = line + "{\n"
line = line + " StructType *ST = cast(Ty);\n"
line = line + " return ST->getElementType(0);\n"
line = line + "}"
print(line)
class db_valfns_gen:
"A generator of validation functions."
def __init__(self, db):
self.db = db
def print_content(self):
self.print_header()
self.print_body()
def print_header(self):
print("///////////////////////////////////////////////////////////////////////////////")
print("// Instruction validation functions. //")
def bool_lit(self, val):
return "true" if val else "false";
def op_type(self, o):
if o.llvm_type == "i8":
return "int8_t"
if o.llvm_type == "u8":
return "uint8_t"
raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
def op_const_expr(self, o):
if o.llvm_type == "i8" or o.llvm_type == "u8":
return "(%s)(llvm::dyn_cast(Instr->getOperand(%d))->getZExtValue())" % (self.op_type(o), o.pos - 1)
raise ValueError("Don't know how to describe type %s for operand %s." % (o.llvm_type, o.name))
def print_body(self):
llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
self.print_comment(" // ", "Allow: %s" % ", ".join([i.name + "=" + str(i.llvm_id) for i in llvm_instrs]))
print(" unsigned op = I.getOpcode();")
print(" return %s;" % build_range_code("op", [i.llvm_id for i in llvm_instrs]))
print("}")
print("")
def print_comment(self, prefix, val):
print(format_comment(prefix, val))
class macro_table_gen:
"A generator for macro tables."
def format_row(self, row, widths, sep=', '):
frow = [str(item) + sep + (' ' * (width - len(item)))
for item, width in list(zip(row, widths))[:-1]] + [str(row[-1])]
return ''.join(frow)
def format_table(self, table, *args, **kwargs):
widths = [ functools.reduce(max, [ len(row[i])
for row in table], 1)
for i in range(len(table[0]))]
formatted = []
for row in table:
formatted.append(self.format_row(row, widths, *args, **kwargs))
return formatted
def print_table(self, table, macro_name):
formatted = self.format_table(table)
print( '// %s\n' % formatted[0] +
'#define %s(DO) \\\n' % macro_name +
' \\\n'.join([' DO(%s)' % frow for frow in formatted[1:]]))
class db_sigpoint_gen(macro_table_gen):
"A generator for SigPoint tables."
def __init__(self, db):
self.db = db
def print_sigpoint_table(self):
self.print_table(self.db.sigpoint_table, 'DO_SIGPOINTS')
def print_interpretation_table(self):
self.print_table(self.db.interpretation_table, 'DO_INTERPRETATION_TABLE')
def print_content(self):
self.print_sigpoint_table()
self.print_interpretation_table()
class string_output:
def __init__(self):
self.val = ""
def write(self, text):
self.val = self.val + str(text)
def __str__(self):
return self.val
def run_with_stdout(fn):
import sys
_stdout_saved = sys.stdout
so = string_output()
try:
sys.stdout = so
fn()
finally:
sys.stdout = _stdout_saved
return str(so)
def get_hlsl_intrinsic_stats():
db = get_db_hlsl()
longest_fn = db.intrinsics[0]
longest_param = None
longest_arglist_fn = db.intrinsics[0]
for i in sorted(db.intrinsics, key=lambda x: x.key):
# Get some values for maximum lengths.
if len(i.name) > len(longest_fn.name):
longest_fn = i
for p_idx, p in enumerate(i.params):
if p_idx > 0 and (longest_param is None or len(p.name) > len(longest_param.name)):
longest_param = p
if len(i.params) > len(longest_arglist_fn.params):
longest_arglist_fn = i
result = ""
for k in sorted(db.namespaces.keys()):
v = db.namespaces[k]
result += "static const UINT g_u%sCount = %d;\n" % (k, len(v.intrinsics))
result += "\n"
result += "static const int g_MaxIntrinsicName = %d; // Count of characters for longest intrinsic name - '%s'\n" % (len(longest_fn.name), longest_fn.name)
result += "static const int g_MaxIntrinsicParamName = %d; // Count of characters for longest intrinsic parameter name - '%s'\n" % (len(longest_param.name), longest_param.name)
result += "static const int g_MaxIntrinsicParamCount = %d; // Count of parameters (without return) for longest intrinsic argument list - '%s'\n" % (len(longest_arglist_fn.params) - 1, longest_arglist_fn.name)
return result
def get_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
last_ns = ""
ns_table = ""
is_vk_table = False # SPIRV Change
id_prefix = ""
arg_idx = 0
opcode_namespace = db.opcode_namespace
for i in sorted(db.intrinsics, key=lambda x: x.key):
if last_ns != i.ns:
last_ns = i.ns
id_prefix = "IOP" if last_ns == "Intrinsics" else "MOP"
if (len(ns_table)):
result += ns_table + "};\n"
# SPIRV Change Starts
if is_vk_table:
result += "\n#endif // ENABLE_SPIRV_CODEGEN\n"
is_vk_table = False
# SPIRV Change Ends
result += "\n//\n// Start of %s\n//\n\n" % (last_ns)
# This used to be qualified as __declspec(selectany), but that's no longer necessary.
ns_table = "static const HLSL_INTRINSIC g_%s[] =\n{\n" % (last_ns)
# SPIRV Change Starts
if (i.vulkanSpecific):
is_vk_table = True
result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
# SPIRV Change Ends
arg_idx = 0
ns_table += " {(UINT)%s::%s_%s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
for p in i.params:
result += " {\"%s\", %s, %s, %s, %s, %s, %s, %s},\n" % (
p.name, p.param_qual, p.template_id, p.template_list,
p.component_id, p.component_list, p.rows, p.cols)
result += "};\n\n"
arg_idx += 1
result += ns_table + "};\n"
result += "\n#endif // ENABLE_SPIRV_CODEGEN\n" if is_vk_table else "" # SPIRV Change
return result
# SPIRV Change Starts
def wrap_with_ifdef_if_vulkan_specific(intrinsic, text):
if intrinsic.vulkanSpecific:
return "#ifdef ENABLE_SPIRV_CODEGEN\n" + text + "#endif // ENABLE_SPIRV_CODEGEN\n"
return text
# SPIRV Change Ends
def enum_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
for i in sorted(db.intrinsics, key=lambda x: x.key):
if (i.enum_name not in enumed):
enumerant = " %s,\n" % (i.enum_name)
result += wrap_with_ifdef_if_vulkan_specific(i, enumerant) # SPIRV Change
enumed.append(i.enum_name)
# unsigned
result += " // unsigned\n"
for i in sorted(db.intrinsics, key=lambda x: x.key):
if (i.unsigned_op != ""):
if (i.unsigned_op not in enumed):
result += " %s,\n" % (i.unsigned_op)
enumed.append(i.unsigned_op)
result += " Num_Intrinsics,\n"
return result
def has_unsigned_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
# unsigned
for i in sorted(db.intrinsics, key=lambda x: x.key):
if (i.unsigned_op != ""):
if (i.enum_name not in enumed):
result += " case IntrinsicOp::%s:\n" % (i.enum_name)
enumed.append(i.enum_name)
return result
def get_unsigned_hlsl_intrinsics():
db = get_db_hlsl()
result = ""
enumed = []
# unsigned
for i in sorted(db.intrinsics, key=lambda x: x.key):
if (i.unsigned_op != ""):
if (i.enum_name not in enumed):
enumed.append(i.enum_name)
result += " case IntrinsicOp::%s:\n" % (i.enum_name)
result += " return static_cast(IntrinsicOp::%s);\n" % (i.unsigned_op)
return result
def get_oloads_props():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_props())
def get_oloads_funcs():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_table())
def get_funcs_oload_type():
db = get_db_dxil()
gen = db_oload_gen(db)
return run_with_stdout(lambda: gen.print_opfunc_oload_type())
def get_enum_decl(name, **kwargs):
db = get_db_dxil()
gen = db_enumhelp_gen(db)
return run_with_stdout(lambda: gen.print_enum(db.enum_idx[name], **kwargs))
def get_valrule_enum():
return get_enum_decl("ValidationRule", hide_val=True)
def get_valrule_text():
db = get_db_dxil()
result = "switch(value) {\n"
for v in db.enum_idx["ValidationRule"].values:
result += " case hlsl::ValidationRule::" + v.name + ": return \"" + v.err_msg + "\";\n"
result += "}\n"
return result
def get_instrhelper():
db = get_db_dxil()
gen = db_instrhelp_gen(db)
return run_with_stdout(lambda: gen.print_body())
def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
db = get_db_dxil()
if type(pred) == str:
pred_fn = lambda i: getattr(i, pred)
else:
pred_fn = pred
llvm_instrs = [i for i in db.instr if pred_fn(i)]
result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(getattr(i, attr_name)) for i in llvm_instrs]))
result += "return %s;" % build_range_code(varname, [getattr(i, attr_name) for i in llvm_instrs])
result += "\n"
return result
def get_instrs_rst():
"Create an rst table of allowed LLVM instructions."
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
instrs = sorted(instrs, key=lambda v : v.llvm_id)
rows = []
rows.append(["Instruction", "Action", "Operand overloads"])
for i in instrs:
rows.append([i.name, i.doc, i.oload_types])
result = "\n\n" + format_rst_table(rows) + "\n\n"
# Add detailed instruction information where available.
for i in instrs:
if i.remarks:
result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
return result + "\n"
def get_init_passes():
"Create a series of statements to initialize passes in a registry."
db = get_db_dxil()
result = ""
for p in sorted(db.passes, key=lambda p : p.type_name):
result += "initialize%sPass(Registry);\n" % p.type_name
return result
def get_pass_arg_names():
"Return an ArrayRef of argument names based on passName"
db = get_db_dxil()
decl_result = ""
check_result = ""
for p in sorted(db.passes, key=lambda p : p.type_name):
if len(p.args):
decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
sep = ""
for a in p.args:
decl_result += sep + "\"%s\"" % a.name
sep = ", "
decl_result += " };\n"
return decl_result + check_result
def get_pass_arg_descs():
"Return an ArrayRef of argument descriptions based on passName"
db = get_db_dxil()
decl_result = ""
check_result = ""
for p in sorted(db.passes, key=lambda p : p.type_name):
if len(p.args):
decl_result += "static const LPCSTR %sArgs[] = { " % p.type_name
check_result += "if (strcmp(passName, \"%s\") == 0) return ArrayRef(%sArgs, _countof(%sArgs));\n" % (p.name, p.type_name, p.type_name)
sep = ""
for a in p.args:
decl_result += sep + "\"%s\"" % a.doc
sep = ", "
decl_result += " };\n"
return decl_result + check_result
def get_is_pass_option_name():
"Create a return expression to check whether a value 'S' is a pass option name."
db = get_db_dxil()
prefix = ""
result = "return "
for k in sorted(db.pass_idx_args):
result += prefix + "S.equals(\"%s\")" % k
prefix = "\n || "
return result + ";"
def get_opcodes_rst():
"Create an rst table of opcodes"
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
instrs = sorted(instrs, key=lambda v : v.dxil_opid)
rows = []
rows.append(["ID", "Name", "Description"])
for i in instrs:
op_name = i.dxil_op
if i.remarks:
op_name = op_name + "_" # append _ to enable internal hyperlink on rst files
rows.append([i.dxil_opid, op_name, i.doc])
result = "\n\n" + format_rst_table(rows) + "\n\n"
# Add detailed instruction information where available.
instrs = sorted(instrs, key=lambda v : v.name)
for i in instrs:
if i.remarks:
result += i.name + "\n" + ("~" * len(i.name)) + "\n\n" + i.remarks + "\n\n"
return result + "\n"
def get_valrules_rst():
"Create an rst table of validation rules instructions."
db = get_db_dxil()
rules = [i for i in db.val_rules if not i.is_disabled]
rules = sorted(rules, key=lambda v : v.name)
rows = []
rows.append(["Rule Code", "Description"])
for i in rules:
rows.append([i.name, i.doc])
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_opsigs():
# Create a list of DXIL operation signatures, sorted by ID.
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_dxil_op]
instrs = sorted(instrs, key=lambda v : v.dxil_opid)
# db_dxil already asserts that the numbering is dense.
# Create the code to write out.
code = "static const char *OpCodeSignatures[] = {\n"
for inst_idx,i in enumerate(instrs):
code += " \"("
for operand in i.ops:
if operand.pos > 1: # skip 0 (the return value) and 1 (the opcode itself)
code += operand.name
if operand.pos < len(i.ops) - 1:
code += ","
code += ")\""
if inst_idx < len(instrs) - 1:
code += ","
code += " // " + i.name
code += "\n"
code += "};\n"
return code
def get_valopcode_sm_text():
db = get_db_dxil()
instrs = [i for i in db.instr if i.is_dxil_op]
instrs = sorted(instrs, key=lambda v : (v.shader_model, v.shader_stages, v.dxil_opid))
last_model = None
last_stage = None
grouped_instrs = []
code = ""
def flush_instrs(grouped_instrs, last_model, last_stage):
if len(grouped_instrs) == 0:
return ""
result = format_comment("// ", "Instructions: %s" % ", ".join([i.name + "=" + str(i.dxil_opid) for i in grouped_instrs]))
result += "if (" + build_range_code("op", [i.dxil_opid for i in grouped_instrs]) + ")\n"
result += " return "
model_cond = stage_cond = None
if last_model != (6,0):
model_cond = "pSM->GetMajor() > %d || (pSM->GetMajor() == %d && pSM->GetMinor() >= %d)" % (
last_model[0], last_model[0], last_model[1])
if last_stage != "*":
stage_cond = ' || '.join(["pSM->Is%sS()" % c.upper() for c in last_stage])
if model_cond or stage_cond:
result += '\n && '.join(
["(%s)" % expr for expr in (model_cond, stage_cond) if expr] )
return result + ";\n"
else:
# don't write these out, instead fall through
return ""
for i in instrs:
if (i.shader_model, i.shader_stages) != (last_model, last_stage):
code += flush_instrs(grouped_instrs, last_model, last_stage)
grouped_instrs = []
last_model = i.shader_model
last_stage = i.shader_stages
grouped_instrs.append(i)
code += flush_instrs(grouped_instrs, last_model, last_stage)
code += "return true;\n"
return code
def get_sigpoint_table():
db = get_db_dxil()
gen = db_sigpoint_gen(db)
return run_with_stdout(lambda: gen.print_sigpoint_table())
def get_sigpoint_rst():
"Create an rst table for SigPointKind."
db = get_db_dxil()
rows = [row[:] for row in db.sigpoint_table[:-1]] # Copy table
e = dict([(v.name, v) for v in db.enum_idx['SigPointKind'].values])
rows[0] = ['ID'] + rows[0] + ['Description']
for i in range(1, len(rows)):
row = rows[i]
v = e[row[0]]
rows[i] = [v.value] + row + [v.doc]
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_sem_interpretation_enum_rst():
db = get_db_dxil()
rows = ([['ID', 'Name', 'Description']] +
[[v.value, v.name, v.doc]
for v in db.enum_idx['SemanticInterpretationKind'].values[:-1]])
return "\n\n" + format_rst_table(rows) + "\n\n"
def get_sem_interpretation_table_rst():
db = get_db_dxil()
return "\n\n" + format_rst_table(db.interpretation_table) + "\n\n"
def get_interpretation_table():
db = get_db_dxil()
gen = db_sigpoint_gen(db)
return run_with_stdout(lambda: gen.print_interpretation_table())
def RunCodeTagUpdate(file_path):
import os
import CodeTags
print(" ... updating " + file_path)
args = [file_path, file_path + ".tmp"]
result = CodeTags.main(args)
if result != 0:
print(" ... error: %d" % result)
else:
with open(file_path, 'rt') as f:
before = f.read()
with open(file_path + ".tmp", 'rt') as f:
after = f.read()
if before == after:
print(" --- no changes found")
else:
print(" +++ changes found, updating file")
with open(file_path, 'wt') as f:
f.write(after)
os.remove(file_path + ".tmp")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate code to handle instructions.")
parser.add_argument("-gen", choices=["docs-ref", "docs-spec", "inst-header", "enums", "oloads", "valfns"], help="Output type to generate.")
parser.add_argument("-update-files", action="store_const", const=True)
args = parser.parse_args()
db = get_db_dxil() # used by all generators, also handy to have it run validation
if args.gen == "docs-ref":
gen = db_docsref_gen(db)
gen.print_content()
if args.gen == "docs-spec":
import os, docutils.core
assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
spec_file = os.path.abspath(os.path.join(hlsl_src_dir, "docs/DXIL.rst"))
with open(spec_file) as f:
s = docutils.core.publish_file(f, writer_name="html")
if args.gen == "inst-header":
gen = db_instrhelp_gen(db)
gen.print_content()
if args.gen == "enums":
gen = db_enumhelp_gen(db)
gen.print_content()
if args.gen == "oloads":
gen = db_oload_gen(db)
gen.print_content()
if args.gen == "valfns":
gen = db_valfns_gen(db)
gen.print_content()
if args.update_files:
print("Updating files ...")
import CodeTags
import os
assert "HLSL_SRC_DIR" in os.environ, "Environment variable HLSL_SRC_DIR is not defined"
hlsl_src_dir = os.environ["HLSL_SRC_DIR"]
pj = lambda *parts: os.path.abspath(os.path.join(*parts))
files = [
'docs/DXIL.rst',
'lib/HLSL/DXILOperations.cpp',
'include/dxc/HLSL/DXILConstants.h',
'include/dxc/HLSL/DxilValidation.h',
'include/dxc/HLSL/DxilInstructions.h',
'lib/HLSL/DxcOptimizer.cpp',
'lib/HLSL/DxilValidation.cpp',
'tools/clang/lib/Sema/gen_intrin_main_tables_15.h',
'include/dxc/HlslIntrinsicOp.h',
'tools/clang/tools/dxcompiler/dxcdisassembler.cpp',
'include/dxc/HLSL/DxilSigPoint.inl',
]
for relative_file_path in files:
RunCodeTagUpdate(pj(hlsl_src_dir, relative_file_path))