Prechádzať zdrojové kódy

[spirv] Fix read/write to base struct member from derived struct (#3263)

* [spirv] Support base struct member access.

* [spirv] Bug fix for non-overridden methods.
Ehsan 4 rokov pred
rodič
commit
753e2a4c0a

+ 37 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -341,6 +341,34 @@ void getBaseClassIndices(const CastExpr *expr,
   indices->clear();
   indices->clear();
 
 
   QualType derivedType = expr->getSubExpr()->getType();
   QualType derivedType = expr->getSubExpr()->getType();
+
+  // There are two types of UncheckedDerivedToBase/HLSLDerivedToBase casts:
+  //
+  // The first is when a derived object tries to access a member in the base.
+  // For example: derived.base_member.
+  // ImplicitCastExpr 'Base' lvalue <UncheckedDerivedToBase (Base)>
+  // `-DeclRefExpr 'Derived' lvalue Var 0x1f0d9bb2890 'derived' 'Derived'
+  //
+  // The second is when a pointer of the dervied is used to access members or
+  // methods of the base. There are currently no pointers in HLSL, but the
+  // method defintions can use the "this" pointer.
+  // For example:
+  // class Base { float value; };
+  // class Derviced : Base {
+  //   float4 getBaseValue() { return value; }
+  // };
+  //
+  // In this example, the 'this' pointer (pointing to Derived) is used inside
+  // 'getBaseValue', which is then cast to a Base pointer:
+  //
+  // ImplicitCastExpr 'Base *' <UncheckedDerivedToBase (Base)>
+  // `-CXXThisExpr 'Derviced *' this
+  //
+  // Therefore in order to obtain the derivedDecl below, we must make sure that
+  // we handle the second case too by using the pointee type.
+  if (derivedType->isPointerType())
+    derivedType = derivedType->getPointeeType();
+
   const auto *derivedDecl = derivedType->getAsCXXRecordDecl();
   const auto *derivedDecl = derivedType->getAsCXXRecordDecl();
 
 
   // Go through the base cast chain: for each of the derived to base cast, find
   // Go through the base cast chain: for each of the derived to base cast, find
@@ -363,6 +391,8 @@ void getBaseClassIndices(const CastExpr *expr,
 
 
     // Continue to proceed the next base in the chain
     // Continue to proceed the next base in the chain
     derivedType = baseType;
     derivedType = baseType;
+    if (derivedType->isPointerType())
+      derivedType = derivedType->getPointeeType();
     derivedDecl = derivedType->getAsCXXRecordDecl();
     derivedDecl = derivedType->getAsCXXRecordDecl();
   }
   }
 }
 }
@@ -390,6 +420,10 @@ std::string getFnName(const FunctionDecl *fn) {
   return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
   return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
 }
 }
 
 
+bool isMemoryObjectDeclaration(SpirvInstruction *inst) {
+  return isa<SpirvVariable>(inst) || isa<SpirvFunctionParameter>(inst);
+}
+
 } // namespace
 } // namespace
 
 
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
@@ -2158,7 +2192,8 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
         // Based on SPIR-V spec, function parameter must always be in Function
         // Based on SPIR-V spec, function parameter must always be in Function
         // scope. If we pass a non-function scope argument, we need
         // scope. If we pass a non-function scope argument, we need
         // the legalization.
         // the legalization.
-        if (objInstr->getStorageClass() != spv::StorageClass::Function)
+        if (objInstr->getStorageClass() != spv::StorageClass::Function ||
+            !isMemoryObjectDeclaration(objInstr))
           beforeHlslLegalization = true;
           beforeHlslLegalization = true;
 
 
         args.push_back(objInstr);
         args.push_back(objInstr);
