فهرست منبع

Merge pull request #1828 from tex3d/allow-extension-methods

Fix assert in DeduceTemplateArgumentsForHLSL for extension methods
Tex Riddell 6 سال پیش
والد
کامیت
0345684539
3فایلهای تغییر یافته به همراه81 افزوده شده و 47 حذف شده
  1. 27 34
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  2. 14 13
      tools/clang/lib/Sema/SemaHLSL.cpp
  3. 40 0
      tools/clang/unittests/HLSL/ExtensionTest.cpp

+ 27 - 34
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1107,11 +1107,27 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   // Add hlsl intrinsic attr
   // Add hlsl intrinsic attr
   unsigned intrinsicOpcode;
   unsigned intrinsicOpcode;
   StringRef intrinsicGroup;
   StringRef intrinsicGroup;
+  llvm::FunctionType *FT = F->getFunctionType();
+
+  auto AddResourceMetadata = [&](QualType qTy, llvm::Type *Ty) {
+    hlsl::DxilResourceBase::Class resClass = TypeToClass(qTy);
+    if (resClass != hlsl::DxilResourceBase::Class::Invalid) {
+      if (!resMetadataMap.count(Ty)) {
+        MDNode *Meta = GetOrAddResTypeMD(qTy);
+        DXASSERT(Meta, "else invalid resource type");
+        resMetadataMap[Ty] = Meta;
+      }
+    }
+  };
+
   if (hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup)) {
   if (hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup)) {
     AddHLSLIntrinsicOpcodeToFunction(F, intrinsicOpcode);
     AddHLSLIntrinsicOpcodeToFunction(F, intrinsicOpcode);
     F->addFnAttr(hlsl::HLPrefix, intrinsicGroup);
     F->addFnAttr(hlsl::HLPrefix, intrinsicGroup);
+    unsigned iParamOffset = 0; // skip this on llvm function
+
     // Save resource type annotation.
     // Save resource type annotation.
     if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD)) {
     if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD)) {
+      iParamOffset = 1;
       const CXXRecordDecl *RD = MD->getParent();
       const CXXRecordDecl *RD = MD->getParent();
       // For nested case like sample_slice_type.
       // For nested case like sample_slice_type.
       if (const CXXRecordDecl *PRD =
       if (const CXXRecordDecl *PRD =
@@ -1120,43 +1136,20 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       }
       }
 
 
       QualType recordTy = MD->getASTContext().getRecordType(RD);
       QualType recordTy = MD->getASTContext().getRecordType(RD);
-      hlsl::DxilResourceBase::Class resClass = TypeToClass(recordTy);
       llvm::Type *Ty = CGM.getTypes().ConvertType(recordTy);
       llvm::Type *Ty = CGM.getTypes().ConvertType(recordTy);
-      llvm::FunctionType *FT = F->getFunctionType();
-      // Save resource type metadata.
-      switch (resClass) {
-      case DXIL::ResourceClass::UAV: {
-        MDNode *MD = GetOrAddResTypeMD(recordTy);
-        DXASSERT(MD, "else invalid resource type");
-        resMetadataMap[Ty] = MD;
-      } break;
-      case DXIL::ResourceClass::SRV: {
-        MDNode *Meta = GetOrAddResTypeMD(recordTy);
-        DXASSERT(Meta, "else invalid resource type");
-        resMetadataMap[Ty] = Meta;
-        if (FT->getNumParams() > 1) {
-          QualType paramTy = MD->getParamDecl(0)->getType();
-          // Add sampler type.
-          if (TypeToClass(paramTy) == DXIL::ResourceClass::Sampler) {
-            llvm::Type *Ty = FT->getParamType(1)->getPointerElementType();
-            MDNode *MD = GetOrAddResTypeMD(paramTy);
-            DXASSERT(MD, "else invalid resource type");
-            resMetadataMap[Ty] = MD;
-          }
-        }
-      } break;
-      default:
-        // Skip OutputStream for GS.
-        break;
-      }
+      AddResourceMetadata(recordTy, Ty);
     }
     }
