Browse Source

Remove assumption that templates are never UDTs (#4752)

* Remove assumption that templates are never UDTs

There was an assumtion in the HLSL sema code that a template
specialization could never be a UDT. This assumption is incorrect now.
I've reworked the code so that we instead assume built-in types are
marked as `implicit` (which they all should and seem to be).

Correcting this in `IsHLSLNumericUserDefinedType` resulted in some
breakge in raytracing code generation because we used that method to
deterimine if structures could be payloads or attributes. That was an
incorrect API usage because we do have some builtin types that are
allowed.

The change here does the following:
* Introduce `IsHLSLBuiltinRayAttributeStruct` which returns true for the
  builtin raytracing data types that behave like UDTs.
* Introduce `IsHLSLCopyableAnnotatableRecord` returns true for
  user-defined trivially copyable structures and the builtin ray tracing
  types.
* Adjust `IsHLSLNumericUserDefinedType` to do what the name says.
* Consolidates implementations of `IsUserDefinedRecordType` across
  the project.
* Adds new test cases for the ray tracing built in structs to cover
  diagnostic cases missed by the existing tests.

The new `IsHLSLBuiltinRayAttributeStruct` is hacky and uses the type
names (as the old code did). We should in the future insert an internal
attribute on the types that can be used to denote them so that we don't
need to match string names.

Resolves #4735
Chris B 2 năm trước cách đây
mục cha
commit
4d66ec8a9c

+ 2 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -407,6 +407,8 @@ bool IsHLSLBufferViewType(clang::QualType type);
 bool IsHLSLStructuredBufferType(clang::QualType type);
 bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type);
 bool IsHLSLNumericUserDefinedType(clang::QualType type);
+bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT);
+bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT);
 bool IsHLSLAggregateType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);
 unsigned GetHLSLResourceTemplateUInt(clang::QualType type);

+ 33 - 27
tools/clang/lib/AST/HlslTypes.cpp

@@ -97,7 +97,7 @@ bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type) {
   if (isa<RecordType>(Ty)) {
     if (IsHLSLVecMatType(type))
       return true;
-    return IsHLSLNumericUserDefinedType(type);
+    return IsHLSLCopyableAnnotatableRecord(type);
   } else if (type->isArrayType()) {
     return IsHLSLNumericOrAggregateOfNumericType(QualType(type->getArrayElementTypeNoTypeQual(), 0));
   }
@@ -111,14 +111,7 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
   const clang::Type *Ty = type.getCanonicalType().getTypePtr();
   if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
     const RecordDecl *RD = RT->getDecl();
-    if (isa<ClassTemplateSpecializationDecl>(RD)) {
-      return false;   // UDT are not templates
-    }
-    // TODO: avoid check by name
-    StringRef name = RD->getName();
-    if (name == "ByteAddressBuffer" ||
-        name == "RWByteAddressBuffer" ||
-        name == "RaytracingAccelerationStructure")
+    if (!IsUserDefinedRecordType(type))
       return false;
     for (auto member : RD->fields()) {
       if (!IsHLSLNumericOrAggregateOfNumericType(member->getType()))
@@ -129,15 +122,34 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
   return false;
 }
 
+// In some cases we need record types that are annotatable and trivially
+// copyable from outside the shader. This excludes resource types which may be
+// trivially copyable inside the shader, and builtin matrix and vector types
+// which can't be annotated. But includes UDTs of trivially copyable data and
+// the builtin trivially copyable raytracing structs.
+bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT) {
+  return IsHLSLNumericUserDefinedType(QT) ||
+         IsHLSLBuiltinRayAttributeStruct(QT);
+}
+
+bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT) {
+  QT = QT.getCanonicalType();
+  const clang::Type *Ty = QT.getTypePtr();
+  if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+    const RecordDecl *RD = RT->getDecl();
+    if (RD->getName() == "BuiltInTriangleIntersectionAttributes" || 
+        RD->getName() == "RayDesc")
+      return true;
+  }
+  return false;
+}
+
 // Aggregate types are arrays and user-defined structs
 bool IsHLSLAggregateType(clang::QualType type) {
   type = type.getCanonicalType();
   if (isa<clang::ArrayType>(type)) return true;
 
-  const RecordType *Record = dyn_cast<RecordType>(type);
-  return Record != nullptr
-    && !IsHLSLVecMatType(type) && !IsHLSLResourceType(type)
-    && !dyn_cast<ClassTemplateSpecializationDecl>(Record->getAsCXXRecordDecl());
+  return IsUserDefinedRecordType(type);
 }
 
 clang::QualType GetElementTypeOrType(clang::QualType type) {
@@ -586,23 +598,17 @@ bool IsHLSLSubobjectType(clang::QualType type) {
   return GetHLSLSubobjectKind(type, kind, hgType);
 }
 
-bool IsUserDefinedRecordType(clang::QualType type) {
-  if (const auto *rt = type->getAs<RecordType>()) {
-    // HLSL specific types
-    if (hlsl::IsHLSLResourceType(type) || hlsl::IsHLSLVecMatType(type) ||
-        isa<ExtVectorType>(type.getTypePtr()) || type->isBuiltinType() ||
-        type->isArrayType()) {
-      return false;
-    }
-
-    // SubpassInput or SubpassInputMS type
-    if (rt->getDecl()->getName() == "SubpassInput" ||
-        rt->getDecl()->getName() == "SubpassInputMS") {
+bool IsUserDefinedRecordType(clang::QualType QT) {
+  const clang::Type *Ty = QT.getCanonicalType().getTypePtr();
+  if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+    const RecordDecl *RD = RT->getDecl();
+    if (RD->isImplicit())
       return false;
-    }
+    if (auto TD = dyn_cast<ClassTemplateSpecializationDecl>(RD))
+      if (TD->getSpecializedTemplate()->isImplicit())
+        return false;
     return true;
   }
-
   return false;
 }
 

+ 3 - 3
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2202,7 +2202,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
           rayShaderHaveErrors = true;
         }
         if (ArgNo < 2) {
-          if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
+          if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
             Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
               DiagnosticsEngine::Error,
               "payload and attribute structures must be user defined types with only numeric contents."));
