Jelajahi Sumber

Support extension intrinsics on resources (#24)

This commit adds support for extension intrinsics that work as
methods on resources. For example, we could have an extension
on buffers called `MyBufferOp`

    Buffer<float2> buf;
    float2 val = buf.MyBufferOp(int2(1, 2))

To support extension methods we add a new resource lowering strategy
that does three transformations to the intrinsic call

1. Expand vectors in place in the call arguments.
2. Convert non-void return value to dx.types.ResRet.
3. Convert resource parameter to dx.types.Handle value.

For example, assuming that MyBufferOp has opcode 138. The resource
lowering strategy would convert the call as HL-dxil to dxil as
follows

    call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
    ==>
    call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
David Peixotto 8 tahun lalu
induk
melakukan
0f3789fb0d

+ 12 - 2
include/dxc/HLSL/HLOperationLowerExtension.h

@@ -14,15 +14,19 @@
 #include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
 #include "llvm/ADT/StringRef.h"
 #include <string>
+#include <unordered_map>
 
 namespace llvm {
   class Value;
   class CallInst;
   class Function;
   class StringRef;
+  class Instruction;
 }
 
 namespace hlsl {
+  class OP;
+
   // Lowers HLSL extensions from HL operation to DXIL operation.
   class ExtensionLowering {
   public:
@@ -32,11 +36,14 @@ namespace hlsl {
       NoTranslation,  // Propagate the call arguments as is down to dxil.
       Replicate,      // Scalarize the vector arguments and replicate the call.
       Pack,           // Convert the vector arguments into structs.
+      Resource,       // Convert return value to resource return and explode vectors.
     };
 
+    typedef std::unordered_map<llvm::Instruction *, llvm::Value *> HandleMap;
+
     // Create the lowering using the given strategy and custom codegen helper.
-    ExtensionLowering(llvm::StringRef strategy, HLSLExtensionsCodegenHelper *helper);
-    ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper);
+    ExtensionLowering(llvm::StringRef strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp);
+    ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp);
 
     // Translate the HL op call to a DXIL op call.
     // Returns a new value if translation was successful.
@@ -62,11 +69,14 @@ namespace hlsl {
   private:
     Strategy m_strategy;
     HLSLExtensionsCodegenHelper *m_helper;
+    const HandleMap &m_handleMap;
+    OP &m_hlslOp;
 
     llvm::Value *Unknown(llvm::CallInst *CI);
     llvm::Value *NoTranslation(llvm::CallInst *CI);
     llvm::Value *Replicate(llvm::CallInst *CI);
     llvm::Value *Pack(llvm::CallInst *CI);
+    llvm::Value *Resource(llvm::CallInst *CI);
 
     // Translate the HL call by replicating the call for each vector element.
     //

+ 7 - 4
lib/HLSL/HLOperationLower.cpp

@@ -5723,7 +5723,11 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
   }
 }
 
-static void TranslateHLExtension(Function *F, HLSLExtensionsCodegenHelper *helper) {
+typedef std::unordered_map<llvm::Instruction *, llvm::Value *> HandleMap;
+static void TranslateHLExtension(Function *F,
+                                 HLSLExtensionsCodegenHelper *helper,
+                                 const HandleMap &handleMap,
+                                 OP& hlslOp) {
   // Find all calls to the function F.
   // Store the calls in a vector for now to be replaced the loop below.
   // We use a two step "find then replace" to avoid removing uses while
@@ -5737,7 +5741,7 @@ static void TranslateHLExtension(Function *F, HLSLExtensionsCodegenHelper *helpe
 
   // Get the lowering strategy to use for this intrinsic.
   llvm::StringRef LowerStrategy = GetHLLowerStrategy(F);
-  ExtensionLowering lower(LowerStrategy, helper);
+  ExtensionLowering lower(LowerStrategy, helper, handleMap, hlslOp);
 
   // Replace all calls that were successfully translated.
   for (CallInst *CI : CallsToReplace) {
@@ -5773,8 +5777,7 @@ void TranslateBuiltinOperations(
       continue;
     }
     if (group == HLOpcodeGroup::HLExtIntrinsic) {
-      // TODO: consider handling extensions to object methods
-      TranslateHLExtension(F, extCodegenHelper);
+      TranslateHLExtension(F, extCodegenHelper, handleMap, helper.hlslOP);
       continue;
     }
     TranslateHLBuiltinOperation(F, helper, group, &objHelper);

+ 218 - 15
lib/HLSL/HLOperationLowerExtension.cpp

@@ -34,6 +34,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
     case 'n': return Strategy::NoTranslation;
     case 'r': return Strategy::Replicate;
     case 'p': return Strategy::Pack;
+    case 'm': return Strategy::Resource;
     default: break;
   }
   return Strategy::Unknown;
@@ -44,17 +45,18 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
     case Strategy::NoTranslation: return "n";
     case Strategy::Replicate:     return "r";
     case Strategy::Pack:          return "p";
+    case Strategy::Resource:      return "m"; // m for resource method
     default: break;
   }
   return "?";
 }
 
-ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper) 
-  : m_strategy(strategy), m_helper(helper)
+ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp)
+  : m_strategy(strategy), m_helper(helper), m_handleMap(handleMap), m_hlslOp(hlslOp)
   {}
 
-ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper) 
-  : ExtensionLowering(GetStrategy(strategy), helper)
+ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, const HandleMap &handleMap, OP& hlslOp)
+  : ExtensionLowering(GetStrategy(strategy), helper, handleMap, hlslOp)
   {}
 
 llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
@@ -62,6 +64,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
   case Strategy::NoTranslation: return NoTranslation(CI);
   case Strategy::Replicate:     return Replicate(CI);
   case Strategy::Pack:          return Pack(CI);
+  case Strategy::Resource:      return Resource(CI);
   default: break;
   }
   return Unknown(CI);
@@ -75,8 +78,17 @@ llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
 // Interface to describe how to translate types from HL-dxil to dxil.
 class FunctionTypeTranslator {
 public:
+  // Arguments can be exploded into multiple copies of the same type.
+  // For example a <2 x i32> could become { i32, 2 } if the vector
+  // is expanded in place or { i32, 1 } if the call is replicated.
+  struct ArgumentType {
+    Type *type;
+    int  count;
+
+    ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
+  };
   virtual Type *TranslateReturnType(CallInst *CI) = 0;
-  virtual Type *TranslateArgumentType(Type *OrigArgType) = 0;
+  virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
 };
 
 // Class to create the new function with the translated types for low-level dxil.
@@ -85,6 +97,10 @@ public:
   template <typename TypeTranslator>
   static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
     TypeTranslator typeTranslator;
+    return GetLoweredFunction(typeTranslator, CI, lower);
+  }
+  
+  static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
     FunctionTranslator translator(typeTranslator, lower);
     return translator.GetLoweredFunction(CI);
   }
@@ -120,9 +136,11 @@ private:
     SmallVector<Type *, 10> ParamTypes;
     ParamTypes.reserve(CI->getNumArgOperands());
     for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
-      Type *OrigTy = CI->getArgOperand(i)->getType();
-      Type *TranslatedTy = m_typeTranslator.TranslateArgumentType(OrigTy);
-      ParamTypes.push_back(TranslatedTy);
+      Value *OrigArg = CI->getArgOperand(i);
+      FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
+      for (int i = 0; i < newArgType.count; ++i) {
+        ParamTypes.push_back(newArgType.type);
+      }
     }
 
     const bool IsVarArg = false;
@@ -151,8 +169,8 @@ class NoTranslationTypeTranslator : public FunctionTypeTranslator {
   virtual Type *TranslateReturnType(CallInst *CI) override {
     return CI->getType();
   }
-  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
-    return OrigArgType;
+  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
+    return ArgumentType(OrigArg->getType());
   }
 };
 
@@ -212,13 +230,13 @@ class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
     return RetTy;
   }
 
-  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
-    Type *Ty = OrigArgType;
+  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
+    Type *Ty = OrigArg->getType();
     if (Ty->isVectorTy()) {
       Ty = Ty->getVectorElementType();
     }
 
-    return Ty;
+    return ArgumentType(Ty);
   }
 
 };
@@ -404,8 +422,8 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
   virtual Type *TranslateReturnType(CallInst *CI) override {
     return TranslateIfVector(CI->getType());
   }