-    if (intrinsicOpcode == (unsigned)IntrinsicOp::IOP_TraceRay) {
-      QualType recordTy = FD->getParamDecl(0)->getType();
-      llvm::Type *Ty = CGM.getTypes().ConvertType(recordTy);
-      MDNode *MD = GetOrAddResTypeMD(recordTy);
-      DXASSERT(MD, "else invalid resource type");
-      resMetadataMap[Ty] = MD;
+
+    // Add metadata for any resources found in parameters
+    for (unsigned iParam = 0; iParam < FD->getNumParams(); iParam++) {
+      llvm::Type *Ty = FT->getParamType(iParam + iParamOffset);
+      if (!Ty->isPointerTy())
+        continue; // not a resource
+      Ty = Ty->getPointerElementType();
+      QualType paramTy = FD->getParamDecl(iParam)->getType();
+      AddResourceMetadata(paramTy, Ty);
     }
     }
+
     StringRef lower;
     StringRef lower;
     if (hlsl::GetIntrinsicLowering(FD, lower))
     if (hlsl::GetIntrinsicLowering(FD, lower))
       hlsl::SetHLLowerStrategy(F, lower);
       hlsl::SetHLLowerStrategy(F, lower);

+ 14 - 13
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -185,7 +185,7 @@ enum ArBasicKind {
   AR_OBJECT_WAVE,
   AR_OBJECT_WAVE,
 
 
   AR_OBJECT_RAY_DESC,
   AR_OBJECT_RAY_DESC,
-  AR_OBJECT_ACCELARATION_STRUCT,
+  AR_OBJECT_ACCELERATION_STRUCT,
   AR_OBJECT_USER_DEFINED_TYPE,
   AR_OBJECT_USER_DEFINED_TYPE,
   AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
   AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
 
 
@@ -462,7 +462,7 @@ const UINT g_uBasicKindProps[] =
   BPROP_OBJECT,   // AR_OBJECT_WAVE
   BPROP_OBJECT,   // AR_OBJECT_WAVE
 
 
   LICOMPTYPE_RAYDESC,               // AR_OBJECT_RAY_DESC
   LICOMPTYPE_RAYDESC,               // AR_OBJECT_RAY_DESC
-  LICOMPTYPE_ACCELERATION_STRUCT,   // AR_OBJECT_ACCELARATION_STRUCT
+  LICOMPTYPE_ACCELERATION_STRUCT,   // AR_OBJECT_ACCELERATION_STRUCT
   LICOMPTYPE_USER_DEFINED_TYPE,      // AR_OBJECT_USER_DEFINED_TYPE
   LICOMPTYPE_USER_DEFINED_TYPE,      // AR_OBJECT_USER_DEFINED_TYPE
   0,      // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
   0,      // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
 
 
@@ -1098,9 +1098,9 @@ static const ArBasicKind g_RayDescCT[] =
   AR_BASIC_UNKNOWN
   AR_BASIC_UNKNOWN
 };
 };
 
 
-static const ArBasicKind g_AccelarationStructCT[] =
+static const ArBasicKind g_AccelerationStructCT[] =
 {
 {
-  AR_OBJECT_ACCELARATION_STRUCT,
+  AR_OBJECT_ACCELERATION_STRUCT,
   AR_BASIC_UNKNOWN
   AR_BASIC_UNKNOWN
 };
 };
 
 
@@ -1201,7 +1201,7 @@ const ArBasicKind* g_LegalIntrinsicCompTypes[] =
   g_UInt16CT,           // LICOMPTYPE_UINT16
   g_UInt16CT,           // LICOMPTYPE_UINT16
   g_Numeric16OnlyCT,    // LICOMPTYPE_NUMERIC16_ONLY
   g_Numeric16OnlyCT,    // LICOMPTYPE_NUMERIC16_ONLY
   g_RayDescCT,          // LICOMPTYPE_RAYDESC
   g_RayDescCT,          // LICOMPTYPE_RAYDESC
