Sfoglia il codice sorgente

Fix vec1/mat1x1 to aggregate type splats and a crash (#1979)

FXC treats scalars, vector1s and matrix1x1s the same, at least as far as conversions go. We allowed splats from scalars to aggregates, but not from vector1s or matrix1x1s.

Also fixes a bug where we would crash on conversions to aggregates where the result was ignored, because clang wouldn't allocate an AggValueSlot and we didn't handle the destination pointer being nullptr.
Tristan Labelle 6 anni fa
parent
commit
cc00f6e183

+ 22 - 15
tools/clang/lib/CodeGen/CGExprAgg.cpp

@@ -710,17 +710,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   // HLSL Change Begins.
   case CK_FlatConversion: {
     QualType Ty = E->getSubExpr()->getType();
-    llvm::Value *DestPtr = Dest.getAddr();
+
+    // We must emit the converted subexpression for any side-effects,
+    // but the conversion itself doesn't have any, so we should not
+    // emit it if we were not provided a destination aggregate value slot.
 
     if (IntegerLiteral *IL = dyn_cast<IntegerLiteral>(E->getSubExpr())) {
+      if (Dest.isIgnored()) return;
       llvm::Value *SrcVal = llvm::ConstantInt::get(CGF.getLLVMContext(), IL->getValue());
       CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversion(
-          CGF, SrcVal, DestPtr, E->getType(), Ty);
-    } else if (FloatingLiteral *FL =
-                   dyn_cast<FloatingLiteral>(E->getSubExpr())) {
+          CGF, SrcVal, Dest.getAddr(), E->getType(), Ty);
+    } else if (FloatingLiteral *FL = dyn_cast<FloatingLiteral>(E->getSubExpr())) {
+      if (Dest.isIgnored()) return;
       llvm::Value *SrcVal = llvm::ConstantFP::get(CGF.getLLVMContext(), FL->getValue());
       CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversion(
-          CGF, SrcVal, DestPtr, E->getType(), Ty);
+          CGF, SrcVal, Dest.getAddr(), E->getType(), Ty);
     } else {
       Expr *Src = E->getSubExpr();
       switch (CGF.getEvaluationKind(Ty)) {
@@ -731,21 +735,24 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
             Src = SrcCast->getSubExpr();
           }
         }
+
         // Just use decl if possible to skip useless copy.
-        if (DeclRefExpr *SrcDecl = dyn_cast<DeclRefExpr>(Src)) {
-          LValue LV = CGF.EmitLValue(SrcDecl);
-          CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
-              CGF, LV.getAddress(), Src->getType(), DestPtr, E->getType());
-        } else {
-          LValue LV = CGF.EmitAggExprToLValue(Src);
-          CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
-              CGF, LV.getAddress(), Src->getType(), DestPtr, E->getType());
-        }
+        LValue LV;
+        if (DeclRefExpr *SrcDecl = dyn_cast<DeclRefExpr>(Src))
+          LV = CGF.EmitLValue(SrcDecl);
+        else
+          LV = CGF.EmitAggExprToLValue(Src);
+
+        if (Dest.isIgnored()) return;
+        CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversionAggregateCopy(
+          CGF, LV.getAddress(), Src->getType(), Dest.getAddr(), E->getType());
       } break;
       case TEK_Scalar: {
         llvm::Value *SrcVal = CGF.EmitScalarExpr(Src);
+
+        if (Dest.isIgnored()) return;
         CGF.CGM.getHLSLRuntime().EmitHLSLFlatConversion(
-          CGF, SrcVal, DestPtr, E->getType(), Ty);
+          CGF, SrcVal, Dest.getAddr(), E->getType(), Ty);
       } break;
       default:
         assert(0 && "invalid type for flat cast");

+ 15 - 19
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -194,7 +194,7 @@ private:
                                    clang::QualType DestType,
                                    llvm::Type *Ty);
 
-  void EmitHLSLFlatConversion(CodeGenFunction &CGF, Value *SrcVal,
+  void EmitHLSLSplat(CodeGenFunction &CGF, Value *SrcVal,
                               llvm::Value *DestPtr,
                               SmallVector<Value *, 4> &idxList,
                               QualType Type, QualType SrcType,
@@ -6886,14 +6886,14 @@ static void SimpleFlatValCopy(CodeGenFunction &CGF,
     CGF.Builder.CreateStore(ResultScalar, DstPtr);
 }
 
-void CGMSHLSLRuntime::EmitHLSLFlatConversion(
+void CGMSHLSLRuntime::EmitHLSLSplat(
     CodeGenFunction &CGF, Value *SrcVal, llvm::Value *DestPtr,
     SmallVector<Value *, 4> &idxList, QualType Type, QualType SrcType,
     llvm::Type *Ty) {
   if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
     idxList.emplace_back(CGF.Builder.getInt32(0));
 
-    EmitHLSLFlatConversion(CGF, SrcVal, DestPtr, idxList, Type,
+    EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, Type,
                                       SrcType, PT->getElementType());
 
     idxList.pop_back();
@@ -6935,8 +6935,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
           Constant *idx = llvm::Constant::getIntegerValue(
               IntegerType::get(Ty->getContext(), 32), APInt(32, i));
           idxList.emplace_back(idx);
-          EmitHLSLFlatConversion(CGF, SrcVal, DestPtr, idxList,
-                                            parentTy, SrcType, ET);
+          EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, parentTy, SrcType, ET);
           idxList.pop_back();
         }
       }
@@ -6950,8 +6949,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
 
-      EmitHLSLFlatConversion(CGF, SrcVal, DestPtr, idxList,
-                                        fieldIter->getType(), SrcType, ET);
+      EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, fieldIter->getType(), SrcType, ET);
 
       idxList.pop_back();
     }
