Bläddra i källkod

[spirv] Consider base structs when calculating layout. (#2405)

Ehsan 6 år sedan
förälder
incheckning
9b70ec14d4

+ 24 - 0
tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp

@@ -252,6 +252,30 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
     uint32_t maxAlignment = 0;
     uint32_t structSize = 0;
 
+    // If this struct is derived from some other structs, place an implicit
+    // field at the very beginning for the base struct.
+    if (const auto *cxxDecl = dyn_cast<CXXRecordDecl>(structType->getDecl())) {
+      for (const auto base : cxxDecl->bases()) {
+        uint32_t memberAlignment = 0, memberSize = 0;
+        std::tie(memberAlignment, memberSize) =
+            getAlignmentAndSize(base.getType(), rule, isRowMajor, stride);
+
+        if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+            rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
+            rule == SpirvLayoutRule::FxcCTBuffer) {
+          alignUsingHLSLRelaxedLayout(base.getType(), memberSize,
+                                      memberAlignment, &structSize);
+        } else {
+          structSize = roundToPow2(structSize, memberAlignment);
+        }
+
+        // The base alignment of the structure is N, where N is the largest
+        // base alignment value of any of its members...
+        maxAlignment = std::max(maxAlignment, memberAlignment);
+        structSize += memberSize;
+      }
+    }
+
     for (const auto *field : structType->getDecl()->fields()) {
       uint32_t memberAlignment = 0, memberSize = 0;
       std::tie(memberAlignment, memberSize) =

+ 49 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.cbuffer.derived-struct.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T vs_6_2 -E main -enable-16bit-types
+
+struct VertexInput {
+  float4 position : POSITION;
+};
+
+struct PixelOutput {
+  float4 position : SV_POSITION;
+};
+
+// CHECK: OpMemberDecorate %Base1 0 Offset 0
+struct Base1 {
+  float4 foo1;
+};
+
+// CHECK: OpMemberDecorate %Base2 0 Offset 0
+// CHECK: OpMemberDecorate %Base2 1 Offset 16
+struct Base2 : Base1 {
+  float4 foo2;
+};
+
+// CHECK: OpMemberDecorate %Derived 0 Offset 0
+// CHECK: OpMemberDecorate %Derived 1 Offset 32
+// CHECK: OpMemberDecorate %Derived 2 Offset 48
+struct Derived : Base2 {
+  float4 foo3;
+  float4 foo4;
+};
+
+// CHECK: OpMemberDecorate %type_constantData 0 Offset 0
+// CHECK: OpMemberDecorate %type_constantData 1 Offset 64
+cbuffer constantData : register(b0) {
+  Derived derivedData;
+  float4x4 MVP;
+}
+
+
+// CHECK:             %Base1 = OpTypeStruct %v4float
+// CHECK:             %Base2 = OpTypeStruct %Base1 %v4float
+// CHECK:           %Derived = OpTypeStruct %Base2 %v4float %v4float
+// CHECK: %type_constantData = OpTypeStruct %Derived %mat4v4float
+
+PixelOutput main(const VertexInput vertex) {
+  PixelOutput pixel;
+  pixel.position = mul(vertex.position, MVP);
+
+  return pixel;
+}
+

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1711,6 +1711,9 @@ TEST_F(FileTest, VulkanLayoutCBufferNestedEmptyStd140) {
 TEST_F(FileTest, VulkanLayoutCBufferBoolean) {
   runFileTest("vk.layout.cbuffer.boolean.hlsl");
 }
+TEST_F(FileTest, VulkanLayoutCBufferDerivedStruct) {
+  runFileTest("vk.layout.cbuffer.derived-struct.hlsl");
+}
 TEST_F(FileTest, VulkanLayoutRWStructuredBufferBoolean) {
   runFileTest("vk.layout.rwstructuredbuffer.boolean.hlsl");
 }