Преглед изворни кода

spirv: get field index from SPIR-V type, not AST (#4806)

Before this commit, struct field indices were computed from the AST
type. This was correct as SPIR-V types and AST types had the same
structure.

When adding bitfields, we started to squash some fields together, as
long as some rules were respected (bitfield, same type, doesn't span
over multiple type-sized words, etc).
This means indices now diverge between SPIR-V types and AST types.

This requires us to lower the AST type before the lowering pass, when
generating instructions like OpAccessChain.
A cleaner alternative would be to lower all types before instruction
generation, and only operate on SPIR-V type. But cost of making this
refactoring is large, and we don't believe it is worth it since we plan
to upstream to clang (hence rewrite this code).

note: this code prevents AST indices to diverge from SPIR-V indices when
computed. This is just a safeguard until bitfields are in place in case
I made a mistake with the layout rules, allowing us to catch those bugs
faster.

Signed-off-by: Nathan Gauër <[email protected]>
Nathan Gauër пре 2 година
родитељ
комит
8f279baaef
1 измењених фајлова са 77 додато и 17 уклоњено
  1. 77 17
      tools/clang/lib/SPIRV/SpirvEmitter.cpp

+ 77 - 17
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -15,6 +15,7 @@
 
 #include "AlignmentSizeCalculator.h"
 #include "InitListHandler.h"
+#include "LowerTypeVisitor.h"
 #include "RawBufferMethods.h"
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/HlslIntrinsicOp.h"
@@ -542,6 +543,71 @@ bool isVkRawBufferLoadIntrinsic(const clang::FunctionDecl *FD) {
   return true;
 }
 
+// Takes an AST member type, and determines its index in the equivalent SPIR-V
+// struct type. This is required as the struct layout might change between the
+// AST representation and SPIR-V representation.
+uint32_t getFieldIndexInStruct(const SpirvCodeGenOptions &spirvOptions,
+                               LowerTypeVisitor &lowerTypeVisitor,
+                               const MemberExpr *expr) {
+  // If we are accessing a derived struct, we need to account for the number
+  // of base structs, since they are placed as fields at the beginning of the
+  // derived struct.
+  auto baseType = expr->getBase()->getType();
+  if (baseType->isPointerType()) {
+    baseType = baseType->getPointeeType();
+  }
+
+  const auto *fieldDecl =
+      dynamic_cast<const FieldDecl *>(expr->getMemberDecl());
+  assert(fieldDecl);
+  const uint32_t indexAST =
+      getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
+
+  // The AST type index is not representative of the SPIR-V type index
+  // because we might squash some fields (bitfields by ex.).
+  // What we need is to match each AST node with the squashed field and then,
+  // determine the real index.
+  const SpirvType *spvType = lowerTypeVisitor.lowerType(
+      baseType, spirvOptions.sBufferLayoutRule, llvm::None, SourceLocation());
+  assert(spvType);
+
+  const auto st = dynamic_cast<const StructType *>(spvType);
+  assert(st != nullptr);
+  const auto &fields = st->getFields();
+  assert(indexAST <= fields.size());
+
+  // Some fields in SPIR-V share the same index (bitfields). Computing the final
+  // index of the requested field.
+  uint32_t indexSPV = 0;
+  for (size_t i = 1; i <= indexAST; i++) {
+    // Do not remove this condition. This is required to support inheritance:
+    // 1. SPIR-V composite first element is the parent type:
+    //  by ex "OpTypeStruct %base_struct %float".
+    // 2. if the parent type is an empty class, it's size it zero, hence
+    // "%float" offset is also 0.
+    //
+    // A way to detect such cases is to check for type difference: fields cannot
+    // be merged if the type is different.
+    if (fields[i - 1].type != fields[i].type) {
+      indexSPV++;
+      continue;
+    }
+
+    if (fields[i - 1].offset.getValueOr(0) != fields[i].offset.getValueOr(0)) {
+      indexSPV++;
+      continue;
+    }
+  }
+
+  // TODO(issue #4140): remove once bitfields are implemented.
+  // This is just a safeguard until bitfield support is in. Before bitfields,
+  // AST indices were always correct, so this function should not change that
+  // behavior. Once the bitfield support is in, indices will start to diverge,
+  // and this assert should be removed.
+  assert(indexSPV == indexAST);
+  return indexSPV;
+}
+
 } // namespace
 
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
@@ -7614,23 +7680,17 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
       }
     }
 
-    // Append the index of the current level
-    const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
-    assert(fieldDecl);
-    // If we are accessing a derived struct, we need to account for the number
-    // of base structs, since they are placed as fields at the beginning of the
-    // derived struct.
-    auto baseType = indexing->getBase()->getType();
-    if (baseType->isPointerType()) {
-      baseType = baseType->getPointeeType();
-    }
-    const uint32_t index =
-        getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
-    if (rawIndex) {
-      rawIndices->push_back(index);
-    } else {
-      indices->push_back(spvBuilder.getConstantInt(
-          astContext.IntTy, llvm::APInt(32, index, true)));
+    {
+      LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
+      const uint32_t fieldIndex =
+          getFieldIndexInStruct(spirvOptions, lowerTypeVisitor, indexing);
+
+      if (rawIndex) {
+        rawIndices->push_back(fieldIndex);
+      } else {
+        indices->push_back(spvBuilder.getConstantInt(
+            astContext.IntTy, llvm::APInt(32, fieldIndex, true)));
+      }
     }
 
     return base;