Procházet zdrojové kódy

Generate GEP for DerivedToBase instead of bitcast. (#283)

1. Support base in FlattenedTypeIterator.
2. Generate GEP for DerivedToBase except empty struct.
Xiang Li před 8 roky
rodič
revize
6323ba161b

+ 1 - 0
tools/clang/include/clang/AST/OperationKinds.h

@@ -311,6 +311,7 @@ enum CastKind {
   CK_HLSLMatrixTruncationCast,
   CK_HLSLVectorToMatrixCast,
   CK_HLSLMatrixToVectorCast,
+  CK_HLSLDerivedToBase,
   // HLSL ComponentConversion (HLSLCC) Casts:
   CK_HLSLCC_IntegralCast,
   CK_HLSLCC_IntegralToBoolean,

+ 1 - 0
tools/clang/include/clang/Sema/Overload.h

@@ -96,6 +96,7 @@ namespace clang {
     ICK_Flat_Conversion,       ///< Flat assignment conversion for HLSL (inline conversion, straddled)
     ICK_HLSLVector_Splat,      ///< HLSLVector/Matrix splat
     ICK_HLSLVector_Truncation, ///< HLSLVector/Matrix truncation
+    ICK_HLSL_Derived_To_Base,  ///< HLSL Derived-to-base
     // HLSL Change Ends
 
     ICK_Num_Conversion_Kinds   ///< The number of conversion kinds

+ 28 - 1
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -3375,7 +3375,6 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
     return MakeAddrLValue(bitcast, ToType);
   }
   case CK_FlatConversion: {
-    // HLSL only single inheritance.
     // Just bitcast.
     QualType ToType = getContext().getLValueReferenceType(E->getType());
 
@@ -3387,6 +3386,34 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
     llvm::Value *bitcast = Builder.CreateBitCast(This, ResultType);
     return MakeAddrLValue(bitcast, ToType);
   }
+  case CK_HLSLDerivedToBase: {
+    // HLSL only single inheritance.
+    // Just GEP.
+    QualType ToType = getContext().getLValueReferenceType(E->getType());
+
+    LValue LV = EmitLValue(E->getSubExpr());
+    llvm::Value *This = LV.getAddress();
+
+    // gep to target type
+    llvm::Type *ResultType = ConvertType(ToType);
+    unsigned level = 0;
+    llvm::Type *ToTy = ResultType->getPointerElementType();
+    llvm::Type *FromTy = This->getType()->getPointerElementType();
+    // For empty struct, just bitcast.
+    if (!isa<llvm::StructType>(FromTy->getStructElementType(0))) {
+      llvm::Value *bitcast = Builder.CreateBitCast(This, ResultType);
+      return MakeAddrLValue(bitcast, ToType);
+    }
+
+    while (ToTy != FromTy) {
+      FromTy = FromTy->getStructElementType(0);
+      ++level;
+    }
+    llvm::Value *zeroIdx = Builder.getInt32(0);
+    SmallVector<llvm::Value *, 2> IdxList(level + 1, zeroIdx);
+    llvm::Value *GEP = Builder.CreateInBoundsGEP(This, IdxList);
+    return MakeAddrLValue(GEP, ToType);
+  }
   // HLSL Change Ends
   case CK_ZeroToOCLEvent:
     llvm_unreachable("NULL to OpenCL event lvalue cast is not valid");

+ 1 - 0
tools/clang/lib/Sema/SemaExprCXX.cpp

@@ -3433,6 +3433,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
       
   // HLSL Change Starts
   case ICK_Flat_Conversion:
+  case ICK_HLSL_Derived_To_Base:
   case ICK_HLSLVector_Splat:
   case ICK_HLSLVector_Scalar:
   case ICK_HLSLVector_Truncation:

+ 86 - 16
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -4169,26 +4169,43 @@ private:
     FK_Simple,
     FK_Fields,
     FK_Expressions,
-    FK_IncompleteArray
+    FK_IncompleteArray,
+    FK_Bases,
   };
 
   // Use this struct to represent a specific point in the tracked tree.
   struct FlattenedTypeTracker {
     QualType Type;                            // Type at this position in the tree.
     unsigned int Count;                       // Count of consecutive types
+    CXXRecordDecl::base_class_iterator CurrentBase; // Current base for a structure type.
+    CXXRecordDecl::base_class_iterator EndBase;     // STL-style end of bases.
     RecordDecl::field_iterator CurrentField;  // Current field in for a structure type.
     RecordDecl::field_iterator EndField;      // STL-style end of fields.
     MultiExprArg::iterator CurrentExpr;       // Current expression (advanceable for a list of expressions).
     MultiExprArg::iterator EndExpr;           // STL-style end of expressions.
     FlattenedIterKind IterKind;               // Kind of tracker.
-
-    FlattenedTypeTracker(QualType type) : Type(type), Count(0), CurrentExpr(nullptr), IterKind(FK_IncompleteArray) {}
-    FlattenedTypeTracker(QualType type, unsigned int count, MultiExprArg::iterator expression) :
-        Type(type), Count(count), CurrentExpr(expression), IterKind(FK_Simple) {}
-    FlattenedTypeTracker(QualType type, RecordDecl::field_iterator current, RecordDecl::field_iterator end)
-      : Type(type), Count(0), CurrentField(current), EndField(end), CurrentExpr(nullptr), IterKind(FK_Fields) {}
-    FlattenedTypeTracker(MultiExprArg::iterator current, MultiExprArg::iterator end)
-      : Count(0), CurrentExpr(current), EndExpr(end), IterKind(FK_Expressions) {}
+    bool   IsConsidered;                      // If a FlattenedTypeTracker already been considered.
+
+    FlattenedTypeTracker(QualType type)
+        : Type(type), Count(0), CurrentExpr(nullptr),
+          IterKind(FK_IncompleteArray), IsConsidered(false) {}
+    FlattenedTypeTracker(QualType type, unsigned int count,
+                         MultiExprArg::iterator expression)
+        : Type(type), Count(count), CurrentExpr(expression),
+          IterKind(FK_Simple), IsConsidered(false) {}
+    FlattenedTypeTracker(QualType type, RecordDecl::field_iterator current,
+                         RecordDecl::field_iterator end)
+        : Type(type), Count(0), CurrentField(current), EndField(end),
+          CurrentExpr(nullptr), IterKind(FK_Fields), IsConsidered(false) {}
+    FlattenedTypeTracker(MultiExprArg::iterator current,
+                         MultiExprArg::iterator end)
+        : Count(0), CurrentExpr(current), EndExpr(end),
+          IterKind(FK_Expressions), IsConsidered(false) {}
+    FlattenedTypeTracker(QualType type,
+                         CXXRecordDecl::base_class_iterator current,
+                         CXXRecordDecl::base_class_iterator end)
+        : Count(0), CurrentBase(current), EndBase(end), CurrentExpr(nullptr),
+          IterKind(FK_Bases), IsConsidered(false) {}
 
     /// <summary>Gets the current expression if one is available.</summary>
     Expr* getExprOrNull() const { return CurrentExpr ? *CurrentExpr : nullptr; }
@@ -6763,6 +6780,10 @@ clang::ExprResult HLSLExternalSource::PerformHLSLConversion(
       From = m_sema->ImpCastExprToType(From, targetType.getUnqualifiedType(), CK_FlatConversion, From->getValueKind(), /*BasePath=*/0, CCK).get();
       break;
     }
+    case ICK_HLSL_Derived_To_Base: {
+      From = m_sema->ImpCastExprToType(From, targetType.getUnqualifiedType(), CK_HLSLDerivedToBase, From->getValueKind(), /*BasePath=*/0, CCK).get();
+      break;
+    }
     case ICK_HLSLVector_Splat: {
       // 1. optionally convert from vec1 or mat1x1 to scalar
       From = HLSLImpCastToScalar(m_sema, From, SourceInfo.ShapeKind, SourceInfo.EltKind);
@@ -7052,7 +7073,7 @@ bool HLSLExternalSource::CanConvert(
           goto lSuccess;
         }
         if (sourceCXXRD->isDerivedFrom(targetCXXRD)) {
-          Second = ICK_Flat_Conversion;
+          Second = ICK_HLSL_Derived_To_Base;
           goto lSuccess;
         }
       } else {
@@ -9118,6 +9139,8 @@ bool FlattenedTypeIterator::considerLeaf()
 
   bool result = false;
   FlattenedTypeTracker& tracker = m_typeTrackers.back();
+  tracker.IsConsidered = true;
+
   switch (tracker.IterKind) {
   case FlattenedIterKind::FK_Expressions:
     if (pushTrackerForExpression(tracker.CurrentExpr)) {
@@ -9127,6 +9150,17 @@ bool FlattenedTypeIterator::considerLeaf()
   case FlattenedIterKind::FK_Fields:
     if (pushTrackerForType(tracker.CurrentField->getType(), nullptr)) {
       result = considerLeaf();
+    } else {
+      // Pop empty struct.
+      m_typeTrackers.pop_back();
+    }
+    break;
+  case FlattenedIterKind::FK_Bases:
+    if (pushTrackerForType(tracker.CurrentBase->getType(), nullptr)) {
+      result = considerLeaf();
+    } else {
+      // Pop empty base.
+      m_typeTrackers.pop_back();
     }
     break;
   case FlattenedIterKind::FK_IncompleteArray:
@@ -9158,6 +9192,11 @@ void FlattenedTypeIterator::consumeLeaf()
     }
 
     FlattenedTypeTracker& tracker = m_typeTrackers.back();
+    // Reach a leaf which is not considered before.
+    // Stop here.
+    if (!tracker.IsConsidered) {
+      break;
+    }
     switch (tracker.IterKind) {
     case FlattenedIterKind::FK_Expressions:
       ++tracker.CurrentExpr;
@@ -9169,6 +9208,7 @@ void FlattenedTypeIterator::consumeLeaf()
       }
       break;
     case FlattenedIterKind::FK_Fields:
+
       ++tracker.CurrentField;
       if (tracker.CurrentField == tracker.EndField) {
         m_typeTrackers.pop_back();
@@ -9177,6 +9217,15 @@ void FlattenedTypeIterator::consumeLeaf()
         return;
       }
       break;
+    case FlattenedIterKind::FK_Bases:
+      ++tracker.CurrentBase;
+      if (tracker.CurrentBase == tracker.EndBase) {
+        m_typeTrackers.pop_back();
+        topConsumed = false;
+      } else {
+        return;
+      }
+      break;
     case FlattenedIterKind::FK_IncompleteArray:
       if (m_draining) {
         DXASSERT(m_typeTrackers.size() == 1, "m_typeTrackers.size() == 1, otherwise incomplete array isn't topmost");
@@ -9264,18 +9313,39 @@ bool FlattenedTypeIterator::pushTrackerForType(QualType type, MultiExprArg::iter
   case ArTypeObjectKind::AR_TOBJ_BASIC:
     m_typeTrackers.push_back(FlattenedTypeIterator::FlattenedTypeTracker(type, 1, expression));
     return true;
-  case ArTypeObjectKind::AR_TOBJ_COMPOUND:
+  case ArTypeObjectKind::AR_TOBJ_COMPOUND: {
     recordType = type->getAsStructureType();
     if (recordType == nullptr)
       recordType = dyn_cast<RecordType>(type.getTypePtr());
+
     fi = recordType->getDecl()->field_begin();
     fe = recordType->getDecl()->field_end();
+
+    bool bAddTracker = false;
+
     // Skip empty struct.
-    if (fi == fe)
-      return false;
-    m_typeTrackers.push_back(FlattenedTypeIterator::FlattenedTypeTracker(type, fi, fe));
-    type = (*fi)->getType();
-    return true;
+    if (fi != fe) {
+      m_typeTrackers.push_back(
+          FlattenedTypeIterator::FlattenedTypeTracker(type, fi, fe));
+      type = (*fi)->getType();
+      bAddTracker = true;
+    }
+
+    if (CXXRecordDecl *cxxRecordDecl =
+            dyn_cast<CXXRecordDecl>(recordType->getDecl())) {
+      CXXRecordDecl::base_class_iterator bi, be;
+      bi = cxxRecordDecl->bases_begin();
+      be = cxxRecordDecl->bases_end();
+      if (bi != be) {
+        // Add type tracker for base.
+        // Add base after child to make sure base considered first.
+        m_typeTrackers.push_back(
+            FlattenedTypeIterator::FlattenedTypeTracker(type, bi, be));
+        bAddTracker = true;
+      }
+    }
+    return bAddTracker;
+  }
   case ArTypeObjectKind::AR_TOBJ_MATRIX:
     m_typeTrackers.push_back(FlattenedTypeIterator::FlattenedTypeTracker(
       m_source.GetMatrixOrVectorElementType(type),

+ 2 - 0
tools/clang/lib/Sema/SemaOverload.cpp

@@ -140,6 +140,7 @@ ImplicitConversionRank clang::GetConversionRank(ImplicitConversionKind Kind) {
     ICR_Conversion,
     ICR_Conversion,
     ICR_Conversion,
+    ICR_Conversion,
     // HLSL Change Ends
   };
   static_assert(_countof(Rank) == ICK_Num_Conversion_Kinds,
@@ -184,6 +185,7 @@ static const char* GetImplicitConversionName(ImplicitConversionKind Kind) {
     "Flat assignment conversion",
     "HLSLVector/Matrix splat",
     "HLSLVector/Matrix truncation",
+    "HLSL derived to base",
     // HLSL Change Ends
   };
   static_assert(_countof(Name) == ICK_Num_Conversion_Kinds,

+ 22 - 0
tools/clang/test/CodeGenHLSL/cast7.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T ps_6_0 %s -fcgl | FileCheck %s
+
+// Make sure no bitcast to %struct.A*.
+// CHECK-NOT: to %struct.A*
+
+struct A {
+   float a;
+};
+
+struct B : A {
+   float b;
+};
+
+A  ga;
+float2 ib;
+
+float main() : SV_Target
+{
+  B b = {ib};
+  (A)b = ga;
+  return ((A)b).a + b.b;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -387,6 +387,7 @@ public:
   TEST_METHOD(CodeGenCast4)
   TEST_METHOD(CodeGenCast5)
   TEST_METHOD(CodeGenCast6)
+  TEST_METHOD(CodeGenCast7)
   TEST_METHOD(CodeGenCbuf_init_static)
   TEST_METHOD(CodeGenCbufferCopy)
   TEST_METHOD(CodeGenCbufferCopy2)
@@ -2281,6 +2282,10 @@ TEST_F(CompilerTest, CodeGenCast6) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\cast6.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenCast7) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\cast7.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenCbuf_init_static) {
   CodeGenTest(L"..\\CodeGenHLSL\\cbuf_init_static.hlsl");
 }