Sfoglia il codice sorgente

Fix SPIRV struct reconstruction with bitfields (#5390)

HLSL/SPIR-V structs have some layout differences due to bitfields being
squashed.
The reconstruction logic was using the AST layout, and not the new
SPIR-V layout, meaning we could generate invalid indices during
extraction/construction.

This PR fixes a potential bug with bitfields, while making the code
reusable.

---------

Signed-off-by: Nathan Gauër <[email protected]>
Nathan Gauër 2 anni fa
parent
commit
dcf754c264

+ 89 - 56
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -584,6 +584,54 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
   return output;
   return output;
 }
 }
 
 
+// Calls `operation` on for each field in the base and derives class defined by
+// `recordType`. The `operation` will receive the AST type linked to the field,
+// the SPIRV type linked to the field, and the index of the field in the final
+// SPIR-V representation. This index of the field can vary from the AST
+// field-index because bitfields are merged into a single field in the SPIR-V
+// representation.
+//
+// If the operation returns false, we stop processing fields.
+void forEachSpirvField(
+    const RecordType *recordType, const StructType *spirvType,
+    std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
+                       const StructType::FieldInfo &field)>
+        operation) {
+  const auto *cxxDecl = recordType->getAsCXXRecordDecl();
+  const auto *recordDecl = recordType->getDecl();
+
+  // Iterate through the base class (one field per base class).
+  // Bases cannot be melded into 1 field like bitfields, simple iteration.
+  uint32_t lastConvertedIndex = 0;
+  size_t astFieldIndex = 0;
+  for (const auto &base : cxxDecl->bases()) {
+    const auto &type = base.getType();
+    const auto &spirvField = spirvType->getFields()[astFieldIndex];
+    if (!operation(spirvField.fieldIndex, type, spirvField)) {
+      return;
+    }
+    lastConvertedIndex = spirvField.fieldIndex;
+    ++astFieldIndex;
+  }
+
+  // Iterate through the derived class fields. Field could be merged.
+  for (const auto *field : recordDecl->fields()) {
+    const auto &spirvField = spirvType->getFields()[astFieldIndex];
+    const uint32_t currentFieldIndex = spirvField.fieldIndex;
+    if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
+      ++astFieldIndex;
+      continue;
+    }
+
+    const auto &type = field->getType();
+    if (!operation(currentFieldIndex, type, spirvField)) {
+      return;
+    }
+    lastConvertedIndex = currentFieldIndex;
+    ++astFieldIndex;
+  }
+}
+
 } // namespace
 } // namespace
 
 
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
@@ -6410,29 +6458,26 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,
 
 
   // Structs
   // Structs
   if (const auto *recordType = valType->getAs<RecordType>()) {
   if (const auto *recordType = valType->getAs<RecordType>()) {
-    uint32_t index = 0;
-    llvm::SmallVector<SpirvInstruction *, 4> elements;
+    assert(recordType->isStructureType());
 
 
-    // If the struct inherits from other structs, visit the bases.
-    const auto *decl = valType->getAsCXXRecordDecl();
-    for (auto baseIt = decl->bases_begin(), baseIe = decl->bases_end();
-         baseIt != baseIe; ++baseIt, ++index) {
-      SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
-          baseIt->getType(), srcVal, {index}, loc, range);
-      subSrcVal->setLayoutRule(srcVal->getLayoutRule());
-      elements.push_back(
-          reconstructValue(subSrcVal, baseIt->getType(), dstLR, loc, range));
-    }
+    LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
+    const StructType *spirvStructType =
+        lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar());
+
+    llvm::SmallVector<SpirvInstruction *, 4> elements;
+    forEachSpirvField(
+        recordType, spirvStructType,
+        [&](size_t spirvFieldIndex, const QualType &fieldType,
+            const auto &field) {
+          SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
+              fieldType, srcVal, {static_cast<uint32_t>(spirvFieldIndex)}, loc, range);
+          subSrcVal->setLayoutRule(srcVal->getLayoutRule());
+          elements.push_back(
+              reconstructValue(subSrcVal, fieldType, dstLR, loc, range));
+
+          return true;
+        });
 
 
-    // Go over struct fields.
-    for (const auto *field : recordType->getDecl()->fields()) {
-      SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
-          field->getType(), srcVal, {index}, loc, range);
-      subSrcVal->setLayoutRule(srcVal->getLayoutRule());
-      elements.push_back(
-          reconstructValue(subSrcVal, field->getType(), dstLR, loc, range));
-      ++index;
-    }
     auto *result = spvBuilder.createCompositeConstruct(
     auto *result = spvBuilder.createCompositeConstruct(
         valType, elements, srcVal->getSourceLocation(), range);
         valType, elements, srcVal->getSourceLocation(), range);
     result->setLayoutRule(dstLR);
     result->setLayoutRule(dstLR);
