|
@@ -34,6 +34,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
|
|
|
case 'n': return Strategy::NoTranslation;
|
|
|
case 'r': return Strategy::Replicate;
|
|
|
case 'p': return Strategy::Pack;
|
|
|
+ case 'm': return Strategy::Resource;
|
|
|
default: break;
|
|
|
}
|
|
|
return Strategy::Unknown;
|
|
@@ -44,17 +45,18 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
|
|
|
case Strategy::NoTranslation: return "n";
|
|
|
case Strategy::Replicate: return "r";
|
|
|
case Strategy::Pack: return "p";
|
|
|
+ case Strategy::Resource: return "m"; // m for resource method
|
|
|
default: break;
|
|
|
}
|
|
|
return "?";
|
|
|
}
|
|
|
|
|
|
-ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper)
|
|
|
- : m_strategy(strategy), m_helper(helper)
|
|
|
+ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp)
|
|
|
+ : m_strategy(strategy), m_helper(helper), m_handleMap(handleMap), m_hlslOp(hlslOp)
|
|
|
{}
|
|
|
|
|
|
-ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper)
|
|
|
- : ExtensionLowering(GetStrategy(strategy), helper)
|
|
|
+ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp)
|
|
|
+ : ExtensionLowering(GetStrategy(strategy), helper, handleMap, hlslOp)
|
|
|
{}
|
|
|
|
|
|
llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
|
|
@@ -62,6 +64,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
|
|
|
case Strategy::NoTranslation: return NoTranslation(CI);
|
|
|
case Strategy::Replicate: return Replicate(CI);
|
|
|
case Strategy::Pack: return Pack(CI);
|
|
|
+ case Strategy::Resource: return Resource(CI);
|
|
|
default: break;
|
|
|
}
|
|
|
return Unknown(CI);
|
|
@@ -75,8 +78,17 @@ llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
|
|
|
// Interface to describe how to translate types from HL-dxil to dxil.
|
|
|
class FunctionTypeTranslator {
|
|
|
public:
|
|
|
+ // Arguments can be exploded into multiple copies of the same type.
|
|
|
+ // For example a <2 x i32> could become { i32, 2 } if the vector
|
|
|
+ // is expanded in place or { i32, 1 } if the call is replicated.
|
|
|
+ struct ArgumentType {
|
|
|
+ Type *type;
|
|
|
+ int count;
|
|
|
+
|
|
|
+ ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
|
|
|
+ };
|
|
|
virtual Type *TranslateReturnType(CallInst *CI) = 0;
|
|
|
- virtual Type *TranslateArgumentType(Type *OrigArgType) = 0;
|
|
|
+ virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
|
|
|
};
|
|
|
|
|
|
// Class to create the new function with the translated types for low-level dxil.
|
|
@@ -85,6 +97,10 @@ public:
|
|
|
template <typename TypeTranslator>
|
|
|
static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
|
|
|
TypeTranslator typeTranslator;
|
|
|
+ return GetLoweredFunction(typeTranslator, CI, lower);
|
|
|
+ }
|
|
|
+
|
|
|
+ static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
|
|
|
FunctionTranslator translator(typeTranslator, lower);
|
|
|
return translator.GetLoweredFunction(CI);
|
|
|
}
|
|
@@ -120,9 +136,11 @@ private:
|
|
|
SmallVector<Type *, 10> ParamTypes;
|
|
|
ParamTypes.reserve(CI->getNumArgOperands());
|
|
|
for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
|
|
|
- Type *OrigTy = CI->getArgOperand(i)->getType();
|
|
|
- Type *TranslatedTy = m_typeTranslator.TranslateArgumentType(OrigTy);
|
|
|
- ParamTypes.push_back(TranslatedTy);
|
|
|
+ Value *OrigArg = CI->getArgOperand(i);
|
|
|
+ FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
|
|
|
+ for (int i = 0; i < newArgType.count; ++i) {
|
|
|
+ ParamTypes.push_back(newArgType.type);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
const bool IsVarArg = false;
|
|
@@ -151,8 +169,8 @@ class NoTranslationTypeTranslator : public FunctionTypeTranslator {
|
|
|
virtual Type *TranslateReturnType(CallInst *CI) override {
|
|
|
return CI->getType();
|
|
|
}
|
|
|
- virtual Type *TranslateArgumentType(Type *OrigArgType) override {
|
|
|
- return OrigArgType;
|
|
|
+ virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
|
|
|
+ return ArgumentType(OrigArg->getType());
|
|
|
}
|
|
|
};
|
|
|
|
|
@@ -212,13 +230,13 @@ class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
|
|
|
return RetTy;
|
|
|
}
|
|
|
|
|
|
- virtual Type *TranslateArgumentType(Type *OrigArgType) override {
|
|
|
- Type *Ty = OrigArgType;
|
|
|
+ virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
|
|
|
+ Type *Ty = OrigArg->getType();
|
|
|
if (Ty->isVectorTy()) {
|
|
|
Ty = Ty->getVectorElementType();
|
|
|
}
|
|
|
|
|
|
- return Ty;
|
|
|
+ return ArgumentType(Ty);
|
|
|
}
|
|
|
|
|
|
};
|
|
@@ -404,8 +422,8 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
|
|
|
virtual Type *TranslateReturnType(CallInst *CI) override {
|
|
|
return TranslateIfVector(CI->getType());
|
|
|
}
|
|
|
- virtual Type *TranslateArgumentType(Type *OrigArgType) override {
|
|
|
- return TranslateIfVector(OrigArgType);
|
|
|
+ virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
|
|
|
+ return ArgumentType(TranslateIfVector(OrigArg->getType()));
|
|
|
}
|
|
|
|
|
|
Type *TranslateIfVector(Type *ty) {
|
|
@@ -425,6 +443,191 @@ Value *ExtensionLowering::Pack(CallInst *CI) {
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+///////////////////////////////////////////////////////////////////////////////
|
|
|
+// Resource Lowering.
|
|
|
+
|
|
|
+// Modify a call to a resouce method. Makes the following transformation:
|
|
|
+//
|
|
|
+// 1. Convert non-void return value to dx.types.ResRet.
|
|
|
+// 2. Convert resource parameters to the corresponding dx.types.Handle value.
|
|
|
+// 3. Expand vectors in place as separate arguments.
|
|
|
+//
|
|
|
+// Example
|
|
|
+// -----------------------------------------------------------------------------
|
|
|
+//
|
|
|
+// %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
|
|
|
+// %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
|
|
|
+// %x = extractvalue %r, 0
|
|
|
+// %y = extractvalue %r, 1
|
|
|
+// %v = <2 x float> undef
|
|
|
+// %v.1 = insertelement %v, %x, 0
|
|
|
+// %v.2 = insertelement %v.1, %y, 1
|
|
|
+class ResourceMethodCall {
|
|
|
+public:
|
|
|
+ ResourceMethodCall(CallInst *CI, Function &explodedFunction, const ExtensionLowering::HandleMap &handleMap)
|
|
|
+ : m_CI(CI)
|
|
|
+ , m_explodedFunction(explodedFunction)
|
|
|
+ , m_handleMap(handleMap)
|
|
|
+ , m_builder(CI)
|
|
|
+ { }
|
|
|
+
|
|
|
+ Value *Generate() {
|
|
|
+ SmallVector<Value *, 16> args;
|
|
|
+ ExplodeArgs(args);
|
|
|
+ Value *result = CreateCall(args);
|
|
|
+ result = ConvertResult(result);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check to see if the value is mapped to a handle in the handleMap.
|
|
|
+ static Instruction *IsResourceHandle(Value *OrigArg, const ExtensionLowering::HandleMap &handleMap) {
|
|
|
+ if (Instruction *Inst = dyn_cast<Instruction>(OrigArg)) {
|
|
|
+ if (handleMap.count(Inst))
|
|
|
+ return Inst;
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+private:
|
|
|
+ CallInst *m_CI;
|
|
|
+ Function &m_explodedFunction;
|
|
|
+ const ExtensionLowering::HandleMap &m_handleMap;
|
|
|
+ IRBuilder<> m_builder;
|
|
|
+
|
|
|
+ Value *GetResourceHandle(Value *OrigArg) {
|
|
|
+ if (Instruction *Inst = IsResourceHandle(OrigArg, m_handleMap))
|
|
|
+ return m_handleMap.at(Inst);
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ void ExplodeArgs(SmallVectorImpl<Value*> &args) {
|
|
|
+ for (Value *arg : m_CI->arg_operands()) {
|
|
|
+ // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
|
|
|
+ if (arg->getType()->isVectorTy()) {
|
|
|
+ for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
|
|
|
+ Value *xarg = m_builder.CreateExtractElement(arg, i);
|
|
|
+ args.push_back(xarg);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // resource handle arg: handle -> dx.types.Handle
|
|
|
+ else if (Value *handle = GetResourceHandle(arg)) {
|
|
|
+ args.push_back(handle);
|
|
|
+ }
|
|
|
+ // any other value: arg -> arg
|
|
|
+ else {
|
|
|
+ args.push_back(arg);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Value *CreateCall(const SmallVectorImpl<Value*> &args) {
|
|
|
+ return m_builder.CreateCall(&m_explodedFunction, args);
|
|
|
+ }
|
|
|
+
|
|
|
+ Value *ConvertResult(Value *result) {
|
|
|
+ Type *origRetTy = m_CI->getType();
|
|
|
+ if (origRetTy->isVoidTy())
|
|
|
+ return ConvertVoidResult(result);
|
|
|
+ else if (origRetTy->isVectorTy())
|
|
|
+ return ConvertVectorResult(origRetTy, result);
|
|
|
+ else
|
|
|
+ return ConvertScalarResult(origRetTy, result);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Void result does not need any conversion.
|
|
|
+ Value *ConvertVoidResult(Value *result) {
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Vector result will be populated with the elements from the resource return.
|
|
|
+ Value *ConvertVectorResult(Type *origRetTy, Value *result) {
|
|
|
+ Type *resourceRetTy = result->getType();
|
|
|
+ assert(origRetTy->isVectorTy());
|
|
|
+ assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
|
|
|
+
|
|
|
+ const unsigned vectorSize = origRetTy->getVectorNumElements();
|
|
|
+ const unsigned structSize = resourceRetTy->getStructNumElements();
|
|
|
+ const unsigned size = std::min(vectorSize, structSize);
|
|
|
+ assert(vectorSize < structSize);
|
|
|
+
|
|
|
+ // Copy resource struct elements to vector.
|
|
|
+ Value *vector = UndefValue::get(origRetTy);
|
|
|
+ for (unsigned i = 0; i < size; ++i) {
|
|
|
+ Value *element = m_builder.CreateExtractValue(result, { i });
|
|
|
+ vector = m_builder.CreateInsertElement(vector, element, i);
|
|
|
+ }
|
|
|
+
|
|
|
+ return vector;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Scalar result will be populated with the first element of the resource return.
|
|
|
+ Value *ConvertScalarResult(Type *origRetTy, Value *result) {
|
|
|
+ assert(origRetTy->isSingleValueType());
|
|
|
+ return m_builder.CreateExtractValue(result, { 0 });
|
|
|
+ }
|
|
|
+
|
|
|
+};
|
|
|
+
|
|
|
+// Translate function return and argument types for resource method lowering.
|
|
|
+class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
|
|
|
+public:
|
|
|
+ ResourceFunctionTypeTranslator(const ExtensionLowering::HandleMap &handleMap, OP& hlslOp)
|
|
|
+ : m_handleMap(handleMap)
|
|
|
+ , m_hlslOp(hlslOp)
|
|
|
+ { }
|
|
|
+
|
|
|
+ // Translate return type as follows:
|
|
|
+ //
|
|
|
+ // void -> void
|
|
|
+ // <N x ty> -> dx.types.ResRet.ty
|
|
|
+ // ty -> dx.types.ResRet.ty
|
|
|
+ virtual Type *TranslateReturnType(CallInst *CI) override {
|
|
|
+ Type *RetTy = CI->getType();
|
|
|
+ if (RetTy->isVoidTy())
|
|
|
+ return RetTy;
|
|
|
+ else if (RetTy->isVectorTy())
|
|
|
+ RetTy = RetTy->getVectorElementType();
|
|
|
+
|
|
|
+ return m_hlslOp.GetResRetType(RetTy);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Translate argument type as follows:
|
|
|
+ //
|
|
|
+ // resource -> dx.types.Handle
|
|
|
+ // <N x ty> -> { ty, N }
|
|
|
+ // ty -> { ty, 1 }
|
|
|
+ virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
|
|
|
+ int count = 1;
|
|
|
+ Type *ty = OrigArg->getType();
|
|
|
+
|
|
|
+ if (ty->isVectorTy()) {
|
|
|
+ count = ty->getVectorNumElements();
|
|
|
+ ty = ty->getVectorElementType();
|
|
|
+ }
|
|
|
+ else if (ResourceMethodCall::IsResourceHandle(OrigArg, m_handleMap)) {
|
|
|
+ ty = m_hlslOp.GetHandleType();
|
|
|
+ }
|
|
|
+
|
|
|
+ return ArgumentType(ty, count);
|
|
|
+ }
|
|
|
+
|
|
|
+private:
|
|
|
+ const ExtensionLowering::HandleMap &m_handleMap;
|
|
|
+ OP& m_hlslOp;
|
|
|
+};
|
|
|
+
|
|
|
+Value *ExtensionLowering::Resource(CallInst *CI) {
|
|
|
+ ResourceFunctionTypeTranslator resourceTypeTranslator(m_handleMap, m_hlslOp);
|
|
|
+ Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
|
|
|
+ if (!resourceFunction)
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ ResourceMethodCall explode(CI, *resourceFunction, m_handleMap);
|
|
|
+ Value *result = explode.Generate();
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
// Computing Extension Names.
|
|
|
|