Преглед изворни кода

Treat builtin structs like normal UDT except when necessary for intrinsic.

Tex Riddell пре 7 година
родитељ
комит
33e8b219fa

+ 3 - 1
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7662,7 +7662,9 @@ def err_hlsl_intrinsic_template_arg_scalar_vector_16: Error<
    "Explicit template arguments on intrinsic %0 are limited one to scalar or vector type up to 16 bytes in size.">;
 }
 def err_hlsl_no_struct_user_defined_type: Error<
-   "User define type intrinsic arg must be struct">;
+   "User defined type intrinsic arg must be struct">;
+def err_hlsl_ray_desc_required: Error<
+   "Argument type must be struct RayDesc.">;
 def err_hlsl_missing_maxvertexcount_attr: Error<
    "GS entry point must have the maxvertexcount attribute">;
 def err_hlsl_missing_patchconstantfunc_attr: Error<

+ 45 - 27
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -183,6 +183,7 @@ enum ArBasicKind {
   AR_OBJECT_RAY_DESC,
   AR_OBJECT_ACCELARATION_STRUCT,
   AR_OBJECT_USER_DEFINED_TYPE,
+  AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
 
   AR_BASIC_MAXIMUM_COUNT
 };
@@ -442,9 +443,10 @@ const UINT g_uBasicKindProps[] =
 
   BPROP_OBJECT,   // AR_OBJECT_WAVE
 
-  LICOMPTYPE_RAYDESC,               // AR_OBJECT_WAVE
-  LICOMPTYPE_ACCELERATION_STRUCT,   // AR_OBJECT_WAVE
-  LICOMPTYPE_USER_DEFINED_TYPE,      // AR_OBJECT_WAVE
+  LICOMPTYPE_RAYDESC,               // AR_OBJECT_RAY_DESC
+  LICOMPTYPE_ACCELERATION_STRUCT,   // AR_OBJECT_ACCELARATION_STRUCT
+  LICOMPTYPE_USER_DEFINED_TYPE,      // AR_OBJECT_USER_DEFINED_TYPE
+  0,      // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
   // AR_BASIC_MAXIMUM_COUNT
 };
 
@@ -1249,6 +1251,7 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   AR_OBJECT_WAVE,
   AR_OBJECT_RAY_DESC,
   AR_OBJECT_ACCELARATION_STRUCT,
+  AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES,
 };
 
 // Count of template arguments for basic kind of objects that look like templates (one or more type arguments).
@@ -1316,8 +1319,9 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
 
   0, // AR_OBJECT_LEGACY_EFFECT   // Used for all unsupported but ignored legacy effect types
   0, // AR_OBJECT_WAVE
-  0, // AR_OBJECT_RAY_DESC,
-  0, // AR_OBJECT_ACCELARATION_STRUCT,
+  0, // AR_OBJECT_RAY_DESC
+  0, // AR_OBJECT_ACCELARATION_STRUCT
+  0, // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsTemplateCount));
@@ -1395,8 +1399,9 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
 
   { 0, MipsFalse, SampleFalse }, // AR_OBJECT_LEGACY_EFFECT (legacy effect objects)
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_WAVE
-  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_DESC,
-  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_ACCELARATION_STRUCT,
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_DESC
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_ACCELARATION_STRUCT
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts));
@@ -1497,7 +1502,8 @@ const char* g_ArBasicTypeNames[] =
   "wave_t",
   "ray_desc",
   "RaytracingAccelerationStructure",