-  g_AccelarationStructCT,   // LICOMPTYPE_ACCELERATION_STRUCT,
+  g_AccelerationStructCT,   // LICOMPTYPE_ACCELERATION_STRUCT,
   g_UDTCT,              // LICOMPTYPE_USER_DEFINED_TYPE
   g_UDTCT,              // LICOMPTYPE_USER_DEFINED_TYPE
 };
 };
 C_ASSERT(ARRAYSIZE(g_LegalIntrinsicCompTypes) == LICOMPTYPE_COUNT);
 C_ASSERT(ARRAYSIZE(g_LegalIntrinsicCompTypes) == LICOMPTYPE_COUNT);
@@ -1275,7 +1275,7 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
 
 
   AR_OBJECT_WAVE,
   AR_OBJECT_WAVE,
   AR_OBJECT_RAY_DESC,
   AR_OBJECT_RAY_DESC,
-  AR_OBJECT_ACCELARATION_STRUCT,
+  AR_OBJECT_ACCELERATION_STRUCT,
   AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
   AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
 
 
   // subobjects
   // subobjects
@@ -1355,7 +1355,7 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   0, // AR_OBJECT_LEGACY_EFFECT   // Used for all unsupported but ignored legacy effect types
   0, // AR_OBJECT_LEGACY_EFFECT   // Used for all unsupported but ignored legacy effect types
   0, // AR_OBJECT_WAVE
   0, // AR_OBJECT_WAVE
   0, // AR_OBJECT_RAY_DESC
   0, // AR_OBJECT_RAY_DESC
-  0, // AR_OBJECT_ACCELARATION_STRUCT
+  0, // AR_OBJECT_ACCELERATION_STRUCT
   0, // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
   0, // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
 
 
   0, // AR_OBJECT_STATE_OBJECT_CONFIG,
   0, // AR_OBJECT_STATE_OBJECT_CONFIG,
@@ -1444,7 +1444,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_LEGACY_EFFECT (legacy effect objects)
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_LEGACY_EFFECT (legacy effect objects)
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_WAVE
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_WAVE
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_DESC
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_DESC
-  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_ACCELARATION_STRUCT
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_ACCELERATION_STRUCT
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
 
 
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_STATE_OBJECT_CONFIG,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_STATE_OBJECT_CONFIG,
@@ -3903,7 +3903,7 @@ public:
     case AR_OBJECT_APPEND_STRUCTURED_BUFFER:
     case AR_OBJECT_APPEND_STRUCTURED_BUFFER:
     case AR_OBJECT_CONSUME_STRUCTURED_BUFFER:
     case AR_OBJECT_CONSUME_STRUCTURED_BUFFER:
     case AR_OBJECT_WAVE:
     case AR_OBJECT_WAVE:
-    case AR_OBJECT_ACCELARATION_STRUCT:
+    case AR_OBJECT_ACCELERATION_STRUCT:
     case AR_OBJECT_RAY_DESC:
     case AR_OBJECT_RAY_DESC:
     case AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES:
     case AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES:
     {
     {
@@ -8983,11 +8983,12 @@ Sema::TemplateDeductionResult HLSLExternalSource::DeduceTemplateArgumentsForHLSL
   }
   }
 
 
   // Find the table of intrinsics based on the object type.
   // Find the table of intrinsics based on the object type.
-  const HLSL_INTRINSIC* intrinsics;
-  size_t intrinsicCount;
-  const char* objectName;
+  const HLSL_INTRINSIC* intrinsics = nullptr;
+  size_t intrinsicCount = 0;
+  const char* objectName = nullptr;
   FindIntrinsicTable(FunctionTemplate->getDeclContext(), &objectName, &intrinsics, &intrinsicCount);
   FindIntrinsicTable(FunctionTemplate->getDeclContext(), &objectName, &intrinsics, &intrinsicCount);
-  DXASSERT(intrinsics != nullptr,
+  DXASSERT(objectName != nullptr &&
+    (intrinsics != nullptr || m_intrinsicTables.size() > 0),
     "otherwise FindIntrinsicTable failed to lookup a valid object, "
     "otherwise FindIntrinsicTable failed to lookup a valid object, "
     "or the parser let a user-defined template object through");
     "or the parser let a user-defined template object through");
 
 