@@ -2230,7 +2230,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
           rayShaderHaveErrors = true;
         }
         if (ArgNo < 1) {
-          if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
+          if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
             Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
               DiagnosticsEngine::Error,
               "ray payload parameter must be a user defined type with only numeric contents."));
@@ -2255,7 +2255,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
           rayShaderHaveErrors = true;
         }
         if (ArgNo < 1) {
-          if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
+          if (!IsHLSLCopyableAnnotatableRecord(parmDecl->getType())) {
             Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
               DiagnosticsEngine::Error,
               "callable parameter must be a user defined type with only numeric contents."));

+ 0 - 13
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -355,19 +355,6 @@ bool isResourceType(QualType type) {
   return hlsl::IsHLSLResourceType(type);
 }
 
-bool isUserDefinedRecordType(const ASTContext &astContext, QualType type) {
-  if (const auto *rt = type->getAs<RecordType>()) {
-    if (rt->getDecl()->getName() == "mips_slice_type" ||
-        rt->getDecl()->getName() == "sample_slice_type") {
-      return false;
-    }
-  }
-  return type->getAs<RecordType>() != nullptr && !isResourceType(type) &&
-         !isMatrixOrArrayOfMatrix(astContext, type) &&
-         !isScalarOrVectorType(type, nullptr, nullptr) &&
-         !isArrayType(type, nullptr, nullptr);
-}
-
 bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
   // Primitive types
   {

+ 2 - 1
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -20,6 +20,7 @@
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "spirv-tools/optimizer.hpp"
+#include "clang/AST/HlslTypes.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/SPIRV/AstTypeProbe.h"
 #include "clang/SPIRV/String.h"
@@ -2667,7 +2668,7 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr,
             dyn_cast<CXXMethodDecl>(operatorCall->getCalleeDecl())) {
       QualType parentType =
           QualType(cxxMethodDecl->getParent()->getTypeForDecl(), 0);
-      if (isUserDefinedRecordType(astContext, parentType)) {
+      if (hlsl::IsUserDefinedRecordType(parentType)) {
         // If the parent is a user-defined record type
         return processCall(callExpr);
       }

+ 12 - 0
tools/clang/test/HLSL/template-udt-load.hlsl

@@ -0,0 +1,12 @@
+// RUN: %clang_cc1 -fsyntax-only -ffreestanding -HV 2021 -verify %s
+
+ByteAddressBuffer In;
+RWBuffer<float> Out;
+
+
+[numthreads(1,1,1)]
+void CSMain()
+{ 
+  RWBuffer<float> FB = In.Load<RWBuffer<float> >(0); // expected-error {{Explicit template arguments on intrinsic Load must be a single numeric type}}
+  Out[0] = FB[0];
+}

+ 35 - 0
tools/clang/test/HLSLFileCheck/hlsl/template/ByteAddressBufferLoad.hlsl

@@ -0,0 +1,35 @@
+// RUN: %dxc -E CSMain -T cs_6_6 -HV 2021 -fcgl %s | FileCheck %s
+template<typename T>
+struct MyStructA
+{
+    T m_0;
+};
+
+struct MyStructB
+{
+    MyStructA<float> m_a;
+    float m_1;
+    float m_2;
+    float m_3;
+};
+
+ByteAddressBuffer g_bab;
+RWBuffer<float> result;
+
+// This test verifies that templates can be used both as the argument to
+// ByteAddressBuffer::Load and as a member of a structure passed as an argument
+// to ByteAddressLoad as long as the specialized template conforms to the rules
+// for HLSL (must only contain integral and floating point members).
+// CHECK-NOT: error
+
+[numthreads(1,1,1)]
+void CSMain()
+{
+  // CHECK: call %"struct.MyStructA<float>"* @"dx.hl.op..%\22struct.MyStructA<float>\22* (i32, %dx.types.Handle, i32)"(i32 229, %dx.types.Handle %{{[0-9]+}}, i32 0)
+  MyStructA<float> a = g_bab.Load<MyStructA<float> >(0);
+  result[0] = a.m_0;
+
+  // CHECK: call %struct.MyStructB* @"dx.hl.op..%struct.MyStructB* (i32, %dx.types.Handle, i32)"(i32 229, %dx.types.Handle %{{[0-9]+}}, i32 1)
+  MyStructB b = g_bab.Load<MyStructB>(1);
+  result[1] = b.m_a.m_0;
+}

+ 16 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/builtin-ray-types-anyhit.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
+
+// CHECK-NOT: error
+[shader("anyhit")]
+void anyhit_param0( inout RayDesc D1, RayDesc D2 ) { }
+
+[shader("anyhit")]
+void anyhit_param1( inout BuiltInTriangleIntersectionAttributes A1, BuiltInTriangleIntersectionAttributes A2 ) { }
+
+// CHECK: builtin-ray-types-anyhit.hlsl:15:37: error: payload and attribute structures must be user defined types with only numeric contents.
+// CHECK: builtin-ray-types-anyhit.hlsl:15:48: error: payload and attribute structures must be user defined types with only numeric contents.
+// CHECK: builtin-ray-types-anyhit.hlsl:15:6: error: shader must include inout payload structure parameter.
+// CHECK: builtin-ray-types-anyhit.hlsl:15:6: error: shader must include attributes structure parameter.
+[shader("anyhit")]
+void anyhit_param2( inout Texture2D A1, float4 A2 ) { }
+// CHECK-NOT: error

+ 16 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/builtin-ray-types-callable.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
+
+// CHECK-NOT: error
+
+[shader("callable")]
+void callable0( inout RayDesc param ) {}
+
+[shader("callable")]
+void callable1( inout BuiltInTriangleIntersectionAttributes param ) {}
+
+// CHECK: builtin-ray-types-callable.hlsl:14:33: error: callable parameter must be a user defined type with only numeric contents.
+// CHECK: builtin-ray-types-callable.hlsl:14:6: error: shader must include inout parameter structure.
+[shader("callable")]
+void callable2( inout Texture2D param ) {}
+
+// CHECK-NOT: error

+ 17 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/builtin-ray-types-miss.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -T lib_6_4 %s 2>&1 | FileCheck %s
+
+// CHECK-NOT: error
+
+[shader("miss")]
+void miss0(inout RayDesc PL) { }
+
+[shader("miss")]
+void miss1(inout BuiltInTriangleIntersectionAttributes PL) { }
+
+// CHECK: builtin-ray-types-miss.hlsl:15:28: error: ray payload parameter must be a user defined type with only numeric contents.
+// CHECK: builtin-ray-types-miss.hlsl:15:6: error: shader must include inout payload structure parameter.
+
+[shader("miss")]
+void miss2(inout Texture2D PL) { }
+
+// CHECK-NOT: error

+ 5 - 0
tools/clang/unittests/HLSL/VerifierTest.cpp

@@ -103,6 +103,7 @@ public:
   TEST_METHOD(GloballyCoherentErrors)
   TEST_METHOD(GloballyCoherentTemplateErrors)
   TEST_METHOD(RunBitFieldAnnotations)
+  TEST_METHOD(RunUDTByteAddressBufferLoad)
   void CheckVerifies(const wchar_t* path) {
     WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
     const char startMarker[] = "%clang_cc1";
@@ -448,3 +449,7 @@ TEST_F(VerifierTest, GloballyCoherentTemplateErrors) {
 TEST_F(VerifierTest, RunBitFieldAnnotations) {
   CheckVerifiesHLSL(L"bitfields-and-annotations.hlsl");
 }
+
+TEST_F(VerifierTest, RunUDTByteAddressBufferLoad) {
+  CheckVerifiesHLSL(L"template-udt-load.hlsl");
+}