فهرست منبع

Change some getAsStructureType() uses to getAs<RecordType>() (#4707)

In some cases, code using getAsStructureType() seemed to expect any user-
defined type to result in RecordType here, but only 'struct' types would,
leaving 'class' types to fail certain code paths.  Some code paths had an
additional getAs<RecordType>() if the getAsStructureType() returned nullptr,
but at that point, why bother with getAsStructureType() in the first place?

This updates cases in CGHLSLMS.cpp that looked to be misusing getAsStructureType
to simply use getAs<RecordType>() instead.  One case is with constructing
type annotations, where two branches are used and code is almost identical,
except skipping size return when a member is a resource was only in the struct
path.  I think removing this separate path and checking for resource on any
RecordType makes sense here.
Tex Riddell 2 سال پیش
والد
کامیت
bf1c9e4c89

+ 10 - 33
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1215,23 +1215,7 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
     AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
     // Resources don't count towards cbuffer size.
     return 0;
-  } else if (const RecordType *RT = paramTy->getAsStructureType()) {
-    RecordDecl *RD = RT->getDecl();
-    llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
-    // Skip if already created.
-    if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
-      unsigned structSize = annotation->GetCBufferSize();
-      return structSize;
-    }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
-      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
-    DxilPayloadAnnotation *payloadAnnotation = nullptr;
-    if (ValidatePayloadDecl(RT->getDecl(), *m_pHLModule->GetShaderModel(), CGM.getDiags(), CGM.getCodeGenOpts()))
-      payloadAnnotation = dxilTypeSys.AddPayloadAnnotation(ST);
-    unsigned size = ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
-    // Resources don't count towards cbuffer size.
-    return IsHLSLResourceType(Ty) ? 0 : size;
-  } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
+  } else if (const RecordType *RT = paramTy->getAs<RecordType>()) {
     // For this pointer.
     RecordDecl *RD = RT->getDecl();
     llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
@@ -1245,7 +1229,9 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
     DxilPayloadAnnotation* payloadAnnotation = nullptr;
     if (ValidatePayloadDecl(RT->getDecl(), *m_pHLModule->GetShaderModel(), CGM.getDiags(), CGM.getCodeGenOpts()))
          payloadAnnotation = dxilTypeSys.AddPayloadAnnotation(ST);
-    return ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
+    unsigned size = ConstructStructAnnotation(annotation, payloadAnnotation, RD, dxilTypeSys);
+    // Resources don't count towards cbuffer size.
+    return IsHLSLResourceType(Ty) ? 0 : size;
   } else if (IsStringType(Ty)) {
     // string won't be included in cbuffer
     return 0;
@@ -3177,10 +3163,7 @@ static void CollectScalarTypes(std::vector<QualType> &ScalarTys, QualType Ty) {
         CollectScalarTypes(ScalarTys, EltTy);
       }
     } else {
-      const RecordType *RT = Ty->getAsStructureType();
-      // For CXXRecord.
-      if (!RT)
-        RT = Ty->getAs<RecordType>();
+      const RecordType *RT = Ty->getAs<RecordType>();
       RecordDecl *RD = RT->getDecl();
       for (FieldDecl *field : RD->fields())
         CollectScalarTypes(ScalarTys, field->getType());
@@ -3994,7 +3977,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
           elts.emplace_back(Builder.CreateLoad(val));
           eltTys.emplace_back(Ty);
         } else {
-          RecordDecl *RD = Ty->getAsStructureType()->getDecl();
+          const RecordDecl *RD = Ty->getAs<RecordType>()->getDecl();
           const CGRecordLayout& RL = CGF.getTypes().getCGRecordLayout(RD);
 
           // Take care base.
@@ -4124,10 +4107,7 @@ static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVect
       // Skip hlsl object.
       idx++;
     } else {
-      const RecordType *RT = Ty->getAsStructureType();
-      // For CXXRecord.
-      if (!RT)
-        RT = Ty->getAs<RecordType>();
+      const RecordType *RT = Ty->getAs<RecordType>();
       RecordDecl *RD = RT->getDecl();
       // Take care base.
       if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
@@ -4210,10 +4190,7 @@ static void StoreInitListToDestPtr(Value *DestPtr,
     } else {
       Constant *zero = Builder.getInt32(0);
 
-      const RecordType *RT = Type->getAsStructureType();
-      // For CXXRecord.
-      if (!RT)
-        RT = Type->getAs<RecordType>();
+      const RecordType *RT = Type->getAs<RecordType>();
       RecordDecl *RD = RT->getDecl();
       const CGRecordLayout &RL = Types.getCGRecordLayout(RD);
       // Take care base.
@@ -5353,7 +5330,7 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
       EltTyList.push_back(Type);
       return;
     }
-    const clang::RecordType *RT = Type->getAsStructureType();
+    const clang::RecordType *RT = Type->getAs<RecordType>();
     RecordDecl *RD = RT->getDecl();
 
     const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
@@ -5802,7 +5779,7 @@ void CGMSHLSLRuntime::EmitHLSLSplat(
   } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
     DXASSERT(!dxilutil::IsHLSLObjectType(ST), "cannot cast to hlsl object, Sema should reject");
 
-    const clang::RecordType *RT = Type->getAsStructureType();
+    const clang::RecordType *RT = Type->getAs<RecordType>();
     RecordDecl *RD = RT->getDecl();
 
     const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);

+ 33 - 0
tools/clang/test/HLSLFileCheck/hlsl/classes/mismatch_class_implicit_cast.hlsl

@@ -0,0 +1,33 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 -HV 2021 %s | FileCheck %s -check-prefix=ERROR
+
+// CHECK: define void @main()
+// CHECK: %[[H:[^ ]+]] = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %{{[^,]+}}, i32 0)
+// CHECK: %[[f:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[H]], 0
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %[[f]])
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %[[f]])
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %[[f]])
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %[[f]])
+
+// ERROR: error: no matching function for call to 'badCall'
+// ERROR: note: candidate function not viable: no known conversion from 'A' to 'B' for 1st argument
+
+class A {
+  float f;
+  int i;
+};
+class B {
+  float f;
+  int i;
+};
+
+float4 badCall(B data) {
+  return (float4)data.f;
+}
+
+A g_dnc;
+
+float4  main() : SV_Target {
+  A dnc = g_dnc;
+  return badCall(dnc);
+}