+ 40 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -111,6 +111,12 @@ static const HLSL_INTRINSIC_ARGUMENT TestIBFE[] = {
   { "val",    AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
   { "val",    AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
 };
 };
 
 
+// float2 = MySamplerOp(uint2 addr)
+static const HLSL_INTRINSIC_ARGUMENT TestMySamplerOp[] = {
+  { "MySamplerOp", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
+  { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
+};
+
 struct Intrinsic {
 struct Intrinsic {
   LPCWSTR hlslName;
   LPCWSTR hlslName;
   const char *dxilName;
   const char *dxilName;
@@ -146,6 +152,11 @@ Intrinsic BufferIntrinsics[] = {
   {L"MyBufferOp",   "MyBufferOp",      "m", { 12, false, true, -1, countof(TestMyBufferOp), TestMyBufferOp}},
   {L"MyBufferOp",   "MyBufferOp",      "m", { 12, false, true, -1, countof(TestMyBufferOp), TestMyBufferOp}},
 };
 };
 
 
+// Test adding a method to an object that normally has no methods (SamplerState will do).
+Intrinsic SamplerIntrinsics[] = {
+  {L"MySamplerOp",   "MySamplerOp",    "m", { 15, false, true, -1, countof(TestMySamplerOp), TestMySamplerOp}},
+};
+
 class IntrinsicTable {
 class IntrinsicTable {
 public:
 public:
   IntrinsicTable(const wchar_t *ns, Intrinsic *begin, Intrinsic *end)
   IntrinsicTable(const wchar_t *ns, Intrinsic *begin, Intrinsic *end)
@@ -214,6 +225,7 @@ public:
   TestIntrinsicTable() : m_dwRef(0) { 
   TestIntrinsicTable() : m_dwRef(0) { 
     m_tables.push_back(IntrinsicTable(L"",       std::begin(Intrinsics), std::end(Intrinsics)));
     m_tables.push_back(IntrinsicTable(L"",       std::begin(Intrinsics), std::end(Intrinsics)));
     m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
     m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
+    m_tables.push_back(IntrinsicTable(L"SamplerState", std::begin(SamplerIntrinsics), std::end(SamplerIntrinsics)));
   }
   }
   DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
   DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
@@ -441,6 +453,7 @@ public:
   TEST_METHOD(DxilLoweringVector1)
   TEST_METHOD(DxilLoweringVector1)
   TEST_METHOD(DxilLoweringVector2)
   TEST_METHOD(DxilLoweringVector2)
   TEST_METHOD(DxilLoweringScalar)
   TEST_METHOD(DxilLoweringScalar)
+  TEST_METHOD(SamplerExtensionIntrinsic)
 };
 };
 
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -840,3 +853,30 @@ TEST_F(ExtensionTest, DxilLoweringScalar) {
     disassembly.npos !=
     disassembly.npos !=
     disassembly.find("call i32 @dx.op.tertiary.i32(i32 51"));
     disassembly.find("call i32 @dx.op.tertiary.i32(i32 51"));
 }
 }
+
+TEST_F(ExtensionTest, SamplerExtensionIntrinsic) {
+  // Test adding methods to objects that don't have any methods normally,
+  // and therefore have null default intrinsic table.
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  auto result = c.Compile(
+    "SamplerState samp;"
+    "float2 main(uint2 v1 : V1) : SV_Target {\n"
+    "  return samp.MySamplerOp(uint2(1, 2));\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  CheckOperationResultMsgs(result, {}, true, false);
+  std::string disassembly = c.Disassemble();
+
+  // Things to check
+  // - works when SamplerState normally has no methods
+  // - return type is translated to dx.types.ResRet
+  // - buffer is translated to dx.types.Handle
+  // - vector is exploded
+  LPCSTR expected[] = {
+    "call %dx.types.ResRet.f32 @MySamplerOp\\(i32 15, %dx.types.Handle %.*, i32 1, i32 2\\)"
+  };
+  CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
+}
+