-  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
-    return TranslateIfVector(OrigArgType);
+  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
+    return ArgumentType(TranslateIfVector(OrigArg->getType()));
   }
 
   Type *TranslateIfVector(Type *ty) {
@@ -425,6 +443,191 @@ Value *ExtensionLowering::Pack(CallInst *CI) {
   return result;
 }
 
+///////////////////////////////////////////////////////////////////////////////
+// Resource Lowering.
+
+// Modify a call to a resouce method. Makes the following transformation:
+//
+// 1. Convert non-void return value to dx.types.ResRet.
+// 2. Convert resource parameters to the corresponding dx.types.Handle value.
+// 3. Expand vectors in place as separate arguments.
+//
+// Example
+// -----------------------------------------------------------------------------
+//
+//  %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
+//  %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
+//  %x = extractvalue %r, 0
+//  %y = extractvalue %r, 1
+//  %v = <2 x float> undef
+//  %v.1 = insertelement %v,   %x, 0
+//  %v.2 = insertelement %v.1, %y, 1
+class ResourceMethodCall {
+public:
+  ResourceMethodCall(CallInst *CI, Function &explodedFunction, const ExtensionLowering::HandleMap &handleMap)
+    : m_CI(CI)
+    , m_explodedFunction(explodedFunction)
+    , m_handleMap(handleMap)
+    , m_builder(CI)
+  { }
+
+  Value *Generate() {
+    SmallVector<Value *, 16> args;
+    ExplodeArgs(args);
+    Value *result = CreateCall(args);
+    result = ConvertResult(result);
+    return result;
+  }
+  
+  // Check to see if the value is mapped to a handle in the handleMap.
+  static Instruction *IsResourceHandle(Value *OrigArg, const ExtensionLowering::HandleMap &handleMap) {
+    if (Instruction *Inst = dyn_cast<Instruction>(OrigArg)) {
+      if (handleMap.count(Inst))
+        return Inst;
+    }
+    return nullptr;
+  }
+  
+private:
+  CallInst *m_CI;
+  Function &m_explodedFunction;
+  const ExtensionLowering::HandleMap &m_handleMap;
+  IRBuilder<> m_builder;
+  
+  Value *GetResourceHandle(Value *OrigArg) {
+    if (Instruction *Inst = IsResourceHandle(OrigArg, m_handleMap))
+      return m_handleMap.at(Inst);
+    return nullptr;
+    
+  }
+
+  void ExplodeArgs(SmallVectorImpl<Value*> &args) {
+    for (Value *arg : m_CI->arg_operands()) {
+      // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
+      if (arg->getType()->isVectorTy()) {
+        for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
+          Value *xarg = m_builder.CreateExtractElement(arg, i);
+          args.push_back(xarg);
+        }
+      }
+      // resource handle arg: handle -> dx.types.Handle
+      else if (Value *handle = GetResourceHandle(arg)) {
+        args.push_back(handle);
+      }
+      // any other value: arg -> arg
+      else {
+        args.push_back(arg);
+      }
+    }
+  }
+
+  Value *CreateCall(const SmallVectorImpl<Value*> &args) {
+    return m_builder.CreateCall(&m_explodedFunction, args);
+  }
+
+  Value *ConvertResult(Value *result) {
+    Type *origRetTy = m_CI->getType();
+    if (origRetTy->isVoidTy())
+      return ConvertVoidResult(result);
+    else if (origRetTy->isVectorTy())
+      return ConvertVectorResult(origRetTy, result);
+    else
+      return ConvertScalarResult(origRetTy, result);
+  }
+
+  // Void result does not need any conversion.
+  Value *ConvertVoidResult(Value *result) {
+    return result;
+  }
+
+  // Vector result will be populated with the elements from the resource return.
+  Value *ConvertVectorResult(Type *origRetTy, Value *result) {
+    Type *resourceRetTy = result->getType();
+    assert(origRetTy->isVectorTy());
+    assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
+    
+    const unsigned vectorSize = origRetTy->getVectorNumElements();
+    const unsigned structSize = resourceRetTy->getStructNumElements();
+    const unsigned size = std::min(vectorSize, structSize);
+    assert(vectorSize < structSize);
+    
+    // Copy resource struct elements to vector.
+    Value *vector = UndefValue::get(origRetTy);
+    for (unsigned i = 0; i < size; ++i) {
+      Value *element = m_builder.CreateExtractValue(result, { i });
+      vector = m_builder.CreateInsertElement(vector, element, i);
+    }
+
+    return vector;
+  }
+
+  // Scalar result will be populated with the first element of the resource return.
+  Value *ConvertScalarResult(Type *origRetTy, Value *result) {
+    assert(origRetTy->isSingleValueType());
+    return m_builder.CreateExtractValue(result, { 0 });
+  }
+
+};
+
+// Translate function return and argument types for resource method lowering.
+class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
+public:
+  ResourceFunctionTypeTranslator(const ExtensionLowering::HandleMap &handleMap, OP& hlslOp)
+    : m_handleMap(handleMap)
+    , m_hlslOp(hlslOp)
+  { }
+
+  // Translate return type as follows:
+  //
+  // void     -> void
+  // <N x ty> -> dx.types.ResRet.ty
+  //  ty      -> dx.types.ResRet.ty
+  virtual Type *TranslateReturnType(CallInst *CI) override {
+    Type *RetTy = CI->getType();
+    if (RetTy->isVoidTy())
+      return RetTy;
+    else if (RetTy->isVectorTy())
+      RetTy = RetTy->getVectorElementType();
+
+    return m_hlslOp.GetResRetType(RetTy);
+  }
+  
+  // Translate argument type as follows:
+  //
+  // resource -> dx.types.Handle
+  // <N x ty> -> { ty, N }
+  //  ty      -> { ty, 1 }
+  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
+    int count = 1;
+    Type *ty = OrigArg->getType();
+
+    if (ty->isVectorTy()) {
+      count = ty->getVectorNumElements();
+      ty = ty->getVectorElementType();
+    }
+    else if (ResourceMethodCall::IsResourceHandle(OrigArg, m_handleMap)) {
+      ty = m_hlslOp.GetHandleType();
+    }
+
+    return ArgumentType(ty, count);
+  }
+
+private:
+  const ExtensionLowering::HandleMap &m_handleMap;
+  OP& m_hlslOp;
+};
+
+Value *ExtensionLowering::Resource(CallInst *CI) {
+  ResourceFunctionTypeTranslator resourceTypeTranslator(m_handleMap, m_hlslOp);
+  Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
+  if (!resourceFunction)
+    return nullptr;
+
+  ResourceMethodCall explode(CI, *resourceFunction, m_handleMap);
+  Value *result = explode.Generate();
+  return result;
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // Computing Extension Names.
 

+ 6 - 5
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -3073,15 +3073,15 @@ public:
       const HLSL_INTRINSIC *pPrior = nullptr;
       UINT64 lookupCookie = 0;
       CA2W wideTypeName(typeName);
-      table->LookupIntrinsic(wideTypeName, L"*", &pIntrinsic, &lookupCookie);
-      while (pIntrinsic != nullptr) {
+      HRESULT found = table->LookupIntrinsic(wideTypeName, L"*", &pIntrinsic, &lookupCookie);
+      while (pIntrinsic != nullptr && SUCCEEDED(found)) {
         if (!AreIntrinsicTemplatesEquivalent(pIntrinsic, pPrior)) {
           AddObjectIntrinsicTemplate(recordDecl, startDepth, pIntrinsic);
           // NOTE: this only works with the current implementation because
           // intrinsics are alive as long as the table is alive.
           pPrior = pIntrinsic;
         }
-        table->LookupIntrinsic(wideTypeName, L"*", &pIntrinsic, &lookupCookie);
+        found = table->LookupIntrinsic(wideTypeName, L"*", &pIntrinsic, &lookupCookie);
       }
     }
   }