-  "UserDefineType"
+  "UserDefineType",
+  "BuiltInTriangleIntersectionAttributes"
 };
 
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
@@ -2450,7 +2456,7 @@ static CXXRecordDecl *CreateRayDescStruct(clang::ASTContext &context,
 // {
 //   float2 barycentrics;
 // };
-static void AddBuiltInTriangleIntersectionAttributes(ASTContext& context, QualType baryType) {
+static CXXRecordDecl *AddBuiltInTriangleIntersectionAttributes(ASTContext& context, QualType baryType) {
     DeclContext *curDC = context.getTranslationUnitDecl();
     IdentifierInfo &attributesId =
         context.Idents.get(StringRef("BuiltInTriangleIntersectionAttributes"),
@@ -2464,6 +2470,7 @@ static void AddBuiltInTriangleIntersectionAttributes(ASTContext& context, QualTy
     attributesDecl->completeDefinition();
     attributesDecl->setImplicit(true);
     curDC->addDecl(attributesDecl);
+    return attributesDecl;
 }
 
 //
@@ -3080,6 +3087,9 @@ private:
       if (kind == AR_OBJECT_RAY_DESC) {
         QualType float3Ty = LookupVectorType(HLSLScalarType::HLSLScalarType_float, 3);
         recordDecl = CreateRayDescStruct(*m_context, float3Ty);
+      } else if (kind == AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES) {
+        QualType float2Type = LookupVectorType(HLSLScalarType::HLSLScalarType_float, 2);
+        recordDecl = AddBuiltInTriangleIntersectionAttributes(*m_context, float2Type);
       } else
       if (templateArgCount == 0)
       {
@@ -3210,8 +3220,6 @@ public:
     for (auto && intrinsic : m_intrinsicTables) {
       AddIntrinsicTableMethods(intrinsic);
     }
-    QualType float2Type = LookupVectorType(HLSLScalarType::HLSLScalarType_float, 2);
-    AddBuiltInTriangleIntersectionAttributes(S.getASTContext(), float2Type);
   }
 
   void ForgetSema() override
@@ -3390,10 +3398,12 @@ public:
 
     if (typeRecordDecl && typeRecordDecl->isImplicit()) {
       if (typeRecordDecl->getDeclContext()->isFileContext()) {
-        // BuiltInTriangleIntersectionAttributes will be considered as a user
-        // defined type for diagnostic purposes.
-        if (typeRecordDecl->getName().equals("BuiltInTriangleIntersectionAttributes"))
+        int index = FindObjectBasicKindIndex(typeRecordDecl);
+        if (index != -1) {
+          ArBasicKind kind  = g_ArBasicKindsAsTypes[index];
+          if ( AR_OBJECT_RAY_DESC == kind || AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES == kind)
             return AR_TOBJ_COMPOUND;
+        }
         return AR_TOBJ_OBJECT;
       }
       else
@@ -3725,6 +3735,7 @@ public:
     case AR_OBJECT_WAVE:
     case AR_OBJECT_ACCELARATION_STRUCT:
     case AR_OBJECT_RAY_DESC:
+    case AR_OBJECT_TRIANGLE_INTERSECTION_ATTRIBUTES:
     {
         const ArBasicKind* match = std::find(g_ArBasicKindsAsTypes, &g_ArBasicKindsAsTypes[_countof(g_ArBasicKindsAsTypes)], kind);
         DXASSERT(match != &g_ArBasicKindsAsTypes[_countof(g_ArBasicKindsAsTypes)], "otherwise can't find constant in basic kinds");
@@ -5060,13 +5071,30 @@ bool HLSLExternalSource::MatchArguments(
     pIntrinsicArg = &pIntrinsic->pArgs[iArg];
     DXASSERT(pIntrinsicArg->uTemplateId != INTRIN_TEMPLATE_VARARGS, "no vararg support");
 
+    QualType pType = pCallArg->getType();
+    ArTypeObjectKind TypeInfoShapeKind = GetTypeObjectKind(pType);
+    ArBasicKind TypeInfoEltKind = GetTypeElementKind(pType);
+
+    if (pIntrinsicArg->uLegalComponentTypes == LICOMPTYPE_RAYDESC) {
+      if (TypeInfoShapeKind == AR_TOBJ_COMPOUND) {
+        if (CXXRecordDecl *pDecl = pType->getAsCXXRecordDecl()) {
+          int index = FindObjectBasicKindIndex(pDecl);
+          if (index != -1 && AR_OBJECT_RAY_DESC == g_ArBasicKindsAsTypes[index]) {
+            ++iArg;
+            continue;
+          }
+        }
+      }
+      m_sema->Diag(pCallArg->getExprLoc(),
+                   diag::err_hlsl_ray_desc_required);
+      return false;
+    }
+
     if (pIntrinsicArg->uLegalComponentTypes == LICOMPTYPE_USER_DEFINED_TYPE) {
       DXASSERT(objectElement.isNull(), "");
       QualType Ty = pCallArg->getType();
       // Must be user define type for LICOMPTYPE_USER_DEFINED_TYPE arg.
-      if (!Ty->isRecordType() ||
-          hlsl::IsHLSLVecMatType(Ty) ||
-          hlsl::IsHLSLResourceType(Ty)) {
+      if (TypeInfoShapeKind != AR_TOBJ_COMPOUND) {
         m_sema->Diag(pCallArg->getExprLoc(),
                      diag::err_hlsl_no_struct_user_defined_type);
         return false;
@@ -5082,10 +5110,6 @@ bool HLSLExternalSource::MatchArguments(
       continue;
     }
 
-    QualType pType = pCallArg->getType();
-    ArTypeObjectKind TypeInfoShapeKind = GetTypeObjectKind(pType);
-    ArBasicKind TypeInfoEltKind = GetTypeElementKind(pType);
-
     if (TypeInfoEltKind == AR_BASIC_LITERAL_INT ||
         TypeInfoEltKind == AR_BASIC_LITERAL_FLOAT) {
       bool affectRetType =
@@ -9845,12 +9869,6 @@ bool FlattenedTypeIterator::pushTrackerForType(QualType type, MultiExprArg::iter
   }
 
   ArTypeObjectKind objectKind = m_source.GetTypeObjectKind(type);
-  if (objectKind == ArTypeObjectKind::AR_TOBJ_OBJECT) {
-    // Treat ray desc as compound.
-    ArBasicKind kind = m_source.GetTypeElementKind(type);
-    if (kind == AR_OBJECT_RAY_DESC)
-      objectKind = AR_TOBJ_COMPOUND;
-  }
   QualType elementType;
   unsigned int elementCount;
   const RecordType* recordType;

+ 1 - 1
tools/clang/test/CodeGenHLSL/quick-test/raytracing_attr_struct.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T lib_6_3 %s | FileCheck %s
 
-//CHECK: User define type intrinsic arg must be struct
+//CHECK: User defined type intrinsic arg must be struct
 
 float main(float THit : t, uint HitKind : h, float2 f2 : F) {
   return ReportHit(THit, HitKind, f2);

+ 2 - 2
tools/clang/test/CodeGenHLSL/quick-test/raytracing_payload_struct.hlsl

@@ -1,8 +1,8 @@
 // RUN: %dxc -E main -T lib_6_3 %s | FileCheck %s
 
-//CHECK: User define type intrinsic arg must be struct
+//CHECK: User defined type intrinsic arg must be struct
 
-RayTracingAccelerationStructure Acc;
+RaytracingAccelerationStructure Acc;
 
 uint RayFlags;
 uint InstanceInclusionMask;