@@ -6654,8 +6689,7 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
             varDecl->getType(), VK_LValue);
             varDecl->getType(), VK_LValue);
 
 
     const Expr *base = collectArrayStructIndices(
     const Expr *base = collectArrayStructIndices(
-        indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
-        rawIndices, indices, isMSOutAttribute);
+        indexing->getBase(), rawIndex, rawIndices, indices, isMSOutAttribute);
 
 
     if (isMSOutAttribute && base) {
     if (isMSOutAttribute && base) {
       if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
       if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {

+ 99 - 0
tools/clang/test/CodeGenSPIRV/oo.struct.derived.methods.hlsl

@@ -0,0 +1,99 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// The derived struct methods do not override the base struct methods.
+
+struct A {
+  float4 base;
+  void SetBase(float4 v) { base = v; }
+  float4 GetBase() { return base; }
+};
+
+struct B : A {
+  float4 derived;
+  float4 GetDerived() { return derived; }
+  void SetDerived(float4 v) { derived = v; }
+};
+
+struct C : B {
+  float4 c_value;
+  float4 GetCValue() { return c_value; }
+  void SetValue(float4 v) { c_value = v; }
+};
+
+float4 main() : SV_Target {
+  C c;
+
+// CHECK:  [[v4f0:%\d+]] = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+// CHECK: [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %c %uint_0 %uint_0
+// CHECK:                  OpStore %param_var_v [[v4f0]]
+// CHECK:                  OpFunctionCall %void %A_SetBase [[A_ptr]] %param_var_v
+  c.SetBase(float4(0, 0, 0, 0));
+
+// CHECK: [[B_ptr:%\d+]] = OpAccessChain %_ptr_Function_B %c %uint_0
+// CHECK:                  OpStore %param_var_v_0 [[v4f0]]
+// CHECK:                  OpFunctionCall %void %B_SetDerived [[B_ptr]] %param_var_v_0
+  c.SetDerived(float4(0, 0, 0, 0));
+
+// CHECK:                  OpStore %param_var_v_1 [[v4f0]]
+// CHECK:                  OpFunctionCall %void %C_SetValue %c %param_var_v_1
+  c.SetValue(float4(0, 0, 0, 0));
+
+  return
+// CHECK: [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %c %uint_0 %uint_0
+// CHECK:       {{%\d+}} = OpFunctionCall %v4float %A_GetBase [[A_ptr]]
+    c.GetBase() +
+// CHECK: [[B_ptr:%\d+]] = OpAccessChain %_ptr_Function_B %c %uint_0
+// CHECK:       {{%\d+}} = OpFunctionCall %v4float %B_GetDerived [[B_ptr]]
+    c.GetDerived() +
+// CHECK:       {{%\d+}} = OpFunctionCall %v4float %C_GetCValue %c
+    c.GetCValue();
+}
+
+// Definition for: void A::SetBase(float4 v) { base = v;}
+//
+// CHECK:        %A_SetBase = OpFunction
+// CHECK:       %param_this = OpFunctionParameter %_ptr_Function_A
+// CHECK:                %v = OpFunctionParameter %_ptr_Function_v4float
+// CHECK:        [[v:%\d+]] = OpLoad %v4float %v
+// CHECK: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this %int_0
+// CHECK:                     OpStore [[base_ptr]] [[v]]
+
+// Definition for: void B::SetDerived(float4 v) { derived = v; }
+//
+// CHECK:        %B_SetDerived = OpFunction
+// CHECK:        %param_this_0 = OpFunctionParameter %_ptr_Function_B
+// CHECK:                 %v_0 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK:           [[v:%\d+]] = OpLoad %v4float %v_0
+// CHECK: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_0 %int_1
+// CHECK:                        OpStore [[derived_ptr]] [[v]]
+
+// Definition for: void C::SetValue(float4 v) { c_value = v; }
+//
+// CHECK:        %C_SetValue = OpFunction
+// CHECK:      %param_this_1 = OpFunctionParameter %_ptr_Function_C
+// CHECK:               %v_1 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK:         [[v:%\d+]] = OpLoad %v4float %v_1
+// CHECK: [[c_val_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_1 %int_1
+// CHECK:                      OpStore [[c_val_ptr]] [[v]]
+
+// Definition for A::float4 GetBase() { return base; }
+//
+// CHECK:        %A_GetBase = OpFunction
+// CHECK:     %param_this_2 = OpFunctionParameter %_ptr_Function_A
+// CHECK: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_2 %int_0
+// CHECK:                     OpLoad %v4float [[base_ptr]]
+
+
+// Definition for B::float4 GetDerived() { return derived; }
+//
+// CHECK:        %B_GetDerived = OpFunction
+// CHECK:        %param_this_3 = OpFunctionParameter %_ptr_Function_B
+// CHECK: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_3 %int_1
+// CHECK:                        OpLoad %v4float [[derived_ptr]]
+
+// Definition for C::float4 GetCValue() { return c_value; }
+//
+// CHECK:       %C_GetCValue = OpFunction
+// CHECK:      %param_this_4 = OpFunctionParameter %_ptr_Function_C
+// CHECK: [[c_val_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_4 %int_1
+// CHECK:                      OpLoad %v4float [[c_val_ptr]]

+ 177 - 0
tools/clang/test/CodeGenSPIRV/oo.struct.derived.methods.override.hlsl

@@ -0,0 +1,177 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// The derived methods override the base struct methods.
+
+struct A {
+  float4 base;
+};
+
+struct B : A {
+  float4 derived;
+
+  float4 GetBase() {return base;}
+  float4 GetDerived() {return derived;}
+  void SetBase(float4 v) { base = v;}
+  void SetDerived(float4 v) { derived = v; }
+};
+
+struct C : B {
+  float4 c_value;
+
+  float4 GetBase() { return base; }
+  float4 GetDerived() { return derived; }
+  float4 GetCValue() { return c_value; }
+  void SetBase(float4 v) { base = v; }
+  void SetDerived(float4 v) { derived = v; }
+  void SetValue(float4 v) { c_value = v; }
+};
+
+float4 main() : SV_Target {
+  B b;
+// CHECK:    [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %b %uint_0
+// CHECK: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK:                     OpStore [[base_ptr]] {{%\d+}}
+  b.base = float4(1, 1, 0, 1);
+// CHECK: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %b %int_1
+// CHECK:                        OpStore [[derived_ptr]] {{%\d+}}
+  b.derived = float4(1, 0, 1, 1);
+
+// CHECK: OpFunctionCall %void %B_SetBase %b %param_var_v
+  b.SetBase(float4(1, 0, 1, 1));
+// CHECK: OpFunctionCall %void %B_SetDerived %b %param_var_v_0
+  b.SetDerived(float4(1, 0, 1, 1));
+
+  C c;
+// CHECK:       [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %c %uint_0 %uint_0
+// CHECK:    [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK:                        OpStore [[base_ptr]] {{%\d+}}
+  c.base = float4(0,0,0,0);
+// CHECK:       [[B_ptr:%\d+]] = OpAccessChain %_ptr_Function_B %c %uint_0
+// CHECK: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[B_ptr]] %int_1
+// CHECK:                        OpStore [[derived_ptr]] {{%\d+}}
+  c.derived = float4(0,0,0,0);
+// CHECK: [[c_value_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %c %int_1
+// CHECK:                        OpStore [[c_value_ptr]] {{%\d+}}
+  c.c_value = float4(0,0,0,0);
+
+// CHECK: OpFunctionCall %void %C_SetBase %c %param_var_v_1
+  c.SetBase(float4(0, 0, 0, 0));
+// CHECK: OpFunctionCall %void %C_SetDerived %c %param_var_v_2
+  c.SetDerived(float4(0, 0, 0, 0));
+// CHECK: OpFunctionCall %void %C_SetValue %c %param_var_v_3
+  c.SetValue(float4(0, 0, 0, 0));
+
+// CHECK: OpFunctionCall %v4float %B_GetBase %b
+// CHECK: OpFunctionCall %v4float %B_GetDerived %b
+// CHECK: OpFunctionCall %v4float %C_GetBase %c
+// CHECK: OpFunctionCall %v4float %C_GetDerived %c
+// CHECK: OpFunctionCall %v4float %C_GetCValue %c
+  return b.GetBase() + b.GetDerived() + c.GetBase() + c.GetDerived() + c.GetCValue();
+}
+
+
+// Definition for: void B::SetBase(float4 v) { base = v;}
+// CHECK:             %B_SetBase = OpFunction
+// CHECK-NEXT:       %param_this = OpFunctionParameter %_ptr_Function_B
+// CHECK-NEXT:                %v = OpFunctionParameter %_ptr_Function_v4float
+// CHECK-NEXT:       %bb_entry_0 = OpLabel
+// CHECK-NEXT:        [[v:%\d+]] = OpLoad %v4float %v
+// CHECK-NEXT:    [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %param_this %uint_0
+// CHECK-NEXT: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK-NEXT:                     OpStore [[base_ptr]] [[v]]
+// CHECK-NEXT:                     OpReturn
+// CHECK-NEXT:                     OpFunctionEnd
+
+// Definition for: void B::SetDerived(float4 v) { derived = v; }
+// CHECK:             %B_SetDerived = OpFunction
+// CHECK-NEXT:        %param_this_0 = OpFunctionParameter %_ptr_Function_B
+// CHECK-NEXT:                 %v_0 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK-NEXT:          %bb_entry_1 = OpLabel
+// CHECK-NEXT:           [[v:%\d+]] = OpLoad %v4float %v_0
+// CHECK-NEXT: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_0 %int_1
+// CHECK-NEXT:                        OpStore [[derived_ptr]] [[v]]
+// CHECK-NEXT:                        OpReturn
+// CHECK-NEXT:                        OpFunctionEnd
+
+
+// Definition for: void C::SetBase(float4 v) { base = v; }
+// CHECK:             %C_SetBase = OpFunction
+// CHECK-NEXT:     %param_this_1 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:              %v_1 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK-NEXT:       %bb_entry_2 = OpLabel
+// CHECK-NEXT:        [[v:%\d+]] = OpLoad %v4float %v_1
+// CHECK-NEXT:    [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %param_this_1 %uint_0 %uint_0
+// CHECK-NEXT: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK-NEXT:                     OpStore [[base_ptr]] [[v]]
+// CHECK-NEXT:                     OpReturn
+// CHECK-NEXT:                     OpFunctionEnd
+
+// Definition for: void C::SetDerived(float4 v) { derived = v; }
+// CHECK:             %C_SetDerived = OpFunction
+// CHECK-NEXT:        %param_this_2 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:                 %v_2 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK-NEXT:          %bb_entry_3 = OpLabel
+// CHECK-NEXT:           [[v:%\d+]] = OpLoad %v4float %v_2
+// CHECK-NEXT:       [[B_ptr:%\d+]] = OpAccessChain %_ptr_Function_B %param_this_2 %uint_0
+// CHECK-NEXT: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[B_ptr]] %int_1
+// CHECK-NEXT:                        OpStore [[derived_ptr]] [[v]]
+// CHECK-NEXT:                        OpReturn
+// CHECK-NEXT:                        OpFunctionEnd
+
+// Definition for: void C::SetValue(float4 v) { c_value = v; }
+// CHECK:               %C_SetValue = OpFunction
+// CHECK-NEXT:        %param_this_3 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:                 %v_3 = OpFunctionParameter %_ptr_Function_v4float
+// CHECK-NEXT:          %bb_entry_4 = OpLabel
+// CHECK-NEXT:           [[v:%\d+]] = OpLoad %v4float %v_3
+// CHECK-NEXT: [[c_value_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_3 %int_1
+// CHECK-NEXT:                        OpStore [[c_value_ptr:%\d+]]
+// CHECK-NEXT:                        OpReturn
+// CHECK-NEXT:                        OpFunctionEnd
+
+// Definition for: float4 B::GetBase() {return base;}
+// CHECK:             %B_GetBase = OpFunction
+// CHECK-NEXT:     %param_this_4 = OpFunctionParameter %_ptr_Function_B
+// CHECK-NEXT:       %bb_entry_5 = OpLabel
+// CHECK-NEXT:    [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %param_this_4 %uint_0
+// CHECK-NEXT: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK-NEXT:     [[base:%\d+]] = OpLoad %v4float [[base_ptr]]
+// CHECK-NEXT:                     OpReturnValue [[base]]
+// CHECK-NEXT:                     OpFunctionEnd
+
+// Definition for: float4 B::GetDerived() {return derived;}
+// CHECK:             %B_GetDerived = OpFunction
+// CHECK-NEXT:        %param_this_5 = OpFunctionParameter %_ptr_Function_B
+// CHECK-NEXT:          %bb_entry_6 = OpLabel
+// CHECK-NEXT: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_5 %int_1
+// CHECK-NEXT:     [[derived:%\d+]] = OpLoad %v4float [[derived_ptr]]
+// CHECK-NEXT:                        OpReturnValue [[derived]]
+// CHECK-NEXT:                        OpFunctionEnd
+
+// Definition for: float4 C::GetBase() {return base;}
+// CHECK:             %C_GetBase = OpFunction
+// CHECK-NEXT:     %param_this_6 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:       %bb_entry_7 = OpLabel
+// CHECK-NEXT:    [[A_ptr:%\d+]] = OpAccessChain %_ptr_Function_A %param_this_6 %uint_0 %uint_0
+// CHECK-NEXT: [[base_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[A_ptr]] %int_0
+// CHECK-NEXT:     [[base:%\d+]] = OpLoad %v4float [[base_ptr]]
+// CHECK-NEXT:                     OpReturnValue [[base]]
+// CHECK-NEXT:                     OpFunctionEnd
+
+// Definition for: float4 C::GetDerived() {return derived;}
+// CHECK:             %C_GetDerived = OpFunction
+// CHECK-NEXT:        %param_this_7 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:          %bb_entry_8 = OpLabel
+// CHECK-NEXT:       [[B_ptr:%\d+]] = OpAccessChain %_ptr_Function_B %param_this_7 %uint_0
+// CHECK-NEXT: [[derived_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[B_ptr]] %int_1
+// CHECK-NEXT:     [[derived:%\d+]] = OpLoad %v4float [[derived_ptr]]
+// CHECK-NEXT:                        OpReturnValue [[derived]]
+// CHECK-NEXT:                        OpFunctionEnd
+
+// CHECK:              %C_GetCValue = OpFunction
+// CHECK-NEXT:        %param_this_8 = OpFunctionParameter %_ptr_Function_C
+// CHECK-NEXT:          %bb_entry_9 = OpLabel
+// CHECK-NEXT: [[c_value_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %param_this_8 %int_1
+// CHECK-NEXT:     [[c_value:%\d+]] = OpLoad %v4float [[c_value_ptr]]
+// CHECK-NEXT:                        OpReturnValue [[c_value]]
+// CHECK-NEXT:                        OpFunctionEnd

+ 7 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -601,6 +601,13 @@ TEST_F(FileTest, StructMethodCall) {
   setBeforeHLSLLegalization();
   setBeforeHLSLLegalization();
   runFileTest("oo.struct.method.hlsl");
   runFileTest("oo.struct.method.hlsl");
 }
 }
+TEST_F(FileTest, StructDerivedMethods) {
+  setBeforeHLSLLegalization();
+  runFileTest("oo.struct.derived.methods.hlsl");
+}
+TEST_F(FileTest, StructDerivedMethodsOverride) {
+  runFileTest("oo.struct.derived.methods.override.hlsl");
+}
 TEST_F(FileTest, StructThisAlias) {
 TEST_F(FileTest, StructThisAlias) {
   setBeforeHLSLLegalization();
   setBeforeHLSLLegalization();
   runFileTest("oo.struct.this.alias.hlsl");
   runFileTest("oo.struct.this.alias.hlsl");