@@ -6947,47 +6992,35 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
                                                       SourceRange range) {
                                                       SourceRange range) {
   assert(astStructType->isStructureType());
   assert(astStructType->isStructureType());
 
 
-  const auto *structDecl = astStructType->getAsStructureType()->getDecl();
   LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
   LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
   const StructType *spirvStructType =
   const StructType *spirvStructType =
       lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
       lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
-
   uint32_t vectorIndex = 0;
   uint32_t vectorIndex = 0;
   uint32_t elemCount = 1;
   uint32_t elemCount = 1;
-  uint32_t lastConvertedIndex = 0;
   llvm::SmallVector<SpirvInstruction *, 4> members;
   llvm::SmallVector<SpirvInstruction *, 4> members;
-  for (auto field = structDecl->field_begin(); field != structDecl->field_end();
-       field++) {
-    // Multiple bitfields can share the same storing type. In such case, we only
-    // want to append the whole storage once.
-    const size_t astFieldIndex =
-        std::distance(structDecl->field_begin(), field);
-    const uint32_t currentFieldIndex =
-        spirvStructType->getFields()[astFieldIndex].fieldIndex;
-    if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
-      continue;
-    }
-    lastConvertedIndex = currentFieldIndex;
-
-    if (isScalarType(field->getType())) {
-      members.push_back(spvBuilder.createCompositeExtract(
-          elemType, vector, {vectorIndex++}, loc, range));
-      continue;
-    }
-
-    if (isVectorType(field->getType(), nullptr, &elemCount)) {
-      llvm::SmallVector<uint32_t, 4> indices;
-      for (uint32_t i = 0; i < elemCount; ++i)
-        indices.push_back(vectorIndex++);
-
-      members.push_back(spvBuilder.createVectorShuffle(
-          astContext.getExtVectorType(elemType, elemCount), vector, vector,
-          indices, loc, range));
-      continue;
-    }
-
-    assert(false && "unhandled type");
-  }
+  forEachSpirvField(astStructType->getAs<RecordType>(), spirvStructType,
+                    [&](size_t spirvFieldIndex, const QualType &fieldType,
+                        const auto &field) {
+                      if (isScalarType(fieldType)) {
+                        members.push_back(spvBuilder.createCompositeExtract(
+                            elemType, vector, {vectorIndex++}, loc, range));
+                        return true;
+                      }
+
+                      if (isVectorType(fieldType, nullptr, &elemCount)) {
+                        llvm::SmallVector<uint32_t, 4> indices;
+                        for (uint32_t i = 0; i < elemCount; ++i)
+                          indices.push_back(vectorIndex++);
+
+                        members.push_back(spvBuilder.createVectorShuffle(
+                            astContext.getExtVectorType(elemType, elemCount),
+                            vector, vector, indices, loc, range));
+                        return true;
+                      }
+
+                      assert(false && "unhandled type");
+                      return false;
+                    });
 
 
   return spvBuilder.createCompositeConstruct(
   return spvBuilder.createCompositeConstruct(
       astStructType, members, vector->getSourceLocation(), range);
       astStructType, members, vector->getSourceLocation(), range);

+ 60 - 0
tools/clang/test/CodeGenSPIRV/op.structured-buffer.reconstruct.bitfield.hlsl

@@ -0,0 +1,60 @@
+// RUN: %dxc -T cs_6_0 -E main -HV 2021
+
+struct Base {
+  uint base;
+};
+
+struct Derived : Base {
+  uint a;
+  uint b : 3;
+  uint c : 3;
+  uint d;
+};
+
+RWStructuredBuffer<Derived> g_probes : register(u0);
+
+[numthreads(64u, 1u, 1u)]
+void main(uint3 dispatchThreadId : SV_DispatchThreadID) {
+
+// CHECK:     [[p:%\w+]] = OpVariable %_ptr_Function_Derived_0 Function
+  Derived p;
+
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_Base_0 [[p]] %uint_0
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp]] %int_0
+// CHECK:                  OpStore [[tmp]] %uint_5
+  p.base = 5;
+
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_1
+// CHECK:                  OpStore [[tmp]] %uint_1
+  p.a = 1;
+
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2
+// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]]
+// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_2 %uint_0 %uint_3
+// CHECK:                  OpStore [[tmp]] [[value]]
+  p.b = 2;
+
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2
+// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]]
+// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_3 %uint_3 %uint_3
+// CHECK:                  OpStore [[tmp]] [[value]]
+  p.c = 3;
+
+// CHECK:   [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_3
+// CHECK:                  OpStore [[tmp]] %uint_4
+  p.d = 4;
+
+
+// CHECK:     [[p:%\d+]] = OpLoad %Derived_0 [[p]]
+// CHECK:   [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_Derived %g_probes %int_0 %uint_0
+// CHECK:   [[tmp:%\d+]] = OpCompositeExtract %Base_0 [[p]] 0
+// CHECK:   [[tmp:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+// CHECK:  [[base:%\d+]] = OpCompositeConstruct %Base [[tmp]]
+// CHECK:  [[mem1:%\d+]] = OpCompositeExtract %uint [[p]] 1
+// CHECK:  [[mem2:%\d+]] = OpCompositeExtract %uint [[p]] 2
+// CHECK:  [[mem3:%\d+]] = OpCompositeExtract %uint [[p]] 3
+// CHECK:   [[tmp:%\d+]] = OpCompositeConstruct %Derived [[base]] [[mem1]] [[mem2]] [[mem3]]
+// CHECK:                  OpStore [[ptr]] [[tmp]]
+	g_probes[0] = p;
+}
+

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -472,6 +472,9 @@ TEST_F(FileTest, OpStructuredBufferAccess) {
 TEST_F(FileTest, OpStructuredBufferAccessBitfield) {
 TEST_F(FileTest, OpStructuredBufferAccessBitfield) {
   runFileTest("op.structured-buffer.access.bitfield.hlsl");
   runFileTest("op.structured-buffer.access.bitfield.hlsl");
 }
 }
+TEST_F(FileTest, OpStructuredBufferReconstructBitfield) {
+  runFileTest("op.structured-buffer.reconstruct.bitfield.hlsl");
+}
 TEST_F(FileTest, OpRWStructuredBufferAccess) {
 TEST_F(FileTest, OpRWStructuredBufferAccess) {
   runFileTest("op.rw-structured-buffer.access.hlsl");
   runFileTest("op.rw-structured-buffer.access.hlsl");
 }
 }