Browse Source

Modify the extension mechansim to handle custom lowering for resource… (#3081)

* Modify the extension mechansim to handle custom lowering for resource methods

The goal is to allow resource extension intrinsics to be handled the same way
as the core hlsl resource intrinsics.

Specifically, we want to support:

 1. Multiple hlsl overloads map to a single dxil intrinsic
 2. The hlsl overloads can take different parameters for a given resource type
 3. The hlsl overloads are not consistent across different resource types

To achieve these goals we need a more complex mechanism for describing how
to translate the high-level arguments to arguments for a dxil function.

This commit implements custom lowering by allowing the mapping from high-level
args to dxil args to be specified as extra information along with the
lowering strategy.

The custom lowering info describes this lowering using the following format.

* Add missing virtual destructors
David Peixotto 5 years ago
parent
commit
55b9194ccc

+ 3 - 0
include/dxc/DXIL/DxilResourceBase.h

@@ -61,6 +61,7 @@ public:
   const char *GetResDimName() const;
   const char *GetResDimName() const;
   const char *GetResIDPrefix() const;
   const char *GetResIDPrefix() const;
   const char *GetResBindPrefix() const;
   const char *GetResBindPrefix() const;
+  const char *GetResKindName() const;
 
 
 protected:
 protected:
   void SetClass(Class C);
   void SetClass(Class C);
@@ -77,4 +78,6 @@ private:
   llvm::Value *m_pHandle;         // Cached resource handle for SM5.0- (and maybe SM5.1).
   llvm::Value *m_pHandle;         // Cached resource handle for SM5.0- (and maybe SM5.1).
 };
 };
 
 
+const char *GetResourceKindName(DXIL::ResourceKind K);
+
 } // namespace hlsl
 } // namespace hlsl

+ 12 - 2
include/dxc/HLSL/HLOperationLowerExtension.h

@@ -27,6 +27,13 @@ namespace llvm {
 namespace hlsl {
 namespace hlsl {
   class OP;
   class OP;
 
 
+  struct HLResourceLookup 
+  {
+      // Lookup resource kind based on handle. Return true on success.
+      virtual bool GetResourceKindName(llvm::Value *HLHandle, const char **ppName) = 0;
+      virtual ~HLResourceLookup() {}
+  };
+
   // Lowers HLSL extensions from HL operation to DXIL operation.
   // Lowers HLSL extensions from HL operation to DXIL operation.
   class ExtensionLowering {
   class ExtensionLowering {
   public:
   public:
@@ -41,8 +48,8 @@ namespace hlsl {
     };
     };
 
 
     // Create the lowering using the given strategy and custom codegen helper.
     // Create the lowering using the given strategy and custom codegen helper.
-    ExtensionLowering(llvm::StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp);
-    ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp);
+    ExtensionLowering(llvm::StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &resourceHelper);
+    ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &resourceHelper);
 
 
     // Translate the HL op call to a DXIL op call.
     // Translate the HL op call to a DXIL op call.
     // Returns a new value if translation was successful.
     // Returns a new value if translation was successful.
@@ -69,6 +76,8 @@ namespace hlsl {
     Strategy m_strategy;
     Strategy m_strategy;
     HLSLExtensionsCodegenHelper *m_helper;
     HLSLExtensionsCodegenHelper *m_helper;
     OP &m_hlslOp;
     OP &m_hlslOp;
+    HLResourceLookup &m_hlResourceLookup;
+    std::string m_extraStrategyInfo;
 
 
     llvm::Value *Unknown(llvm::CallInst *CI);
     llvm::Value *Unknown(llvm::CallInst *CI);
     llvm::Value *NoTranslation(llvm::CallInst *CI);
     llvm::Value *NoTranslation(llvm::CallInst *CI);
@@ -76,5 +85,6 @@ namespace hlsl {
     llvm::Value *Pack(llvm::CallInst *CI);
     llvm::Value *Pack(llvm::CallInst *CI);
     llvm::Value *Resource(llvm::CallInst *CI);
     llvm::Value *Resource(llvm::CallInst *CI);
     llvm::Value *Dxil(llvm::CallInst *CI);
     llvm::Value *Dxil(llvm::CallInst *CI);
+    llvm::Value *CustomResource(llvm::CallInst *CI);
   };
   };
 }
 }

+ 3 - 0
include/dxc/Support/ErrorCodes.h

@@ -104,3 +104,6 @@
 
 
 // 0X80AA0019 - Abort compilation error.
 // 0X80AA0019 - Abort compilation error.
 #define DXC_E_ABORT_COMPILATION_ERROR                 DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0019))
 #define DXC_E_ABORT_COMPILATION_ERROR                 DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0019))
+
+// 0X80AA001A - Error in extension mechanism.
+#define DXC_E_EXTENSION_ERROR                         DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x001A))

+ 18 - 0
lib/DXIL/DxilResourceBase.cpp