@@ -3868,6 +3868,7 @@ public:
 
   FunctionDecl* AddHLSLIntrinsicMethod(
     LPCSTR tableName,
+    LPCSTR lowering,
     _In_ const HLSL_INTRINSIC* intrinsic,
     _In_ FunctionTemplateDecl *FunctionTemplate,
     ArrayRef<Expr *> Args,
@@ -3956,7 +3957,7 @@ public:
       SC_Extern, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
 
     // Add intrinsic attr
-    AddHLSLIntrinsicAttr(method, *m_context, tableName, "", intrinsic);
+    AddHLSLIntrinsicAttr(method, *m_context, tableName, lowering, intrinsic);
 
     // Record this function template specialization.
     TemplateArgumentList *argListCopy = TemplateArgumentList::CreateCopy(
@@ -7791,7 +7792,7 @@ Sema::TemplateDeductionResult HLSLExternalSource::DeduceTemplateArgumentsForHLSL
       continue;
     }
 
-    Specialization = AddHLSLIntrinsicMethod(cursor.GetTableName(), *cursor, FunctionTemplate, Args, argTypes, argCount);
+    Specialization = AddHLSLIntrinsicMethod(cursor.GetTableName(), cursor.GetLoweringStrategy(), *cursor, FunctionTemplate, Args, argTypes, argCount);
     DXASSERT_NOMSG(Specialization->getPrimaryTemplate()->getCanonicalDecl() ==
       FunctionTemplate->getCanonicalDecl());
 

+ 133 - 21
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -15,6 +15,7 @@
 #include "dxc/dxcapi.internal.h"
 #include "dxc/HLSL/HLOperationLowerExtension.h"
 #include "dxc/HlslIntrinsicOp.h"
+#include "llvm/Support/Regex.h"
 
 ///////////////////////////////////////////////////////////////////////////////
 // Support for test intrinsics.
@@ -90,6 +91,12 @@ static const HLSL_INTRINSIC_ARGUMENT TestUnsigned[] = {
   { "x", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 1},
 };
 
+// float2 = MyBufferOp(uint2 addr)
+static const HLSL_INTRINSIC_ARGUMENT TestMyBufferOp[] = {
+  { "MyBufferOp", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
+  { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
+};
+
 struct Intrinsic {
   LPCWSTR hlslName;
   const char *dxilName;
@@ -119,11 +126,79 @@ Intrinsic Intrinsics[] = {
   {L"test_unsigned","test_unsigned",   "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, -1, countof(TestUnsigned), TestUnsigned}},
 };
 
+Intrinsic BufferIntrinsics[] = {
+  {L"MyBufferOp",   "MyBufferOp",      "m", { 12, false, true, -1, countof(TestMyBufferOp), TestMyBufferOp}},
+};
+
+class IntrinsicTable {
+public:
+  IntrinsicTable(wchar_t *ns, Intrinsic *begin, Intrinsic *end)
+    :  m_namespace(ns), m_begin(begin), m_end(end)
+  { }
+
+  struct SearchResult {
+    Intrinsic *intrinsic;
+    uint64_t index;
+
+    SearchResult() : SearchResult(nullptr, 0) {}
+    SearchResult(Intrinsic *i, uint64_t n) : intrinsic(i), index(n) {}
+    operator bool() { return intrinsic != nullptr; }
+  };
+
+  SearchResult Search(const wchar_t *name, std::ptrdiff_t startIndex) const {
+    Intrinsic *begin = m_begin + startIndex;
+    assert(std::distance(begin, m_end) >= 0);
+    if (IsStar(name))
+      return BuildResult(begin);
+
+    Intrinsic *found = std::find_if(begin, m_end, [name](const Intrinsic &i) {
+      return wcscmp(i.hlslName, name) == 0;
+    });
+
+    return BuildResult(found);
+  }
+  
+  SearchResult Search(unsigned opcode) const {
+    Intrinsic *begin = m_begin;
+    assert(std::distance(begin, m_end) >= 0);
+    
+    Intrinsic *found = std::find_if(begin, m_end, [opcode](const Intrinsic &i) {
+      return i.hlsl.Op == opcode;
+    });
+
+    return BuildResult(found);
+  }
+  
+  bool MatchesNamespace(const wchar_t *ns) const {
+    return wcscmp(m_namespace, ns) == 0;
+  }
+
+private:
+  const wchar_t *m_namespace;
+  Intrinsic *m_begin;
+  Intrinsic *m_end;
+
+  bool IsStar(const wchar_t *name) const {
+    return wcscmp(name, L"*") == 0;
+  }
+
+  SearchResult BuildResult(Intrinsic *found) const {
+    if (found == m_end)
+      return SearchResult{ nullptr, std::numeric_limits<uint64_t>::max() };
+
+    return SearchResult{ found, static_cast<uint64_t>(std::distance(m_begin, found)) };
+  }
+};
+
 class TestIntrinsicTable : public IDxcIntrinsicTable {
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef);
+  std::vector<IntrinsicTable> m_tables;
 public:
-  TestIntrinsicTable() : m_dwRef(0) { }
+  TestIntrinsicTable() : m_dwRef(0) { 
+    m_tables.push_back(IntrinsicTable(L"",       std::begin(Intrinsics), std::end(Intrinsics)));
+    m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
+  }
   DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
   __override HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) {
     return DoBasicQueryInterface<IDxcIntrinsicTable>(this, iid, ppvObject);
@@ -138,47 +213,61 @@ public:
   __override HRESULT STDMETHODCALLTYPE LookupIntrinsic(
       LPCWSTR typeName, LPCWSTR functionName, const HLSL_INTRINSIC **pIntrinsic,
       _Inout_ UINT64 *pLookupCookie) {
-    if (typeName != nullptr && *typeName) return E_FAIL;
-    Intrinsic *intrinsic =
-      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [functionName](const Intrinsic &i) {
-        return wcscmp(i.hlslName, functionName) == 0;
-    });
-    if (intrinsic == std::end(Intrinsics))
+    if (typeName == nullptr)
       return E_FAIL;
 
-    *pIntrinsic = &intrinsic->hlsl;
-    *pLookupCookie = 0;
-    return S_OK;
+    // Search for matching intrinsic name in matching namespace.
+    IntrinsicTable::SearchResult result;
+    for (const IntrinsicTable &table : m_tables) {
+      if (table.MatchesNamespace(typeName)) {
+        result = table.Search(functionName, *pLookupCookie);
+        break;
+      }
+    }
+
+    if (result) {
+      *pIntrinsic = &result.intrinsic->hlsl;
+      *pLookupCookie = result.index + 1;
+    }
+    else {
+      *pIntrinsic = nullptr;
+      *pLookupCookie = 0;
+    }
+
+    return result.intrinsic ? S_OK : E_FAIL;
   }
 
   __override HRESULT STDMETHODCALLTYPE
   GetLoweringStrategy(UINT opcode, _Outptr_ LPCSTR *pStrategy) {
-    Intrinsic *intrinsic =
-      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [opcode](const Intrinsic &i) {
-      return i.hlsl.Op == opcode;
-    });
+    Intrinsic *intrinsic = FindByOpcode(opcode);
     
-    if (intrinsic == std::end(Intrinsics))
+    if (!intrinsic)
       return E_FAIL;
 
     *pStrategy = intrinsic->strategy;
-
     return S_OK;
   }
 
   __override HRESULT STDMETHODCALLTYPE
   GetIntrinsicName(UINT opcode, _Outptr_ LPCSTR *pName) {
-    Intrinsic *intrinsic =
-      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [opcode](const Intrinsic &i) {
-      return i.hlsl.Op == opcode;
-    });
+    Intrinsic *intrinsic = FindByOpcode(opcode);
 
-    if (intrinsic == std::end(Intrinsics))
+    if (!intrinsic)
       return E_FAIL;
 
     *pName = intrinsic->dxilName;
     return S_OK;
   }
+
+  Intrinsic *FindByOpcode(UINT opcode) {
+    IntrinsicTable::SearchResult result;
+    for (const IntrinsicTable &table : m_tables) {
+      result = table.Search(opcode);
+      if (result)
+        break;
+    }
+    return result.intrinsic;
+  }
 };
 
 // A class to test semantic define validation.
@@ -312,6 +401,7 @@ public:
   TEST_METHOD(PackedLowering);
   TEST_METHOD(ReplicateLoweringWhenOnlyVectorIsResult);
   TEST_METHOD(UnsignedOpcodeIsUnchanged);
+  TEST_METHOD(ResourceExtensionIntrinsic);
 };
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -591,3 +681,25 @@ TEST_F(ExtensionTest, UnsignedOpcodeIsUnchanged) {
     disassembly.npos !=
     disassembly.find("call i32 @test_unsigned(i32 113, "));
 }
+
+TEST_F(ExtensionTest, ResourceExtensionIntrinsic) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "Buffer<float2> buf;"
+    "float2 main(uint2 v1 : V1) : SV_Target {\n"
+    "  return buf.MyBufferOp(uint2(1, 2));\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Things to check
+  // - return type is translated to dx.types.ResRet
+  // - buffer is translated to dx.types.Handle
+  // - vector is exploded
+  llvm::Regex regex("call %dx.types.ResRet.f32 @MyBufferOp\\(i32 12, %dx.types.Handle %.*, i32 1, i32 2\\)");
+  std::string regexErrors;
+  VERIFY_IS_TRUE(regex.isValid(regexErrors));
+  VERIFY_IS_TRUE(regex.match(disassembly));
+}