2
0
Эх сурвалжийг харах

[spirv] Add support for casting from derived to base (#1052)

This is for the explicit (Base)derived kind of cast.
Lei Zhang 7 жил өмнө
parent
commit
7a86911302

+ 4 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -440,7 +440,8 @@ inline uint32_t getNumBaseClasses(QualType type) {
 /// following the cast chain.
 void getBaseClassIndices(const CastExpr *expr,
                          llvm::SmallVectorImpl<uint32_t> *indices) {
-  assert(expr->getCastKind() == CK_UncheckedDerivedToBase);
+  assert(expr->getCastKind() == CK_UncheckedDerivedToBase ||
+         expr->getCastKind() == CK_HLSLDerivedToBase);
 
   indices->clear();
 
@@ -2177,7 +2178,8 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
         processFlatConversion(toType, evalType, subExprId, expr->getExprLoc());
     return SpirvEvalInfo(valId).setRValue();
   }
-  case CastKind::CK_UncheckedDerivedToBase: {
+  case CastKind::CK_UncheckedDerivedToBase:
+  case CastKind::CK_HLSLDerivedToBase: {
     // Find the index sequence of the base to which we are casting
     llvm::SmallVector<uint32_t, 4> baseIndices;
     getBaseClassIndices(expr, &baseIndices);

+ 14 - 0
tools/clang/test/CodeGenSPIRV/oo.inheritance.hlsl

@@ -72,6 +72,20 @@ float4 main() : SV_Target {
 // CHECK-NEXT:                   OpStore [[d]] {{%\d+}}
     dd.d  = 9.;
 
+    DerivedAgain dd2;
+
+// CHECK-NEXT:  [[dd_base_ptr:%\d+]] = OpAccessChain %_ptr_Function_Derived %dd %uint_0
+// CHECK-NEXT:      [[dd_base:%\d+]] = OpLoad %Derived [[dd_base_ptr]]
+// CHECK-NEXT: [[dd2_base_ptr:%\d+]] = OpAccessChain %_ptr_Function_Derived %dd2 %uint_0
+// CHECK-NEXT:                         OpStore [[dd2_base_ptr]] [[dd_base]]
+    (Derived)dd2 = (Derived)dd;
+
+// CHECK-NEXT:   [[d_base_ptr:%\d+]] = OpAccessChain %_ptr_Function_Base %d %uint_0
+// CHECK-NEXT:       [[d_base:%\d+]] = OpLoad %Base [[d_base_ptr]]
+// CHECK-NEXT: [[dd2_base_ptr:%\d+]] = OpAccessChain %_ptr_Function_Base %dd2 %uint_0 %uint_0
+// CHECK-NEXT:                         OpStore [[dd2_base_ptr]] [[d_base]]
+    (Base)dd2    = (Base)d;
+
     // Make sure reads are good
 // CHECK:        [[base:%\d+]] = OpAccessChain %_ptr_Function_Base %d %uint_0
 // CHECK-NEXT:        {{%\d+}} = OpAccessChain %_ptr_Function_v4float [[base]] %int_0