1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264 |
- ///////////////////////////////////////////////////////////////////////////////
- // //
- // HLOperationLowerExtension.cpp //
- // Copyright (C) Microsoft Corporation. All rights reserved. //
- // This file is distributed under the University of Illinois Open Source //
- // License. See LICENSE.TXT for details. //
- // //
- ///////////////////////////////////////////////////////////////////////////////
- #include "dxc/HLSL/HLOperationLowerExtension.h"
- #include "dxc/DXIL/DxilModule.h"
- #include "dxc/DXIL/DxilOperations.h"
- #include "dxc/HLSL/HLModule.h"
- #include "dxc/HLSL/HLOperationLower.h"
- #include "dxc/HLSL/HLOperations.h"
- #include "dxc/HlslIntrinsicOp.h"
- #include "llvm/ADT/StringRef.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/Module.h"
- #include "llvm/Support/raw_os_ostream.h"
- #include "llvm/Support/YAMLParser.h"
- #include "llvm/Support/SourceMgr.h"
- #include "llvm/ADT/SmallString.h"
- using namespace llvm;
- using namespace hlsl;
- LLVM_ATTRIBUTE_NORETURN static void ThrowExtensionError(StringRef Details)
- {
- std::string Msg = (Twine("Error in dxc extension api: ") + Details).str();
- throw hlsl::Exception(DXC_E_EXTENSION_ERROR, Msg);
- }
- // The lowering strategy format is a string that matches the following regex:
- //
- // [a-z](:(?P<ExtraStrategyInfo>.+))?$
- //
- // The first character indicates the strategy with an optional : followed by
- // additional lowering information specific to that strategy.
- //
- ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
- if (strategy.size() < 1)
- return Strategy::Unknown;
- switch (strategy[0]) {
- case 'n': return Strategy::NoTranslation;
- case 'r': return Strategy::Replicate;
- case 'p': return Strategy::Pack;
- case 'm': return Strategy::Resource;
- case 'd': return Strategy::Dxil;
- default: break;
- }
- return Strategy::Unknown;
- }
- llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
- switch (strategy) {
- case Strategy::NoTranslation: return "n";
- case Strategy::Replicate: return "r";
- case Strategy::Pack: return "p";
- case Strategy::Resource: return "m"; // m for resource method
- case Strategy::Dxil: return "d";
- default: break;
- }
- return "?";
- }
- static std::string ParseExtraStrategyInfo(StringRef strategy)
- {
- std::pair<StringRef, StringRef> SplitInfo = strategy.split(":");
- return SplitInfo.second;
- }
- ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
- : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp), m_hlResourceLookup(hlResourceLookup)
- {}
- ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
- : ExtensionLowering(GetStrategy(strategy), helper, hlslOp, hlResourceLookup)
- {
- m_extraStrategyInfo = ParseExtraStrategyInfo(strategy);
- }
- llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
- switch (m_strategy) {
- case Strategy::NoTranslation: return NoTranslation(CI);
- case Strategy::Replicate: return Replicate(CI);
- case Strategy::Pack: return Pack(CI);
- case Strategy::Resource: return Resource(CI);
- case Strategy::Dxil: return Dxil(CI);
- default: break;
- }
- return Unknown(CI);
- }
- llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
- assert(false && "unknown translation strategy");
- return nullptr;
- }
- // Interface to describe how to translate types from HL-dxil to dxil.
- class FunctionTypeTranslator {
- public:
- // Arguments can be exploded into multiple copies of the same type.
- // For example a <2 x i32> could become { i32, 2 } if the vector
- // is expanded in place or { i32, 1 } if the call is replicated.
- struct ArgumentType {
- Type *type;
- int count;
- ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
- };
- virtual ~FunctionTypeTranslator() {}
- virtual Type *TranslateReturnType(CallInst *CI) = 0;
- virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
- };
- // Class to create the new function with the translated types for low-level dxil.
- class FunctionTranslator {
- public:
- template <typename TypeTranslator>
- static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
- TypeTranslator typeTranslator;
- return GetLoweredFunction(typeTranslator, CI, lower);
- }
-
- static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
- FunctionTranslator translator(typeTranslator, lower);
- return translator.GetLoweredFunction(CI);
- }
- virtual ~FunctionTranslator() {}
- protected:
- FunctionTypeTranslator &m_typeTranslator;
- ExtensionLowering &m_lower;
- FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
- : m_typeTranslator(typeTranslator)
- , m_lower(lower)
- {}
- Function *GetLoweredFunction(CallInst *CI) {
- // Ge the return type of replicated function.
- Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
- if (!RetTy)
- return nullptr;
- // Get the Function type for replicated function.
- FunctionType *FTy = GetFunctionType(CI, RetTy);
- if (!FTy)
- return nullptr;
- // Create a new function that will be the replicated call.
- AttributeSet attributes = GetAttributeSet(CI);
- std::string name = m_lower.GetExtensionName(CI);
- return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
- }
- virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
- // Create a new function type with the translated argument.
- SmallVector<Type *, 10> ParamTypes;
- ParamTypes.reserve(CI->getNumArgOperands());
- for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
- Value *OrigArg = CI->getArgOperand(i);
- FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
- for (int i = 0; i < newArgType.count; ++i) {
- ParamTypes.push_back(newArgType.type);
- }
- }
- const bool IsVarArg = false;
- return FunctionType::get(RetTy, ParamTypes, IsVarArg);
- }
- AttributeSet GetAttributeSet(CallInst *CI) {
- Function *F = CI->getCalledFunction();
- AttributeSet attributes;
- auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
- if (F->hasFnAttribute(a)) {
- attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
- }
- };
- copyAttribute(Attribute::AttrKind::ReadOnly);
- copyAttribute(Attribute::AttrKind::ReadNone);
- copyAttribute(Attribute::AttrKind::ArgMemOnly);
- return attributes;
- }
- };
- ///////////////////////////////////////////////////////////////////////////////
- // NoTranslation Lowering.
- class NoTranslationTypeTranslator : public FunctionTypeTranslator {
- virtual Type *TranslateReturnType(CallInst *CI) override {
- return CI->getType();
- }
- virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
- return ArgumentType(OrigArg->getType());
- }
- };
- llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
- Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
- if (!NoTranslationFunction)
- return nullptr;
- IRBuilder<> builder(CI);
- SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
- return builder.CreateCall(NoTranslationFunction, args);
- }
- ///////////////////////////////////////////////////////////////////////////////
- // Replicated Lowering.
- enum {
- NO_COMMON_VECTOR_SIZE = 0x0,
- };
- // Find the vector size that will be used for replication.
- // The function call will be replicated once for each element of the vector
- // size.
- static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
- unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
- Type *RetTy = CI->getType();
- if (RetTy->isVectorTy())
- commonVectorSize = RetTy->getVectorNumElements();
- for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
- Type *Ty = CI->getArgOperand(i)->getType();
- if (Ty->isVectorTy()) {
- unsigned vectorSize = Ty->getVectorNumElements();
- if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
- // Inconsistent vector sizes; need a different strategy.
- return NO_COMMON_VECTOR_SIZE;
- }
- commonVectorSize = vectorSize;
- }
- }
- return commonVectorSize;
- }
- class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
- virtual Type *TranslateReturnType(CallInst *CI) override {
- unsigned commonVectorSize = GetReplicatedVectorSize(CI);
- if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
- return nullptr;
- // Result should be vector or void.
- Type *RetTy = CI->getType();
- if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
- return nullptr;
- if (RetTy->isVectorTy()) {
- RetTy = RetTy->getVectorElementType();
- }
- return RetTy;
- }
- virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
- Type *Ty = OrigArg->getType();
- if (Ty->isVectorTy()) {
- Ty = Ty->getVectorElementType();
- }
- return ArgumentType(Ty);
- }
- };
- class ReplicateCall {
- public:
- ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
- : m_CI(CI)
- , m_ReplicatedFunction(ReplicatedFunction)
- , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
- , m_ScalarizeArgIdx()
- , m_Args(CI->getNumArgOperands())
- , m_ReplicatedCalls(m_numReplicatedCalls)
- , m_Builder(CI)
- {
- assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
- }
- Value *Generate() {
- CollectReplicatedArguments();
- CreateReplicatedCalls();
- Value *retVal = GetReturnValue();
- return retVal;
- }
- private:
- CallInst *m_CI;
- Function &m_ReplicatedFunction;
- unsigned m_numReplicatedCalls;
- SmallVector<unsigned, 10> m_ScalarizeArgIdx;
- SmallVector<Value *, 10> m_Args;
- SmallVector<Value *, 10> m_ReplicatedCalls;
- IRBuilder<> m_Builder;
- // Collect replicated arguments.
- // For non-vector arguments we can add them to the args list directly.
- // These args will be shared by each replicated call. For the vector
- // arguments we remember the position it will go in the argument list.
- // We will fill in the vector args below when we replicate the call
- // (once for each vector lane).
- void CollectReplicatedArguments() {
- for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
- Type *Ty = m_CI->getArgOperand(i)->getType();
- if (Ty->isVectorTy()) {
- m_ScalarizeArgIdx.push_back(i);
- }
- else {
- m_Args[i] = m_CI->getArgOperand(i);
- }
- }
- }
- // Create replicated calls.
- // Replicate the call once for each element of the replicated vector size.
- void CreateReplicatedCalls() {
- for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
- for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
- unsigned argIdx = m_ScalarizeArgIdx[i];
- Value *arg = m_CI->getArgOperand(argIdx);
- m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
- }
- Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
- m_ReplicatedCalls[vecIdx] = EltOP;
- }
- }
- // Get the final replicated value.
- // If the function is a void type then return (arbitrarily) the first call.
- // We do not return nullptr because that indicates a failure to replicate.
- // If the function is a vector type then aggregate all of the replicated
- // call values into a new vector.
- Value *GetReturnValue() {
- if (m_CI->getType()->isVoidTy())
- return m_ReplicatedCalls.back();
- Value *retVal = llvm::UndefValue::get(m_CI->getType());
- for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
- retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
- return retVal;
- }
- };
- // Translate the HL call by replicating the call for each vector element.
- //
- // For example,
- //
- // <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
- // ==>
- // %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
- // %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
- // <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
- // <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
- //
- // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
- Value *ExtensionLowering::Replicate(CallInst *CI) {
- Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
- if (!ReplicatedFunction)
- return NoTranslation(CI);
- ReplicateCall replicate(CI, *ReplicatedFunction);
- return replicate.Generate();
- }
- ///////////////////////////////////////////////////////////////////////////////
- // Packed Lowering.
- class PackCall {
- public:
- PackCall(CallInst *CI, Function &PackedFunction)
- : m_CI(CI)
- , m_packedFunction(PackedFunction)
- , m_builder(CI)
- {}
- Value *Generate() {
- SmallVector<Value *, 10> args;
- PackArgs(args);
- Value *result = CreateCall(args);
- return UnpackResult(result);
- }
-
- static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
- assert(vecTy->isVectorTy());
- Type *elementTy = vecTy->getVectorElementType();
- unsigned numElements = vecTy->getVectorNumElements();
- SmallVector<Type *, 4> elements;
- for (unsigned i = 0; i < numElements; ++i)
- elements.push_back(elementTy);
- return StructType::get(vecTy->getContext(), elements);
- }
- private:
- CallInst *m_CI;
- Function &m_packedFunction;
- IRBuilder<> m_builder;
- void PackArgs(SmallVectorImpl<Value*> &args) {
- args.clear();
- for (Value *arg : m_CI->arg_operands()) {
- if (arg->getType()->isVectorTy())
- arg = PackVectorIntoStruct(m_builder, arg);
- args.push_back(arg);
- }
- }
- Value *CreateCall(const SmallVectorImpl<Value*> &args) {
- return m_builder.CreateCall(&m_packedFunction, args);
- }
- Value *UnpackResult(Value *result) {
- if (result->getType()->isStructTy()) {
- result = PackStructIntoVector(m_builder, result);
- }
- return result;
- }
- static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
- assert(structTy->isStructTy());
- return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
- }
- static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
- StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
- Value *packed = UndefValue::get(structTy);
- unsigned numElements = structTy->getStructNumElements();
- for (unsigned i = 0; i < numElements; ++i) {
- Value *element = builder.CreateExtractElement(vec, i);
- packed = builder.CreateInsertValue(packed, element, { i });
- }
- return packed;
- }
- static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
- Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
- Value *packed = UndefValue::get(vecTy);
- unsigned numElements = vecTy->getVectorNumElements();
- for (unsigned i = 0; i < numElements; ++i) {
- Value *element = builder.CreateExtractValue(strukt, i);
- packed = builder.CreateInsertElement(packed, element, i);
- }
- return packed;
- }
- };
- class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
- virtual Type *TranslateReturnType(CallInst *CI) override {
- return TranslateIfVector(CI->getType());
- }
- virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
- return ArgumentType(TranslateIfVector(OrigArg->getType()));
- }
- Type *TranslateIfVector(Type *ty) {
- if (ty->isVectorTy())
- ty = PackCall::ConvertVectorTypeToStructType(ty);
- return ty;
- }
- };
- Value *ExtensionLowering::Pack(CallInst *CI) {
- Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
- if (!PackedFunction)
- return NoTranslation(CI);
- PackCall pack(CI, *PackedFunction);
- Value *result = pack.Generate();
- return result;
- }
- ///////////////////////////////////////////////////////////////////////////////
- // Resource Lowering.
- // Modify a call to a resouce method. Makes the following transformation:
- //
- // 1. Convert non-void return value to dx.types.ResRet.
- // 2. Expand vectors in place as separate arguments.
- //
- // Example
- // -----------------------------------------------------------------------------
- //
- // %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
- // %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
- // %x = extractvalue %r, 0
- // %y = extractvalue %r, 1
- // %v = <2 x float> undef
- // %v.1 = insertelement %v, %x, 0
- // %v.2 = insertelement %v.1, %y, 1
- class ResourceMethodCall {
- public:
- ResourceMethodCall(CallInst *CI)
- : m_CI(CI)
- , m_builder(CI)
- { }
- virtual ~ResourceMethodCall() {}
- virtual Value *Generate(Function *explodedFunction) {
- SmallVector<Value *, 16> args;
- ExplodeArgs(args);
- Value *result = CreateCall(explodedFunction, args);
- result = ConvertResult(result);
- return result;
- }
-
- protected:
- CallInst *m_CI;
- IRBuilder<> m_builder;
- void ExplodeArgs(SmallVectorImpl<Value*> &args) {
- for (Value *arg : m_CI->arg_operands()) {
- // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
- if (arg->getType()->isVectorTy()) {
- for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
- Value *xarg = m_builder.CreateExtractElement(arg, i);
- args.push_back(xarg);
- }
- }
- // any other value: arg -> arg
- else {
- args.push_back(arg);
- }
- }
- }
- Value *CreateCall(Function *explodedFunction, ArrayRef<Value*> args) {
- return m_builder.CreateCall(explodedFunction, args);
- }
- Value *ConvertResult(Value *result) {
- Type *origRetTy = m_CI->getType();
- if (origRetTy->isVoidTy())
- return ConvertVoidResult(result);
- else if (origRetTy->isVectorTy())
- return ConvertVectorResult(origRetTy, result);
- else
- return ConvertScalarResult(origRetTy, result);
- }
- // Void result does not need any conversion.
- Value *ConvertVoidResult(Value *result) {
- return result;
- }
- // Vector result will be populated with the elements from the resource return.
- Value *ConvertVectorResult(Type *origRetTy, Value *result) {
- Type *resourceRetTy = result->getType();
- assert(origRetTy->isVectorTy());
- assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
-
- const unsigned vectorSize = origRetTy->getVectorNumElements();
- const unsigned structSize = resourceRetTy->getStructNumElements();
- const unsigned size = std::min(vectorSize, structSize);
- assert(vectorSize < structSize);
-
- // Copy resource struct elements to vector.
- Value *vector = UndefValue::get(origRetTy);
- for (unsigned i = 0; i < size; ++i) {
- Value *element = m_builder.CreateExtractValue(result, { i });
- vector = m_builder.CreateInsertElement(vector, element, i);
- }
- return vector;
- }
- // Scalar result will be populated with the first element of the resource return.
- Value *ConvertScalarResult(Type *origRetTy, Value *result) {
- assert(origRetTy->isSingleValueType());
- return m_builder.CreateExtractValue(result, { 0 });
- }
- };
- // Translate function return and argument types for resource method lowering.
- class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
- public:
- ResourceFunctionTypeTranslator(OP &hlslOp) : m_hlslOp(hlslOp) {}
- // Translate return type as follows:
- //
- // void -> void
- // <N x ty> -> dx.types.ResRet.ty
- // ty -> dx.types.ResRet.ty
- virtual Type *TranslateReturnType(CallInst *CI) override {
- Type *RetTy = CI->getType();
- if (RetTy->isVoidTy())
- return RetTy;
- else if (RetTy->isVectorTy())
- RetTy = RetTy->getVectorElementType();
- return m_hlslOp.GetResRetType(RetTy);
- }
-
- // Translate argument type as follows:
- //
- // resource -> dx.types.Handle
- // <N x ty> -> { ty, N }
- // ty -> { ty, 1 }
- virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
- int count = 1;
- Type *ty = OrigArg->getType();
- if (ty->isVectorTy()) {
- count = ty->getVectorNumElements();
- ty = ty->getVectorElementType();
- }
- return ArgumentType(ty, count);
- }
- private:
- OP& m_hlslOp;
- };
- Value *ExtensionLowering::Resource(CallInst *CI) {
- // Extra strategy info overrides the default lowering for resource methods.
- if (!m_extraStrategyInfo.empty())
- {
- return CustomResource(CI);
- }
- ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
- Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
- if (!resourceFunction)
- return NoTranslation(CI);
- ResourceMethodCall explode(CI);
- Value *result = explode.Generate(resourceFunction);
- return result;
- }
- // This class handles the core logic for custom lowering of resource
- // method intrinsics. The goal is to allow resource extension intrinsics
- // to be handled the same way as the core hlsl resource intrinsics.
- //
- // Specifically, we want to support:
- //
- // 1. Multiple hlsl overloads map to a single dxil intrinsic
- // 2. The hlsl overloads can take different parameters for a given resource type
- // 3. The hlsl overloads are not consistent across different resource types
- //
- // To achieve these goals we need a more complex mechanism for describing how
- // to translate the high-level arguments to arguments for a dxil function.
- // The custom lowering info describes this lowering using the following format.
- //
- // [Custom Lowering Info Format]
- // A json string encoding a map where each key is either a specific resource type or
- // the keyword "default" to be used for any other resource. The value is a
- // a custom-format string encoding how high-level arguments are mapped to
- // dxil intrinsic arguments.
- //
- // [Argument Translation Format]
- // A comma separated string where the number of fields is exactly equal to the number
- // of parameters in the target dxil intrinsic. Each field describes how to generate
- // the argument for that dxil intrinsic parameter. It has the following format where
- // the hl_arg_index is mandatory, but the other two parts are optional.
- //
- // <hl_arg_index>.<vector_index>:<optional_type_info>
- //
- // The format is precisely described by the following regular expression:
- //
- // (?P<hl_arg_index>[-0-9]+)(.(?P<vector_index>[-0-9]+))?(:(?P<optional_type_info>\?i32|\?i16|\?i8|\?float|\?half))?$
- //
- // Example
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Say we want to define the MyTextureOp extension with the following overloads:
- //
- // Texture1D
- // MyTextureOp(uint addr, uint offset)
- // MyTextureOp(uint addr, uint offset, uint val)
- //
- // Texture2D
- // MyTextureOp(uint2 addr, uint2 val)
- //
- // And a dxil intrinsic defined as follows
- // @MyTextureOp(i32 opcode, %dx.types.Handle handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1)
- //
- // Then we would define the lowering info json as follows
- //
- // {
- // "default" : "0, 1, 2.0, 2.1, 3 , 4.0:?i32, 4.1:?i32"
- // "Texture2D" : "0, 1, 2.0, 2.1, -1:?i32, 3.0 , 3.1\"
- // }
- //
- //
- // This would produce the following lowerings (assuming the MyTextureOp opcode is 17)
- //
- // hlsl: Texture1D.MyTextureOp(a, b)
- // hl: @MyTextureOp(17, handle, a, b)
- // dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
- //
- // hlsl: Texture1D.MyTextureOp(a, b, c)
- // hl: @MyTextureOp(17, handle, a, b, c)
- // dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
- //
- // hlsl: Texture2D.MyTextureOp(a, c)
- // hl: @MyTextureOp(17, handle, a, c)
- // dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
- //
- //
- class CustomResourceLowering
- {
- public:
- CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
- {
- // Parse lowering info json format.
- std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
- ParseLoweringInfo(LoweringInfo, CI->getContext());
- // Lookup resource kind based on handle (first arg after hl opcode)
- enum {RESOURCE_HANDLE_ARG=1};
- const char *pName = nullptr;
- if (!ResourceLookup.GetResourceKindName(CI->getArgOperand(RESOURCE_HANDLE_ARG), &pName))
- {
- ThrowExtensionError("Failed to find resource from handle");
- }
- std::string Name(pName);
- // Select lowering info to use based on resource kind.
- const char *DefaultInfoName = "default";
- std::vector<DxilArgInfo> *pArgInfo = nullptr;
- if (LoweringInfoMap.count(Name))
- {
- pArgInfo = &LoweringInfoMap.at(Name);
- }
- else if (LoweringInfoMap.count(DefaultInfoName))
- {
- pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
- }
- else
- {
- ThrowExtensionError("Unable to find lowering info for resource");
- }
- GenerateLoweredArgs(CI, *pArgInfo);
- }
- const std::vector<Value *> &GetLoweredArgs() const
- {
- return m_LoweredArgs;
- }
- private:
- struct OptionalTypeSpec
- {
- const char* TypeName;
- Type *LLVMType;
- };
- // These are the supported optional types for generating dxil parameters
- // that have no matching argument in the high-level intrinsic overload.
- // See [Argument Translation Format] for details.
- void InitOptionalTypes(LLVMContext &Ctx)
- {
- // Table of supported optional types.
- // Keep in sync with m_OptionalTypes small vector size to avoid
- // dynamic allocation.
- OptionalTypeSpec OptionalTypes[] = {
- {"?i32", Type::getInt32Ty(Ctx)},
- {"?float", Type::getFloatTy(Ctx)},
- {"?half", Type::getHalfTy(Ctx)},
- {"?i8", Type::getInt8Ty(Ctx)},
- {"?i16", Type::getInt16Ty(Ctx)},
- };
- DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
- m_OptionalTypes.clear();
- m_OptionalTypes.reserve(_countof(OptionalTypes));
- for (const OptionalTypeSpec &T : OptionalTypes)
- {
- m_OptionalTypes.push_back(T);
- }
- }
- Type *ParseOptionalType(StringRef OptionalTypeInfo)
- {
- if (OptionalTypeInfo.empty())
- {
- return nullptr;
- }
- for (OptionalTypeSpec &O : m_OptionalTypes)
- {
- if (OptionalTypeInfo == O.TypeName)
- {
- return O.LLVMType;
- }
- }
-
- ThrowExtensionError("Failed to parse optional type");
- }
-
- // Mapping from high level function arg to dxil function arg.
- //
- // The `HighLevelArgIndex` is the index of the function argument to
- // which this dxil argument maps.
- //
- // If `HasVectorIndex` is true then the `VectorIndex` contains the
- // index of the element in the vector pointed to by HighLevelArgIndex.
- //
- // The `OptionalType` is used to specify types for arguments that are not
- // present in all overloads of the high level function. This lets us
- // map multiple high level functions to a single dxil extension intrinsic.
- //
- struct DxilArgInfo
- {
- unsigned HighLevelArgIndex = 0;
- unsigned VectorIndex = 0;
- bool HasVectorIndex = false;
- Type *OptionalType = nullptr;
- };
- typedef std::string ResourceKindName;
- // Convert the lowering info to a machine-friendly format.
- // Note that we use the YAML parser to parse the JSON since JSON
- // is a subset of YAML (and this llvm has no JSON parser).
- //
- // See [Custom Lowering Info Format] for details.
- std::map<ResourceKindName, std::vector<DxilArgInfo>> ParseLoweringInfo(StringRef LoweringInfo, LLVMContext &Ctx)
- {
- InitOptionalTypes(Ctx);
- std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap;
- SourceMgr SM;
- yaml::Stream YAMLStream(LoweringInfo, SM);
- // Make sure we have a valid json input.
- llvm::yaml::document_iterator I = YAMLStream.begin();
- if (I == YAMLStream.end()) {
- ThrowExtensionError("Found empty resource lowering JSON.");
- }
- llvm::yaml::Node *Root = I->getRoot();
- if (!Root) {
- ThrowExtensionError("Error parsing resource lowering JSON.");
- }
- // Parse the top level map object.
- llvm::yaml::MappingNode *Object = dyn_cast<llvm::yaml::MappingNode>(Root);
- if (!Object) {
- ThrowExtensionError("Expected map in top level of resource lowering JSON.");
- }
- // Parse all key/value pairs from the map.
- for (llvm::yaml::MappingNode::iterator KVI = Object->begin(),
- KVE = Object->end();
- KVI != KVE; ++KVI)
- {
- // Parse key.
- llvm::yaml::ScalarNode *KeyString =
- dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getKey());
- if (!KeyString) {
- ThrowExtensionError("Expected string as key in resource lowering info JSON map.");
- }
- SmallString<32> KeyStorage;
- StringRef Key = KeyString->getValue(KeyStorage);
- // Parse value.
- llvm::yaml::ScalarNode *ValueString =
- dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getValue());
- if (!ValueString) {
- ThrowExtensionError("Expected string as value in resource lowering info JSON map.");
- }
- SmallString<128> ValueStorage;
- StringRef Value = ValueString->getValue(ValueStorage);
- // Parse dxil arg info from value.
- LoweringInfoMap[Key] = ParseDxilArgInfo(Value, Ctx);
- }
- return LoweringInfoMap;
- }
- // Parse the dxail argument translation info.
- // See [Argument Translation Format] for details.
- std::vector<DxilArgInfo> ParseDxilArgInfo(StringRef ArgSpec, LLVMContext &Ctx)
- {
- std::vector<DxilArgInfo> Args;
- SmallVector<StringRef, 14> Splits;
- ArgSpec.split(Splits, ",");
- for (const StringRef Split : Splits)
- {
- StringRef Field = Split.trim();
- StringRef HighLevelArgInfo;
- StringRef OptionalTypeInfo;
- std::tie(HighLevelArgInfo, OptionalTypeInfo) = Field.split(":");
- Type *OptionalType = ParseOptionalType(OptionalTypeInfo);
- StringRef HighLevelArgIndex;
- StringRef VectorIndex;
- std::tie(HighLevelArgIndex, VectorIndex) = HighLevelArgInfo.split(".");
- // Parse the arg and vector index.
- // Parse the values as signed integers, but store them as unsigned values to
- // allows using -1 as a shorthand for the max value.
- DxilArgInfo ArgInfo;
- ArgInfo.HighLevelArgIndex = static_cast<unsigned>(std::stoi(HighLevelArgIndex));
- if (!VectorIndex.empty())
- {
- ArgInfo.HasVectorIndex = true;
- ArgInfo.VectorIndex = static_cast<unsigned>(std::stoi(VectorIndex));
- }
- ArgInfo.OptionalType = OptionalType;
- Args.push_back(ArgInfo);
- }
- return Args;
- }
- // Create the dxil args based on custom lowering info.
- void GenerateLoweredArgs(CallInst *CI, const std::vector<DxilArgInfo> &ArgInfoRecords)
- {
- IRBuilder<> builder(CI);
- for (const DxilArgInfo &ArgInfo : ArgInfoRecords)
- {
- // Check to see if we have the corresponding high-level arg in the overload for this call.
- if (ArgInfo.HighLevelArgIndex < CI->getNumArgOperands())
- {
- Value *Arg = CI->getArgOperand(ArgInfo.HighLevelArgIndex);
- if (ArgInfo.HasVectorIndex)
- {
- // We expect a vector type here, but we handle one special case if not.
- if (Arg->getType()->isVectorTy())
- {
- // We allow multiple high-level overloads to map to a single dxil extension function.
- // If the vector index is invalid for this specific overload then use an undef
- // value as a replacement.
- if (ArgInfo.VectorIndex < Arg->getType()->getVectorNumElements())
- {
- Arg = builder.CreateExtractElement(Arg, ArgInfo.VectorIndex);
- }
- else
- {
- Arg = UndefValue::get(Arg->getType()->getVectorElementType());
- }
- }
- else
- {
- // If it is a non-vector type then we replace non-zero vector index with
- // undef. This is to handle hlsl intrinsic overloading rules that allow
- // scalars in place of single-element vectors. We assume here that a non-vector
- // means that a single element vector was already scalarized.
- //
- if (ArgInfo.VectorIndex > 0)
- {
- Arg = UndefValue::get(Arg->getType());
- }
- }
- }
- m_LoweredArgs.push_back(Arg);
- }
- else if (ArgInfo.OptionalType)
- {
- // If there was no matching high-level arg then we look for the optional
- // arg type specified by the lowering info.
- m_LoweredArgs.push_back(UndefValue::get(ArgInfo.OptionalType));
- }
- else
- {
- // No way to know how to generate the correc type for this dxil arg.
- ThrowExtensionError("Unable to map high-level arg to dxil arg");
- }
- }
- }
-
- std::vector<Value *> m_LoweredArgs;
- SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
- };
- // Boilerplate to reuse exising logic as much as possible.
- // We just want to overload GetFunctionType here.
- class CustomResourceFunctionTranslator : public FunctionTranslator {
- public:
- static Function *GetLoweredFunction(
- const CustomResourceLowering &CustomLowering,
- ResourceFunctionTypeTranslator &typeTranslator,
- CallInst *CI,
- ExtensionLowering &lower
- )
- {
- CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
- return T.FunctionTranslator::GetLoweredFunction(CI);
- }
- private:
- CustomResourceFunctionTranslator(
- const CustomResourceLowering &CustomLowering,
- ResourceFunctionTypeTranslator &typeTranslator,
- ExtensionLowering &lower
- )
- : FunctionTranslator(typeTranslator, lower)
- , m_CustomLowering(CustomLowering)
- {
- }
- virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) override {
- SmallVector<Type *, 16> ParamTypes;
- for (Value *V : m_CustomLowering.GetLoweredArgs())
- {
- ParamTypes.push_back(V->getType());
- }
- const bool IsVarArg = false;
- return FunctionType::get(RetTy, ParamTypes, IsVarArg);
- }
- private:
- const CustomResourceLowering &m_CustomLowering;
- };
- // Boilerplate to reuse exising logic as much as possible.
- // We just want to overload Generate here.
- class CustomResourceMethodCall : public ResourceMethodCall
- {
- public:
- CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
- : ResourceMethodCall(CI)
- , m_CustomLowering(CustomLowering)
- {}
- virtual Value *Generate(Function *loweredFunction) override {
- Value *result = CreateCall(loweredFunction, m_CustomLowering.GetLoweredArgs());
- result = ConvertResult(result);
- return result;
- }
- private:
- const CustomResourceLowering &m_CustomLowering;
- };
- // Support custom lowering logic for resource functions.
- Value *ExtensionLowering::CustomResource(CallInst *CI) {
- CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
- ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
- Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
- CustomLowering,
- ResourceTypeTranslator,
- CI,
- *this
- );
- if (!ResourceFunction)
- return NoTranslation(CI);
- CustomResourceMethodCall custom(CI, CustomLowering);
- Value *Result = custom.Generate(ResourceFunction);
- return Result;
- }
- ///////////////////////////////////////////////////////////////////////////////
- // Dxil Lowering.
- Value *ExtensionLowering::Dxil(CallInst *CI) {
- // Map the extension opcode to the corresponding dxil opcode.
- unsigned extOpcode = GetHLOpcode(CI);
- OP::OpCode dxilOpcode;
- if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
- return nullptr;
- // Find the dxil function based on the overload type.
- Type *overloadTy = OP::GetOverloadType(dxilOpcode, CI->getCalledFunction());
- Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
- // Update the opcode in the original call so we can just copy it below.
- // We are about to delete this call anyway.
- CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));
- // Create the new call.
- Value *result = nullptr;
- if (overloadTy->isVectorTy()) {
- ReplicateCall replicate(CI, *F);
- result = replicate.Generate();
- }
- else {
- IRBuilder<> builder(CI);
- SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
- result = builder.CreateCall(F, args);
- }
- return result;
- }
- ///////////////////////////////////////////////////////////////////////////////
- // Computing Extension Names.
- // Compute the name to use for the intrinsic function call once it is lowered to dxil.
- // First checks to see if we have a custom name from the codegen helper and if not
- // chooses a default name based on the lowergin strategy.
- class ExtensionName {
- public:
- ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
- : m_CI(CI)
- , m_strategy(strategy)
- , m_helper(helper)
- {}
- std::string Get() {
- std::string name;
- if (m_helper)
- name = GetCustomExtensionName(m_CI, *m_helper);
- if (!HasCustomExtensionName(name))
- name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
- return name;
- }
- private:
- CallInst *m_CI;
- ExtensionLowering::Strategy m_strategy;
- HLSLExtensionsCodegenHelper *m_helper;
- static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
- unsigned opcode = GetHLOpcode(CI);
- std::string name = helper.GetIntrinsicName(opcode);
- ReplaceOverloadMarkerWithTypeName(name, CI);
- return name;
- }
- static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
- return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
- }
- static bool HasCustomExtensionName(const std::string name) {
- return name.size() > 0;
- }
- typedef unsigned OverloadArgIndex;
- static constexpr OverloadArgIndex DefaultOverloadIndex = std::numeric_limits<OverloadArgIndex>::max();
- // Choose the (return value or argument) type that determines the overload type
- // for the intrinsic call.
- // If the overload arg index was explicitly specified (see ParseOverloadArgIndex)
- // then we use that arg to pick the overload name. Otherwise we pick a default
- // where we take the return type as the overload. If the return is void we
- // take the first (non-opcode) argument as the overload type.
- static Type *SelectOverloadSlot(CallInst *CI, OverloadArgIndex ArgIndex) {
- if (ArgIndex != DefaultOverloadIndex)
- {
- return CI->getArgOperand(ArgIndex)->getType();
- }
- Type *ty = CI->getType();
- if (ty->isVoidTy()) {
- if (CI->getNumArgOperands() > 1)
- ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
- }
- return ty;
- }
- static Type *GetOverloadType(CallInst *CI, OverloadArgIndex ArgIndex) {
- Type *ty = SelectOverloadSlot(CI, ArgIndex);
- if (ty->isVectorTy())
- ty = ty->getVectorElementType();
- return ty;
- }
- static std::string GetTypeName(Type *ty) {
- std::string typeName;
- llvm::raw_string_ostream os(typeName);
- ty->print(os);
- os.flush();
- return typeName;
- }
- static std::string GetOverloadTypeName(CallInst *CI, OverloadArgIndex ArgIndex) {
- Type *ty = GetOverloadType(CI, ArgIndex);
- return GetTypeName(ty);
- }
- // Parse the arg index out of the overload marker (if any).
- //
- // The function names use a $o to indicate that the function is overloaded
- // and we should replace $o with the overload type. The extension name can
- // explicitly set which arg to use for the overload type by adding a colon
- // and a number after the $o (e.g. $o:3 would say the overload type is
- // determined by parameter 3).
- //
- // If we find an arg index after the overload marker we update the size
- // of the marker to include the full parsed string size so that it can
- // be replaced with the selected overload type.
- //
- static OverloadArgIndex ParseOverloadArgIndex(
- const std::string& functionName,
- size_t OverloadMarkerStartIndex,
- size_t *pOverloadMarkerSize)
- {
- assert(OverloadMarkerStartIndex != std::string::npos);
- size_t StartIndex = OverloadMarkerStartIndex + *pOverloadMarkerSize;
- // Check if we have anything after the overload marker to parse.
- if (StartIndex >= functionName.size())
- {
- return DefaultOverloadIndex;
- }
- // Does it start with a ':' ?
- if (functionName[StartIndex] != ':')
- {
- return DefaultOverloadIndex;
- }
- // Skip past the :
- ++StartIndex;
- // Collect all the digits.
- std::string Digits;
- Digits.reserve(functionName.size() - StartIndex);
- for (size_t i = StartIndex; i < functionName.size(); ++i)
- {
- char c = functionName[i];
- if (!isdigit(c))
- {
- break;
- }
- Digits.push_back(c);
- }
- if (Digits.empty())
- {
- return DefaultOverloadIndex;
- }
- *pOverloadMarkerSize = *pOverloadMarkerSize + std::strlen(":") + Digits.size();
- return std::stoi(Digits);
- }
- // Find the occurence of the overload marker $o and replace it the the overload type name.
- static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
- const char *OverloadMarker = "$o";
- size_t OverloadMarkerLength = 2;
- size_t pos = functionName.find(OverloadMarker);
- if (pos != std::string::npos) {
- OverloadArgIndex ArgIndex = ParseOverloadArgIndex(functionName, pos, &OverloadMarkerLength);
- std::string typeName = GetOverloadTypeName(CI, ArgIndex);
- functionName.replace(pos, OverloadMarkerLength, typeName);
- }
- }
- };
- std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
- ExtensionName name(CI, m_strategy, m_helper);
- return name.Get();
- }
|