@@ -99,4 +99,22 @@ const char *DxilResourceBase::GetResDimName() const {
   return s_ResourceDimNames[(unsigned)m_Kind];
   return s_ResourceDimNames[(unsigned)m_Kind];
 }
 }
 
 
+static const char *s_ResourceKindNames[] = {
+        "invalid",     "Texture1D",        "Texture2D",        "Texture2DMS",      "Texture3D",
+        "TextureCube", "Texture1DArray",   "Texture2DArray",   "Texture2DMSArray", "TextureCubeArray",
+        "TypedBuffer", "RawBuffer",        "StructuredBuffer", "CBuffer",          "Sampler",
+        "TBuffer",     "RTAccelerationStructure", "FeedbackTexture2D", "FeedbackTexture2DArray",
+        "StructuredBufferWithCounter", "SamplerComparison",
+};
+static_assert(_countof(s_ResourceKindNames) == (unsigned)DxilResourceBase::Kind::NumEntries,
+  "Resource kind names array must be updated when new resource kind enums are added.");
+
+const char *DxilResourceBase::GetResKindName() const {
+  return GetResourceKindName(m_Kind);
+}
+
+const char *GetResourceKindName(DXIL::ResourceKind K) {
+  return s_ResourceKindNames[(unsigned)K];
+}
+
 } // namespace hlsl
 } // namespace hlsl

+ 26 - 3
lib/HLSL/HLOperationLower.cpp

@@ -341,6 +341,27 @@ private:
   }
   }
 };
 };
 
 
+// Helper for lowering resource extension methods.
+struct HLObjectExtensionLowerHelper : public hlsl::HLResourceLookup {
+    explicit HLObjectExtensionLowerHelper(HLObjectOperationLowerHelper &ObjHelper)
+        : m_ObjHelper(ObjHelper)
+    { }
+
+    virtual bool GetResourceKindName(Value *HLHandle, const char **ppName)
+    {
+        DXIL::ResourceKind K = m_ObjHelper.GetRK(HLHandle);
+        bool Success = K != DXIL::ResourceKind::Invalid;
+        if (Success)
+        {
+            *ppName = hlsl::GetResourceKindName(K);
+        }
+        return Success;
+    }
+
+private:
+    HLObjectOperationLowerHelper &m_ObjHelper;
+};
+
 using IntrinsicLowerFuncTy = Value *(CallInst *CI, IntrinsicOp IOP,
 using IntrinsicLowerFuncTy = Value *(CallInst *CI, IntrinsicOp IOP,
                                      DXIL::OpCode opcode,
                                      DXIL::OpCode opcode,
                                      HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated);
                                      HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated);