@@ -6966,8 +6964,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
 
-      EmitHLSLFlatConversion(CGF, SrcVal, DestPtr, idxList, EltType,
-                                        SrcType, ET);
+      EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, EltType, SrcType, ET);
 
       idxList.pop_back();
     }
@@ -6982,13 +6979,16 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(CodeGenFunction &CGF,
                                              Value *DestPtr,
                                              QualType Ty,
                                              QualType SrcTy) {
-  if (SrcTy->isBuiltinType()) {
-    SmallVector<Value *, 4> idxList;
-    // Add first 0 for DestPtr.
-    idxList.emplace_back(CGF.Builder.getInt32(0));
+  SmallVector<Value *, 4> SrcVals;
+  SmallVector<QualType, 4> SrcQualTys;
+  FlattenValToInitList(CGF, SrcVals, SrcQualTys, SrcTy, Val);
 
-    EmitHLSLFlatConversion(
-        CGF, Val, DestPtr, idxList, Ty, SrcTy,
+  if (SrcVals.size() == 1) {
+    // Perform a splat
+    SmallVector<Value *, 4> GEPIdxStack;
+    GEPIdxStack.emplace_back(CGF.Builder.getInt32(0)); // Add first 0 for DestPtr.
+    EmitHLSLSplat(
+        CGF, SrcVals[0], DestPtr, GEPIdxStack, Ty, SrcQualTys[0],
         DestPtr->getType()->getPointerElementType());
   }
   else {
@@ -6997,10 +6997,6 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(CodeGenFunction &CGF,
     SmallVector<QualType, 4> DstQualTys;
     FlattenAggregatePtrToGepList(CGF, DestPtr, GEPIdxStack, Ty, DestPtr->getType(), DstPtrs, DstQualTys);
 
-    SmallVector<Value *, 4> SrcVals;
-    SmallVector<QualType, 4> SrcQualTys;
-    FlattenValToInitList(CGF, SrcVals, SrcQualTys, SrcTy, Val);
-
     ConvertAndStoreElements(CGF, SrcVals, SrcQualTys, DstPtrs, DstQualTys);
   }
 }

+ 22 - 21
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -8119,14 +8119,8 @@ bool HLSLExternalSource::CanConvert(
       return false;
     }
 
-    const RecordType *targetRT = target->getAsStructureType();
-    if (!targetRT)
-      targetRT = dyn_cast<RecordType>(target);
-
-    const RecordType *sourceRT = source->getAsStructureType();
-    if (!sourceRT)
-      sourceRT = dyn_cast<RecordType>(source);
-
+    const RecordType *targetRT = dyn_cast<RecordType>(target);
+    const RecordType *sourceRT = dyn_cast<RecordType>(source);
     if (targetRT && sourceRT) {
       RecordDecl *targetRD = targetRT->getDecl();
       RecordDecl *sourceRD = sourceRT->getDecl();
@@ -8149,22 +8143,29 @@ bool HLSLExternalSource::CanConvert(
       }
     }
 
-    if (const BuiltinType *BT = source->getAs<BuiltinType>()) {
-      BuiltinType::Kind kind = BT->getKind();
-      switch (kind) {
-      case BuiltinType::Kind::UInt:
-      case BuiltinType::Kind::Int:
-      case BuiltinType::Kind::Float:
-      case BuiltinType::Kind::LitFloat:
-      case BuiltinType::Kind::LitInt:
-        if (explicitConversion) {
+    // Handle explicit splats from single element numerical types (scalars, vector1s and matrix1x1s) to aggregate types.
+    if (explicitConversion) {
+      const BuiltinType *sourceSingleElementBuiltinType = source->getAs<BuiltinType>();
+      if (sourceSingleElementBuiltinType == nullptr
+        && hlsl::IsHLSLVecMatType(source)
+        && hlsl::GetElementCount(source) == 1) {
+        sourceSingleElementBuiltinType = hlsl::GetElementTypeOrType(source)->getAs<BuiltinType>();
+      }
+
+      if (sourceSingleElementBuiltinType != nullptr) {
+        BuiltinType::Kind kind = sourceSingleElementBuiltinType->getKind();
+        switch (kind) {
+        case BuiltinType::Kind::UInt:
+        case BuiltinType::Kind::Int:
+        case BuiltinType::Kind::Float:
+        case BuiltinType::Kind::LitFloat:
+        case BuiltinType::Kind::LitInt:
           Second = ICK_Flat_Conversion;
           goto lSuccess;
+        default:
+          // Only flat conversion kinds are relevant.
+          break;
         }
-        break;
-      default:
-        // Only flat conversion kinds are relevant.
-        break;
       }
     }
 

+ 8 - 8
tools/clang/test/CodeGenHLSL/expressions/conversions_and_casts/between_type_shapes.hlsl

@@ -229,22 +229,22 @@ void main()
     // DXC: i32 1, i32 1, i32 0, i32 0, i8 15)
     // FXC: l(1,1,0,0)
     output_a2((A2)i);
-    // DXC rejects (GitHub #1863)
+    // DXC: i32 1, i32 1, i32 0, i32 0, i8 15)
     // FXC: l(1,1,0,0)
-    // output_a2((A2)v1);
-    // DXC rejects (GitHub #1863)
+    output_a2((A2)v1);
+    // DXC: i32 11, i32 11, i32 0, i32 0, i8 15)
     // FXC: l(11,11,0,0)
-    // output_a2((A2)m1x1);
+    output_a2((A2)m1x1);
 
     // DXC: i32 1, i32 1, i32 0, i32 0, i8 15)
     // FXC: l(1,1,0,0)
     output_s2((S2)i);
-    // DXC rejects (GitHub #1863)
+    // DXC: i32 1, i32 1, i32 0, i32 0, i8 15)
     // FXC: l(1,1,0,0)
-    // output_s2((S2)v1);
-    // DXC rejects (GitHub #1863)
+    output_s2((S2)v1);
+    // DXC: i32 11, i32 11, i32 0, i32 0, i8 15)
     // FXC: l(11,11,0,0)
-    // output_s2((S2)m1x1);
+    output_s2((S2)m1x1);
     
     // DXC: 8888
     output_separator();

+ 14 - 0
tools/clang/test/CodeGenHLSL/expressions/conversions_and_casts/to_aggregate_ignored.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for GitHub #1978, where converting to an aggregate type
+// and ignoring the result would cause a crash, because clang would not
+// allocate an AggValueSlot, and so our destination pointer would be nullptr.
+
+// CHECK: ret void
+
+struct S { int f; };
+void main()
+{
+  (S)0;
+  (int[1])0;
+}

+ 4 - 4
tools/clang/test/HLSL/conversions-between-type-shapes.hlsl

@@ -153,18 +153,18 @@ void main()
     to_a2(i);                                               /* expected-error {{no matching function for call to 'to_a2'}} fxc-error {{X3017: 'to_a2': cannot convert from 'int' to 'typedef int[2]'}} */
     (A2)i;
     to_a2(v1);                                              /* expected-error {{no matching function for call to 'to_a2'}} fxc-error {{X3017: 'to_a2': cannot convert from 'int1' to 'typedef int[2]'}} */
-    (A2)v1;                                                 /* expected-error {{cannot convert from 'int1' to 'A2' (aka 'int [2]')}} fxc-pass {{}} */
+    (A2)v1;
     to_a2(m1x1);                                            /* expected-error {{no matching function for call to 'to_a2'}} fxc-error {{X3017: 'to_a2': cannot convert from 'int1' to 'typedef int[2]'}} */
-    (A2)m1x1;                                               /* expected-error {{cannot convert from 'int1x1' to 'A2' (aka 'int [2]')}} fxc-pass {{}} */
+    (A2)m1x1;
     (A2)a1;                                                 /* expected-error {{cannot convert from 'A1' (aka 'int [1]') to 'A2' (aka 'int [2]')}} fxc-error {{X3017: cannot convert from 'typedef int[1]' to 'typedef int[2]'}} */
     (A2)s1;                                                 /* expected-error {{cannot convert from 'S1' to 'A2' (aka 'int [2]')}} fxc-error {{X3017: cannot convert from 'struct S1' to 'typedef int[2]'}} */
 
     to_s2(i);                                               /* expected-error {{no matching function for call to 'to_s2'}} fxc-error {{X3017: 'to_s2': cannot convert from 'int' to 'struct S2'}} */
     (S2)i;
     to_s2(v1);                                              /* expected-error {{no matching function for call to 'to_s2'}} fxc-error {{X3017: 'to_s2': cannot convert from 'int1' to 'struct S2'}} */
-    (S2)v1;                                                 /* expected-error {{cannot convert from 'int1' to 'S2'}} fxc-pass {{}} */
+    (S2)v1;
     to_s2(m1x1);                                            /* expected-error {{no matching function for call to 'to_s2'}} fxc-error {{X3017: 'to_s2': cannot convert from 'int1' to 'struct S2'}} */
-    (S2)m1x1;                                               /* expected-error {{cannot convert from 'int1x1' to 'S2'}} fxc-pass {{}} */
+    (S2)m1x1;
     (S2)a1;                                                 /* expected-error {{cannot convert from 'A1' (aka 'int [1]') to 'S2'}} fxc-error {{X3017: cannot convert from 'typedef int[1]' to 'struct S2'}} */
     (S2)s1;                                                 /* expected-error {{cannot convert from 'S1' to 'S2'}} fxc-error {{X3017: cannot convert from 'struct S1' to 'struct S2'}} */