Browse Source

Fixed bad aggregate type checks.

Tristan Labelle 6 years ago
parent
commit
333931ebb4

+ 1 - 2
tools/clang/include/clang/AST/HlslTypes.h

@@ -382,9 +382,8 @@ bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
 bool IsHLSLResourceType(clang::QualType type);
-bool IsHLSLNumeric(clang::QualType type);
 bool IsHLSLNumericUserDefinedType(clang::QualType type);
-bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type);
+bool IsHLSLAggregateType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);
 bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
 clang::QualType GetHLSLInputPatchElementType(clang::QualType type);

+ 6 - 4
tools/clang/lib/AST/HlslTypes.cpp

@@ -91,7 +91,7 @@ bool IsHLSLVecType(clang::QualType type) {
   return false;
 }
 
-bool IsHLSLNumeric(clang::QualType type) {
+static bool IsHLSLNumeric(clang::QualType type) {
   const clang::Type *Ty = type.getCanonicalType().getTypePtr();
   if (isa<RecordType>(Ty)) {
     if (IsHLSLVecMatType(type))
@@ -125,9 +125,11 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
   return false;
 }
 
-bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type) {
-  // Aggregate types are arrays and user-defined structs
-  if (context.getAsArrayType(type) != nullptr) return true;
+// Aggregate types are arrays and user-defined structs
+bool IsHLSLAggregateType(clang::QualType type) {
+  type = type.getCanonicalType();
+  if (isa<clang::ArrayType>(type)) return true;
+
   const RecordType *Record = dyn_cast<RecordType>(type);
   return Record != nullptr
     && !IsHLSLVecMatType(type) && !IsHLSLResourceType(type)

+ 1 - 1
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1397,7 +1397,7 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
       }
     }
 
-    if (hlsl::IsHLSLAggregateType(getContext(), LV.getType())) {
+    if (hlsl::IsHLSLAggregateType(LV.getType())) {
       // We cannot load the value because we don't expect to ever have
       // user-defined struct or array-typed llvm registers, only pointers to them.
       // To preserve the snapshot semantics of LValue loads, we copy the

+ 2 - 2
tools/clang/lib/CodeGen/CGExprScalar.cpp

@@ -1825,7 +1825,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     // If the aggregate type is the cast source, it should be a pointer.
     // Aggregate to aggregate casts are handled in CGExprAgg.cpp
     auto areCompoundAndNumeric = [this](QualType lhs, QualType rhs) {
-      return hlsl::IsHLSLAggregateType(CGF.getContext(), lhs)
+      return hlsl::IsHLSLAggregateType(lhs)
         && (rhs->isBuiltinType() || hlsl::IsHLSLVecMatType(rhs));
     };
     assert(Src->getType()->isPointerTy()
@@ -1843,7 +1843,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
       return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, DstPtr, DestTy);
     
     // Structs/arrays are pointers to temporaries
-    if (hlsl::IsHLSLAggregateType(CGF.getContext(), DestTy))
+    if (hlsl::IsHLSLAggregateType(DestTy))
       return DstPtr;
     
     // Scalars/vectors are loaded regularly

+ 5 - 6
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2644,7 +2644,7 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
       EltTy = hlsl::GetHLSLVecElementType(Ty);
     } else if (hlsl::IsHLSLMatType(Ty)) {
       EltTy = hlsl::GetHLSLMatElementType(Ty);
-    } else if (resultTy->isAggregateType()) {
+    } else if (hlsl::IsHLSLAggregateType(resultTy)) {
       // Struct or array in a none-struct resource.
       std::vector<QualType> ScalarTys;
       CollectScalarTypes(ScalarTys, resultTy);
@@ -7012,7 +7012,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
     QualType ParamTy = Param->getType().getNonReferenceType();
     bool RValOnRef = false;
     if (!Param->isModifierOut()) {
-      if (!ParamTy->isAggregateType() || hlsl::IsHLSLMatType(ParamTy)) {
+      if (!hlsl::IsHLSLAggregateType(ParamTy)) {
         if (Arg->isRValue() && Param->getType()->isReferenceType()) {
           // RValue on a reference type.
           if (const CStyleCastExpr *cCast = dyn_cast<CStyleCastExpr>(Arg)) {
@@ -7108,7 +7108,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
         !isObject) {
       QualType ArgTy = Arg->getType();
       Value *outVal = nullptr;
-      bool isAggregateTy = ParamTy->isAggregateType() && !IsHLSLVecMatType(ParamTy);
+      bool isAggregateTy = hlsl::IsHLSLAggregateType(ParamTy);
       if (!isAggregateTy) {
         if (!IsHLSLMatType(ParamTy)) {
           RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation());
@@ -7151,13 +7151,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
     
     Value *outVal = nullptr;
 
-    bool isAggrageteTy = ArgTy->isAggregateType();
-    isAggrageteTy &= !IsHLSLVecMatType(ArgTy);
+    bool isAggregateTy = hlsl::IsHLSLAggregateType(ArgTy);
 
     bool isObject = dxilutil::IsHLSLObjectType(
        tmpArgAddr->getType()->getPointerElementType());
     if (!isObject) {
-      if (!isAggrageteTy) {
+      if (!isAggregateTy) {
         if (!IsHLSLMatType(ParamTy))
           outVal = CGF.Builder.CreateLoad(tmpArgAddr);
         else

+ 1 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -4639,7 +4639,7 @@ public:
 
     // Change return type to rvalue reference type for aggregate types
     QualType retTy = parameterTypes[0];
-    if (retTy->isAggregateType() && !IsHLSLVecMatType(retTy))
+    if (hlsl::IsHLSLAggregateType(retTy))
       parameterTypes[0] = m_context->getRValueReferenceType(retTy);
 
     // Create a new specialization.

+ 15 - 0
tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for GitHub #1929, where we used the C++ definition
+// of an aggregate type and failed to match derived structs.
+
+// CHECK: ret void
+
+struct Base {};
+struct Derived : Base {};
+void f(inout Derived d) {}
+void main()
+{
+    Derived d;
+    f(d);
+}