@@ -7693,7 +7714,8 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
 typedef std::unordered_map<llvm::Instruction *, llvm::Value *> HandleMap;
 typedef std::unordered_map<llvm::Instruction *, llvm::Value *> HandleMap;
 static void TranslateHLExtension(Function *F,
 static void TranslateHLExtension(Function *F,
                                  HLSLExtensionsCodegenHelper *helper,
                                  HLSLExtensionsCodegenHelper *helper,
-                                 OP& hlslOp) {
+                                 OP& hlslOp,
+                                 HLObjectOperationLowerHelper &objHelper) {
   // Find all calls to the function F.
   // Find all calls to the function F.
   // Store the calls in a vector for now to be replaced the loop below.
   // Store the calls in a vector for now to be replaced the loop below.
   // We use a two step "find then replace" to avoid removing uses while
   // We use a two step "find then replace" to avoid removing uses while
@@ -7707,7 +7729,8 @@ static void TranslateHLExtension(Function *F,
 
 
   // Get the lowering strategy to use for this intrinsic.
   // Get the lowering strategy to use for this intrinsic.
   llvm::StringRef LowerStrategy = GetHLLowerStrategy(F);
   llvm::StringRef LowerStrategy = GetHLLowerStrategy(F);
-  ExtensionLowering lower(LowerStrategy, helper, hlslOp);
+  HLObjectExtensionLowerHelper extObjHelper(objHelper);
+  ExtensionLowering lower(LowerStrategy, helper, hlslOp, extObjHelper);
 
 
   // Replace all calls that were successfully translated.
   // Replace all calls that were successfully translated.
   for (CallInst *CI : CallsToReplace) {
   for (CallInst *CI : CallsToReplace) {
@@ -7745,7 +7768,7 @@ void TranslateBuiltinOperations(
       continue;
       continue;
     }
     }
     if (group == HLOpcodeGroup::HLExtIntrinsic) {
     if (group == HLOpcodeGroup::HLExtIntrinsic) {
-      TranslateHLExtension(F, extCodegenHelper, helper.hlslOP);
+      TranslateHLExtension(F, extCodegenHelper, helper.hlslOP, objHelper);
       continue;
       continue;
     }
     }
     if (group == HLOpcodeGroup::HLIntrinsic) {
     if (group == HLOpcodeGroup::HLIntrinsic) {

+ 469 - 17
lib/HLSL/HLOperationLowerExtension.cpp

@@ -21,10 +21,26 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/raw_os_ostream.h"
 #include "llvm/Support/raw_os_ostream.h"
+#include "llvm/Support/YAMLParser.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/ADT/SmallString.h"
 
 
 using namespace llvm;
 using namespace llvm;
 using namespace hlsl;
 using namespace hlsl;
 
 
+LLVM_ATTRIBUTE_NORETURN static void ThrowExtensionError(StringRef Details)
+{
+    std::string Msg = (Twine("Error in dxc extension api: ") + Details).str();
+    throw hlsl::Exception(DXC_E_EXTENSION_ERROR, Msg);
+}
+
+// The lowering strategy format is a string that matches the following regex:
+//
+//      [a-z](:(?P<ExtraStrategyInfo>.+))?$
+//
+// The first character indicates the strategy with an optional : followed by
+// additional lowering information specific to that strategy.
+//
 ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
 ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
   if (strategy.size() < 1)
   if (strategy.size() < 1)
     return Strategy::Unknown;
     return Strategy::Unknown;
@@ -52,14 +68,22 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
   return "?";
   return "?";
 }
 }
 
 
-ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
-  : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp)
-  {}
+static std::string ParseExtraStrategyInfo(StringRef strategy)
+{
+    std::pair<StringRef, StringRef> SplitInfo = strategy.split(":");
+    return SplitInfo.second;
+}
 
 
-ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
-  : ExtensionLowering(GetStrategy(strategy), helper, hlslOp)
+ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp,  HLResourceLookup &hlResourceLookup)
+  : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp), m_hlResourceLookup(hlResourceLookup)
   {}
   {}
 
 
+ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
+  : ExtensionLowering(GetStrategy(strategy), helper, hlslOp, hlResourceLookup)
+  {
+    m_extraStrategyInfo = ParseExtraStrategyInfo(strategy);
+  }
+
 llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
 llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
   switch (m_strategy) {
   switch (m_strategy) {
   case Strategy::NoTranslation: return NoTranslation(CI);
   case Strategy::NoTranslation: return NoTranslation(CI);
@@ -110,7 +134,9 @@ public:
     return translator.GetLoweredFunction(CI);
     return translator.GetLoweredFunction(CI);
   }
   }
 
 
-private:
+  virtual ~FunctionTranslator() {}
+
+protected:
   FunctionTypeTranslator &m_typeTranslator;
   FunctionTypeTranslator &m_typeTranslator;
   ExtensionLowering &m_lower;
   ExtensionLowering &m_lower;
 
 
@@ -136,7 +162,7 @@ private:
     return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
     return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
   }
   }
 
 
-  FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
+  virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
     // Create a new function type with the translated argument.
     // Create a new function type with the translated argument.
     SmallVector<Type *, 10> ParamTypes;
     SmallVector<Type *, 10> ParamTypes;
     ParamTypes.reserve(CI->getNumArgOperands());
     ParamTypes.reserve(CI->getNumArgOperands());
