Parcourir la source

Allow cstyle cast as lvalue. (#137)

Xiang Li il y a 8 ans
Parent
commit
c81ccb7f75

+ 13 - 0
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -3380,6 +3380,19 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
     llvm::Value *bitcast = Builder.CreateBitCast(LV.getAddress(), ResultType);
     return MakeAddrLValue(bitcast, ToType);
   }
+  case CK_FlatConversion: {
+    // HLSL only single inheritance.
+    // Just bitcast.
+    QualType ToType = getContext().getLValueReferenceType(E->getType());
+
+    LValue LV = EmitLValue(E->getSubExpr());
+    llvm::Value *This = LV.getAddress();
+
+    // bitcast to target type
+    llvm::Type *ResultType = ConvertType(ToType);
+    llvm::Value *bitcast = Builder.CreateBitCast(This, ResultType);
+    return MakeAddrLValue(bitcast, ToType);
+  }
   // HLSL Change Ends
   case CK_ZeroToOCLEvent:
     llvm_unreachable("NULL to OpenCL event lvalue cast is not valid");

+ 1 - 1
tools/clang/lib/Sema/SemaCast.cpp

@@ -2106,10 +2106,10 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   // HLSL Change Starts
   // Check for HLSL vector or matrix shrinking.
   if (ValueKind == VK_RValue && 
+      !FunctionalStyle &&
       !isPlaceholder(BuiltinType::Overload) &&
       Self.getLangOpts().HLSL &&
       SrcExpr.get()->isLValue() &&
-      hlsl::IsHLSLVecMatType(SrcExpr.get()->getType()) &&
       hlsl::IsConversionToLessOrEqualElements(&Self, SrcExpr, DestType, true)) {
     ValueKind = VK_LValue;
   }

+ 22 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -5944,6 +5944,28 @@ bool HLSLExternalSource::IsConversionToLessOrEqualElements(
           targetType.getCanonicalType().getUnqualifiedType()) {
     return true;
   }
+  // DerivedFrom is less.
+  if (sourceTypeInfo.ShapeKind == AR_TOBJ_COMPOUND ||
+      GetTypeObjectKind(sourceType) == AR_TOBJ_COMPOUND) {
+    const RecordType *targetRT = targetType->getAsStructureType();
+    if (!targetRT)
+      targetRT = dyn_cast<RecordType>(targetType);
+
+    const RecordType *sourceRT = sourceType->getAsStructureType();
+    if (!sourceRT)
+      sourceRT = dyn_cast<RecordType>(sourceType);
+
+    if (targetRT && sourceRT) {
+      RecordDecl *targetRD = targetRT->getDecl();
+      RecordDecl *sourceRD = sourceRT->getDecl();
+      const CXXRecordDecl *targetCXXRD = dyn_cast<CXXRecordDecl>(targetRD);
+      const CXXRecordDecl *sourceCXXRD = dyn_cast<CXXRecordDecl>(sourceRD);
+      if (targetCXXRD && sourceCXXRD) {
+        if (sourceCXXRD->isDerivedFrom(targetCXXRD))
+          return true;
+      }
+    }
+  }
 
   if (sourceTypeInfo.ShapeKind != AR_TOBJ_SCALAR &&
     sourceTypeInfo.ShapeKind != AR_TOBJ_VECTOR &&

+ 23 - 0
tools/clang/test/CodeGenHLSL/functionalCast.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: @main
+
+struct P {
+   float4 x;
+};
+
+struct C : P {
+    int4 y;
+};
+
+P p;
+
+float4 x;
+int4 y;
+float4 main(float2 a : A, float b : B) : SV_Target
+{
+    C c;
+    (P)c = p;
+    c.y = y;
+    return half4(x).x + c.x + c.y;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -346,6 +346,7 @@ public:
   TEST_METHOD(CodeGenFirstbitLo)
   TEST_METHOD(CodeGenFloatMaxtessfactor)
   TEST_METHOD(CodeGenFModPS)
+  TEST_METHOD(CodeGenFunctionalCast)
   TEST_METHOD(CodeGenGather)
   TEST_METHOD(CodeGenGatherCmp)
   TEST_METHOD(CodeGenGatherCubeOffset)
@@ -2095,6 +2096,10 @@ TEST_F(CompilerTest, CodeGenFModPS) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\fmodPS.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenFunctionalCast) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\functionalCast.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenGather) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\gather.hlsl");
 }