@@ -476,23 +502,23 @@ Value *ExtensionLowering::Pack(CallInst *CI) {
 //  %v.2 = insertelement %v.1, %y, 1
 //  %v.2 = insertelement %v.1, %y, 1
 class ResourceMethodCall {
 class ResourceMethodCall {
 public:
 public:
-  ResourceMethodCall(CallInst *CI, Function &explodedFunction)
+  ResourceMethodCall(CallInst *CI)
     : m_CI(CI)
     : m_CI(CI)
-    , m_explodedFunction(explodedFunction)
     , m_builder(CI)
     , m_builder(CI)
   { }
   { }
 
 
-  Value *Generate() {
+  virtual ~ResourceMethodCall() {}
+
+  virtual Value *Generate(Function *explodedFunction) {
     SmallVector<Value *, 16> args;
     SmallVector<Value *, 16> args;
     ExplodeArgs(args);
     ExplodeArgs(args);
-    Value *result = CreateCall(args);
+    Value *result = CreateCall(explodedFunction, args);
     result = ConvertResult(result);
     result = ConvertResult(result);
     return result;
     return result;
   }
   }
   
   
-private:
+protected:
   CallInst *m_CI;
   CallInst *m_CI;
-  Function &m_explodedFunction;
   IRBuilder<> m_builder;
   IRBuilder<> m_builder;
 
 
   void ExplodeArgs(SmallVectorImpl<Value*> &args) {
   void ExplodeArgs(SmallVectorImpl<Value*> &args) {
@@ -511,8 +537,8 @@ private:
     }
     }
   }
   }
 
 
-  Value *CreateCall(const SmallVectorImpl<Value*> &args) {
-    return m_builder.CreateCall(&m_explodedFunction, args);
+  Value *CreateCall(Function *explodedFunction, ArrayRef<Value*> args) {
+    return m_builder.CreateCall(explodedFunction, args);
   }
   }
 
 
   Value *ConvertResult(Value *result) {
   Value *ConvertResult(Value *result) {
@@ -601,16 +627,442 @@ private:
 };
 };
 
 
 Value *ExtensionLowering::Resource(CallInst *CI) {
 Value *ExtensionLowering::Resource(CallInst *CI) {
+  // Extra strategy info overrides the default lowering for resource methods.
+  if (!m_extraStrategyInfo.empty())
+  {
+    return CustomResource(CI);
+  }
+
   ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
   ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
   Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
   Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
   if (!resourceFunction)
   if (!resourceFunction)
     return NoTranslation(CI);
     return NoTranslation(CI);
 
 
-  ResourceMethodCall explode(CI, *resourceFunction);
-  Value *result = explode.Generate();
+  ResourceMethodCall explode(CI);
+  Value *result = explode.Generate(resourceFunction);
   return result;
   return result;
 }
 }
 
 
+// This class handles the core logic for custom lowering of resource
+// method intrinsics. The goal is to allow resource extension intrinsics
+// to be handled the same way as the core hlsl resource intrinsics.
+//
+// Specifically, we want to support:
+//
+//  1. Multiple hlsl overloads map to a single dxil intrinsic
+//  2. The hlsl overloads can take different parameters for a given resource type
+//  3. The hlsl overloads are not consistent across different resource types 
+//
+// To achieve these goals we need a more complex mechanism for describing how
+// to translate the high-level arguments to arguments for a dxil function.
+// The custom lowering info describes this lowering using the following format.
+//
+// [Custom Lowering Info Format]
+// A json string encoding a map where each key is either a specific resource type or
+// the keyword "default" to be used for any other resource. The value is a
+// a custom-format string encoding how high-level arguments are mapped to
+// dxil intrinsic arguments.
+//
+// [Argument Translation Format]
+// A comma separated string where the number of fields is exactly equal to the number
+// of parameters in the target dxil intrinsic. Each field describes how to generate
+// the argument for that dxil intrinsic parameter. It has the following format where
+// the hl_arg_index is mandatory, but the other two parts are optional.
+//
+//      <hl_arg_index>.<vector_index>:<optional_type_info>
+//
+// The format is precisely described by the following regular expression:
+//
+//      (?P<hl_arg_index>[-0-9]+)(.(?P<vector_index>[-0-9]+))?(:(?P<optional_type_info>\?i32|\?i16|\?i8|\?float|\?half))?$
+//
+// Example
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Say we want to define the MyTextureOp extension with the following overloads:
+//
+// Texture1D
+//  MyTextureOp(uint addr, uint offset)
+//  MyTextureOp(uint addr, uint offset, uint val)
+//
+// Texture2D
+//  MyTextureOp(uint2 addr, uint2 val)
+//  
+// And a dxil intrinsic defined as follows
+//  @MyTextureOp(i32 opcode,  %dx.types.Handle handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1)
+//
+// Then we would define the lowering info json as follows
+//
+//  {
+//      "default"   : "0, 1, 2.0, 2.1,  3     , 4.0:?i32, 4.1:?i32"
+//      "Texture2D" : "0, 1, 2.0, 2.1, -1:?i32, 3.0     , 3.1\"
+//  }
+//
+//
+//  This would produce the following lowerings (assuming the MyTextureOp opcode is 17)
+//
+//  hlsl: Texture1D.MyTextureOp(a, b)
+//  hl:   @MyTextureOp(17, handle, a, b)
+//  dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
+//
+//  hlsl: Texture1D.MyTextureOp(a, b, c)
+//  hl:   @MyTextureOp(17, handle, a, b, c)
+//  dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
+//
+//  hlsl: Texture2D.MyTextureOp(a, c)
+//  hl:   @MyTextureOp(17, handle, a, c)
+//  dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
+//
+// 
+class CustomResourceLowering
+{
+public:
+    CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
+    {
+        // Parse lowering info json format.
+        std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
+            ParseLoweringInfo(LoweringInfo, CI->getContext());
+
+        // Lookup resource kind based on handle (first arg after hl opcode)
+        enum {RESOURCE_HANDLE_ARG=1};
+        const char *pName = nullptr;
+        if (!ResourceLookup.GetResourceKindName(CI->getArgOperand(RESOURCE_HANDLE_ARG), &pName))
+        {
+            ThrowExtensionError("Failed to find resource from handle");
+        }
+        std::string Name(pName);
+
+        // Select lowering info to use based on resource kind.
+        const char *DefaultInfoName = "default";
+        std::vector<DxilArgInfo> *pArgInfo = nullptr;
+        if (LoweringInfoMap.count(Name))
+        {
+            pArgInfo = &LoweringInfoMap.at(Name);
+        }
+        else if (LoweringInfoMap.count(DefaultInfoName))
+        {
+            pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
+        }
+        else
+        {
+            ThrowExtensionError("Unable to find lowering info for resource");
+        }
+        GenerateLoweredArgs(CI, *pArgInfo);
+    }
+
+    const std::vector<Value *> &GetLoweredArgs() const
+    {
+        return m_LoweredArgs;
+    }
+
+private:
+    struct OptionalTypeSpec
+    {
+        const char* TypeName;
+        Type *LLVMType;
+    };
+
+    // These are the supported optional types for generating dxil parameters
+    // that have no matching argument in the high-level intrinsic overload.
+    // See [Argument Translation Format] for details.
+    void InitOptionalTypes(LLVMContext &Ctx)
+    {
+        // Table of supported optional types.
+        // Keep in sync with m_OptionalTypes small vector size to avoid
+        // dynamic allocation.
+        OptionalTypeSpec OptionalTypes[] = {
+            {"?i32",   Type::getInt32Ty(Ctx)},
+            {"?float", Type::getFloatTy(Ctx)},
+            {"?half",  Type::getHalfTy(Ctx)},
+            {"?i8",    Type::getInt8Ty(Ctx)},
+            {"?i16",   Type::getInt16Ty(Ctx)},
+        };
+        DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
+        m_OptionalTypes.clear();
+        m_OptionalTypes.reserve(_countof(OptionalTypes));
+
+        for (const OptionalTypeSpec &T : OptionalTypes)
+        {
+            m_OptionalTypes.push_back(T);
+        }
+    }
+
+    Type *ParseOptionalType(StringRef OptionalTypeInfo)
+    {
+        if (OptionalTypeInfo.empty())
+        {
+            return nullptr;
+        }
+
+        for (OptionalTypeSpec &O : m_OptionalTypes)
+        {
+            if (OptionalTypeInfo == O.TypeName)
+            {
+                return O.LLVMType;
+            }
+        }
+            
+        ThrowExtensionError("Failed to parse optional type");
+    }
+    
+    // Mapping from high level function arg to dxil function arg.
+    //
+    // The `HighLevelArgIndex` is the index of the function argument to
+    // which this dxil argument maps.
+    //
+    // If `HasVectorIndex` is true then the `VectorIndex` contains the
+    // index of the element in the vector pointed to by HighLevelArgIndex.
+    //
+    // The `OptionalType` is used to specify types for arguments that are not
+    // present in all overloads of the high level function. This lets us
+    // map multiple high level functions to a single dxil extension intrinsic.
+    //
+    struct DxilArgInfo
+    {
+        unsigned HighLevelArgIndex = 0;
+        unsigned VectorIndex = 0;
+        bool HasVectorIndex = false;
+        Type *OptionalType = nullptr;
+    };
+    typedef std::string ResourceKindName;
+
+    // Convert the lowering info to a machine-friendly format.
+    // Note that we use the YAML parser to parse the JSON since JSON
+    // is a subset of YAML (and this llvm has no JSON parser).
+    //
+    // See [Custom Lowering Info Format] for details.
+    std::map<ResourceKindName, std::vector<DxilArgInfo>> ParseLoweringInfo(StringRef LoweringInfo, LLVMContext &Ctx)
+    {
+        InitOptionalTypes(Ctx);
+        std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap;
+
+        SourceMgr SM;
+        yaml::Stream YAMLStream(LoweringInfo, SM);
+
+        // Make sure we have a valid json input.
+        llvm::yaml::document_iterator I = YAMLStream.begin();
+        if (I == YAMLStream.end()) {
+            ThrowExtensionError("Found empty resource lowering JSON.");
+        }
+        llvm::yaml::Node *Root = I->getRoot();
+        if (!Root) {
+            ThrowExtensionError("Error parsing resource lowering JSON.");
+        }
+
+        // Parse the top level map object.
+        llvm::yaml::MappingNode *Object = dyn_cast<llvm::yaml::MappingNode>(Root);
+        if (!Object) {
+            ThrowExtensionError("Expected map in top level of resource lowering JSON.");
+        }
+
+        // Parse all key/value pairs from the map.
+        for (llvm::yaml::MappingNode::iterator KVI = Object->begin(),
+            KVE = Object->end();
+            KVI != KVE; ++KVI) 
+        {
+            // Parse key.
+            llvm::yaml::ScalarNode *KeyString =
+                dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getKey());
+            if (!KeyString) {
+                ThrowExtensionError("Expected string as key in resource lowering info JSON map.");
+            }
+            SmallString<32> KeyStorage;
+            StringRef Key = KeyString->getValue(KeyStorage);
+
+            // Parse value.
+            llvm::yaml::ScalarNode *ValueString =
+                dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getValue());
+            if (!ValueString) {
+                ThrowExtensionError("Expected string as value in resource lowering info JSON map.");
+            }
+            SmallString<128> ValueStorage;
+            StringRef Value = ValueString->getValue(ValueStorage);
+
+            // Parse dxil arg info from value.
+            LoweringInfoMap[Key] = ParseDxilArgInfo(Value, Ctx);
+        }
+
+        return LoweringInfoMap;
+    }
+
+
+    // Parse the dxail argument translation info.
+    // See [Argument Translation Format] for details.
+    std::vector<DxilArgInfo> ParseDxilArgInfo(StringRef ArgSpec, LLVMContext &Ctx)
+    {
+        std::vector<DxilArgInfo> Args;
+
+        SmallVector<StringRef, 14> Splits;
+        ArgSpec.split(Splits, ",");
+
+        for (const StringRef Split : Splits)
+        {
+            StringRef Field = Split.trim();
+            StringRef HighLevelArgInfo;
+            StringRef OptionalTypeInfo;
+            std::tie(HighLevelArgInfo, OptionalTypeInfo) = Field.split(":");
+
+            Type *OptionalType = ParseOptionalType(OptionalTypeInfo);
+
+            StringRef HighLevelArgIndex;
+            StringRef VectorIndex;
+            std::tie(HighLevelArgIndex, VectorIndex) = HighLevelArgInfo.split(".");
+
+            // Parse the arg and vector index.
+            // Parse the values as signed integers, but store them as unsigned values to
+            // allows using -1 as a shorthand for the max value.
+            DxilArgInfo ArgInfo;
+            ArgInfo.HighLevelArgIndex = static_cast<unsigned>(std::stoi(HighLevelArgIndex));
+            if (!VectorIndex.empty())
+            {
+                ArgInfo.HasVectorIndex = true;
+                ArgInfo.VectorIndex = static_cast<unsigned>(std::stoi(VectorIndex));
+            }
+            ArgInfo.OptionalType = OptionalType;
+
+            Args.push_back(ArgInfo);
+        }
+
+        return Args;
+    }
+
+    // Create the dxil args based on custom lowering info.
+    void GenerateLoweredArgs(CallInst *CI, const std::vector<DxilArgInfo> &ArgInfoRecords)
+    {
+        IRBuilder<> builder(CI);
+        for (const DxilArgInfo &ArgInfo : ArgInfoRecords)
+        {
+            // Check to see if we have the corresponding high-level arg in the overload for this call.
+            if (ArgInfo.HighLevelArgIndex < CI->getNumArgOperands())
+            {
+                Value *Arg = CI->getArgOperand(ArgInfo.HighLevelArgIndex);
+                if (ArgInfo.HasVectorIndex)
+                {
+                    // We expect a vector type here, but we handle one special case if not.
+                    if (Arg->getType()->isVectorTy())
+                    {
+                        // We allow multiple high-level overloads to map to a single dxil extension function.
+                        // If the vector index is invalid for this specific overload then use an undef
+                        // value as a replacement.
+                        if (ArgInfo.VectorIndex < Arg->getType()->getVectorNumElements())
+                        {
+                            Arg = builder.CreateExtractElement(Arg, ArgInfo.VectorIndex);
+                        }
+                        else
+                        {
+                            Arg = UndefValue::get(Arg->getType()->getVectorElementType());
+                        }
+                    }
+                    else
+                    {
+                        // If it is a non-vector type then we replace non-zero vector index with
+                        // undef. This is to handle hlsl intrinsic overloading rules that allow
+                        // scalars in place of single-element vectors. We assume here that a non-vector
+                        // means that a single element vector was already scalarized.
+                        // 
+                        if (ArgInfo.VectorIndex > 0)
+                        {
+                            Arg = UndefValue::get(Arg->getType());
+                        }
+                    }
+                }
+
+                m_LoweredArgs.push_back(Arg);
+            }
+            else if (ArgInfo.OptionalType)
+            {
+                // If there was no matching high-level arg then we look for the optional
+                // arg type specified by the lowering info.
+                m_LoweredArgs.push_back(UndefValue::get(ArgInfo.OptionalType));
+            }
+            else
+            { 
+                // No way to know how to generate the correc type for this dxil arg.
+                ThrowExtensionError("Unable to map high-level arg to dxil arg");
+            }
+        }
+    }
+    
+    std::vector<Value *> m_LoweredArgs;
+    SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
+};
+
+// Boilerplate to reuse exising logic as much as possible.
+// We just want to overload GetFunctionType here.
+class CustomResourceFunctionTranslator : public FunctionTranslator {
+public:
+  static Function *GetLoweredFunction(
+        const CustomResourceLowering &CustomLowering,
+        ResourceFunctionTypeTranslator &typeTranslator,
+        CallInst *CI,
+        ExtensionLowering &lower
+    )
+  {
+      CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
+      return T.FunctionTranslator::GetLoweredFunction(CI);
+  }
+
+private:
+    CustomResourceFunctionTranslator(
+        const CustomResourceLowering &CustomLowering,
+        ResourceFunctionTypeTranslator &typeTranslator,
+        ExtensionLowering &lower
+    )
+        : FunctionTranslator(typeTranslator, lower)
+        , m_CustomLowering(CustomLowering)
+    {
+    }
+
+    virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) override {
+        SmallVector<Type *, 16> ParamTypes;
+        for (Value *V : m_CustomLowering.GetLoweredArgs())
+        {
+            ParamTypes.push_back(V->getType());
+        }
+        const bool IsVarArg = false;
+        return FunctionType::get(RetTy, ParamTypes, IsVarArg);
+    }
+
+private:
+    const CustomResourceLowering &m_CustomLowering;
+};
+
+// Boilerplate to reuse exising logic as much as possible.
+// We just want to overload Generate here.
+class CustomResourceMethodCall : public ResourceMethodCall
+{
+public:
+    CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
+        : ResourceMethodCall(CI)
+        , m_CustomLowering(CustomLowering)
+    {}
+
+    virtual Value *Generate(Function *loweredFunction) override {
+        Value *result = CreateCall(loweredFunction, m_CustomLowering.GetLoweredArgs());
+        result = ConvertResult(result);
+        return result;
+    }
+
+private:
+    const CustomResourceLowering &m_CustomLowering;
+};
+
+// Support custom lowering logic for resource functions.
+Value *ExtensionLowering::CustomResource(CallInst *CI) {
+    CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
+    ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
+    Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
+        CustomLowering,
+        ResourceTypeTranslator,
+        CI,
+        *this
+    );
+    if (!ResourceFunction)
+        return NoTranslation(CI);
+
+    CustomResourceMethodCall custom(CI, CustomLowering);
+    Value *Result = custom.Generate(ResourceFunction);
+    return Result;
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 ///////////////////////////////////////////////////////////////////////////////
 // Dxil Lowering.
 // Dxil Lowering.
 
 

+ 139 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -123,6 +123,28 @@ static const HLSL_INTRINSIC_ARGUMENT WaveProcArgs[] = {
   { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
   { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
 };
 };
 
 
+// uint = Texutre1D.MyTextureOp(uint addr, uint offset)
+static const HLSL_INTRINSIC_ARGUMENT TestMyTexture1DOp_0[] = {
+  { "MyTextureOp", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_UINT, 1, 1 },
+  { "addr", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+  { "offset", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+};
+
+// uint = Texutre1D.MyTextureOp(uint addr, uint offset, uint val)
+static const HLSL_INTRINSIC_ARGUMENT TestMyTexture1DOp_1[] = {
+  { "MyTextureOp", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_UINT, 1, 1 },
+  { "addr", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+  { "offset", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+  { "val", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+};
+
+// uint2 = Texture2D.MyTextureOp(uint2 addr, uint2 val)
+static const HLSL_INTRINSIC_ARGUMENT TestMyTexture2DOp[] = {
+  { "MyTextureOp", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_UINT, 1, 1 },
+  { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
+  { "val", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
+};
+
 struct Intrinsic {
 struct Intrinsic {
   LPCWSTR hlslName;
   LPCWSTR hlslName;
   const char *dxilName;
   const char *dxilName;
@@ -164,6 +186,39 @@ Intrinsic SamplerIntrinsics[] = {
   {L"MySamplerOp",   "MySamplerOp",    "m", { 15, false, true, false, -1, countof(TestMySamplerOp), TestMySamplerOp}},
   {L"MySamplerOp",   "MySamplerOp",    "m", { 15, false, true, false, -1, countof(TestMySamplerOp), TestMySamplerOp}},
 };
 };
 
 
+// Define a lowering string to target a common dxil extension operation defined like this:
+//
+// @MyTextureOp(i32 opcode, %dx.types.Handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1);
+//
+//  This would produce the following lowerings (assuming the MyTextureOp opcode is 17)
+//
+//  hlsl: Texture1D.MyTextureOp(a, b)
+//  hl:   @MyTextureOp(17, handle, a, b)
+//  dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
+//
+//  hlsl: Texture1D.MyTextureOp(a, b, c)
+//  hl:   @MyTextureOp(17, handle, a, b, c)
+//  dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
+//
+//  hlsl: Texture2D.MyTextureOp(a, c)
+//  hl:   @MyTextureOp(17, handle, a, c)
+//  dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
+//
+static const char *MyTextureOp_LoweringInfo = 
+    "m:{"
+        "\"default\"   : \"0,1,2.0,2.1,3,4.0:?i32,4.1:?i32\","
+        "\"Texture2D\" : \"0,1,2.0,2.1,-1:?i32,3.0,3.1\""
+    "}";
+Intrinsic Texture1DIntrinsics[] = {
+  {L"MyTextureOp",   "MyTextureOp", MyTextureOp_LoweringInfo, { 17, false, true, false, -1, countof(TestMyTexture1DOp_0), TestMyTexture1DOp_0}},
+  {L"MyTextureOp",   "MyTextureOp", MyTextureOp_LoweringInfo, { 17, false, true, false, -1, countof(TestMyTexture1DOp_1), TestMyTexture1DOp_1}},
+};
+
+Intrinsic Texture2DIntrinsics[] = {
+  {L"MyTextureOp",   "MyTextureOp", MyTextureOp_LoweringInfo, { 17, false, true, false, -1, countof(TestMyTexture2DOp), TestMyTexture2DOp}},
+};
+
+
 class IntrinsicTable {
 class IntrinsicTable {
 public:
 public:
   IntrinsicTable(const wchar_t *ns, Intrinsic *begin, Intrinsic *end)
   IntrinsicTable(const wchar_t *ns, Intrinsic *begin, Intrinsic *end)
@@ -233,6 +288,8 @@ public:
     m_tables.push_back(IntrinsicTable(L"",       std::begin(Intrinsics), std::end(Intrinsics)));
     m_tables.push_back(IntrinsicTable(L"",       std::begin(Intrinsics), std::end(Intrinsics)));
     m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
     m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
     m_tables.push_back(IntrinsicTable(L"SamplerState", std::begin(SamplerIntrinsics), std::end(SamplerIntrinsics)));
     m_tables.push_back(IntrinsicTable(L"SamplerState", std::begin(SamplerIntrinsics), std::end(SamplerIntrinsics)));
+    m_tables.push_back(IntrinsicTable(L"Texture1D", std::begin(Texture1DIntrinsics), std::end(Texture1DIntrinsics)));
+    m_tables.push_back(IntrinsicTable(L"Texture2D", std::begin(Texture2DIntrinsics), std::end(Texture2DIntrinsics)));
   }
   }
   DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
   DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
@@ -470,6 +527,9 @@ public:
   TEST_METHOD(DxilLoweringScalar)
   TEST_METHOD(DxilLoweringScalar)
   TEST_METHOD(SamplerExtensionIntrinsic)
   TEST_METHOD(SamplerExtensionIntrinsic)
   TEST_METHOD(WaveIntrinsic)
   TEST_METHOD(WaveIntrinsic)
+  TEST_METHOD(ResourceExtensionIntrinsicCustomLowering1)
+  TEST_METHOD(ResourceExtensionIntrinsicCustomLowering2)
+  TEST_METHOD(ResourceExtensionIntrinsicCustomLowering3)
 };
 };
 
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -1043,3 +1103,82 @@ TEST_F(ExtensionTest, WaveIntrinsic) {
     disassembly.find("br i1 %2"));
     disassembly.find("br i1 %2"));
 }
 }
 
 
+TEST_F(ExtensionTest, ResourceExtensionIntrinsicCustomLowering1) {
+  // Test adding methods to objects that don't have any methods normally,
+  // and therefore have null default intrinsic table.
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  auto result = c.Compile(
+    "Texture1D tex1;"
+    "float2 main() : SV_Target {\n"
+    "  return tex1.MyTextureOp(1,2,3);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  CheckOperationResultMsgs(result, {}, true, false);
+  std::string disassembly = c.Disassemble();
+
+  // Things to check
+  // @MyTextureOp(i32 opcode, %dx.types.Handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1);
+  //
+  // hlsl: Texture1D.MyTextureOp(a, b, c)
+  // dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
+  //
+  LPCSTR expected[] = {
+    "call %dx.types.ResRet.i32 @MyTextureOp\\(i32 17, %dx.types.Handle %.*, i32 1, i32 undef, i32 2, i32 3, i32 undef\\)",
+  };
+  CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
+}
+
+TEST_F(ExtensionTest, ResourceExtensionIntrinsicCustomLowering2) {
+  // Test adding methods to objects that don't have any methods normally,
+  // and therefore have null default intrinsic table.
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  auto result = c.Compile(
+    "Texture2D tex2;"
+    "float2 main() : SV_Target {\n"
+    "  return tex2.MyTextureOp(uint2(4,5), uint2(6,7));\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  CheckOperationResultMsgs(result, {}, true, false);
+  std::string disassembly = c.Disassemble();
+
+  // Things to check
+  // @MyTextureOp(i32 opcode, %dx.types.Handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1);
+  //
+  // hlsl: Texture2D.MyTextureOp(a, c)
+  // dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
+  LPCSTR expected[] = {
+    "call %dx.types.ResRet.i32 @MyTextureOp\\(i32 17, %dx.types.Handle %.*, i32 4, i32 5, i32 undef, i32 6, i32 7\\)",
+  };
+  CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
+}
+
+TEST_F(ExtensionTest, ResourceExtensionIntrinsicCustomLowering3) {
+  // Test adding methods to objects that don't have any methods normally,
+  // and therefore have null default intrinsic table.
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  auto result = c.Compile(
+    "Texture1D tex1;"
+    "float2 main() : SV_Target {\n"
+    "  return tex1.MyTextureOp(1,2);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  CheckOperationResultMsgs(result, {}, true, false);
+  std::string disassembly = c.Disassemble();
+
+  // Things to check
+  // @MyTextureOp(i32 opcode, %dx.types.Handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1);
+  //
+  // hlsl: Texture1D.MyTextureOp(a, b)
+  // dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
+  //
+  LPCSTR expected[] = {
+    "call %dx.types.ResRet.i32 @MyTextureOp\\(i32 17, %dx.types.Handle %.*, i32 1, i32 undef, i32 2, i32 undef, i32 undef\\)",
+  };
+  CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
+}