瀏覽代碼

[SPIR-V] HLSL2021: initial bitfield implementation (#4831)

This commit adds the logic in the SPIR-V backend to generate proper
bitfields. Bitfield are packed using a first-fit method, linearly
packing them, but not mixing types. Goal is to follow C/C++ rules.

Bitfield merging was initially guessed from the offset stored in the FieldInfo.
This offset is not always available and has a very specific meaning.
When the struct is a function local variable, the layout rule is Void,
meaning we shouldn't assume any kind of byte offset, but rely on
construct index.
This commit adds a fieldIndex member to the FieldInfo struct, and this
field is used to determine if 2 fields are merged.

When doing a buffer texture load, the struct must be extracted from a
vector type, and rebuilt. This commit adds support for bitfield extraction for such
types. Fixing this helped me see scalar assignment were also failling in
some cases. Addressing bitfield extraction/insertion issues on with
commit.

Signed-off-by: Nathan Gauër <[email protected]>
Co-authored-by: Cassandra Beckley <[email protected]>
Nathan Gauër 2 年之前
父節點
當前提交
cce6fe0f43

+ 7 - 6
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -174,11 +174,11 @@ public:
                                           SourceLocation loc,
                                           SourceRange range = {});
 
-  /// \brief Creates a load instruction loading the value of the given
-  /// <result-type> from the given pointer. Returns the instruction pointer for
-  /// the loaded value.
-  SpirvLoad *createLoad(QualType resultType, SpirvInstruction *pointer,
-                        SourceLocation loc, SourceRange range = {});
+  /// \brief Creates a load sequence loading the value of the given
+  /// <result-type> from the given pointer (load + optional extraction,
+  /// ex:bitfield). Returns the instruction pointer for the loaded value.
+  SpirvInstruction *createLoad(QualType resultType, SpirvInstruction *pointer,
+                               SourceLocation loc, SourceRange range = {});
   SpirvLoad *createLoad(const SpirvType *resultType, SpirvInstruction *pointer,
                         SourceLocation loc, SourceRange range = {});
 
@@ -186,8 +186,9 @@ public:
   SpirvCopyObject *createCopyObject(QualType resultType,
                                     SpirvInstruction *pointer, SourceLocation);
 
-  /// \brief Creates a store instruction storing the given value into the given
+  /// \brief Creates a store sequence storing the given value into the given
   /// address. Returns the instruction pointer for the store instruction.
+  /// This function handles storing to bitfields.
   SpirvStore *createStore(SpirvInstruction *address, SpirvInstruction *value,
                    SourceLocation loc, SourceRange range = {});
 

+ 6 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -14,6 +14,7 @@
 #include "clang/AST/APValue.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/SourceLocation.h"
+#include "clang/SPIRV/SpirvType.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Optional.h"
@@ -203,6 +204,7 @@ public:
 
   void setRValue(bool rvalue = true) { isRValue_ = rvalue; }
   bool isRValue() const { return isRValue_; }
+  bool isLValue() const { return !isRValue_; }
 
   void setRelaxedPrecision() { isRelaxedPrecision_ = true; }
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
@@ -213,6 +215,9 @@ public:
   void setPrecise(bool p = true) { isPrecise_ = p; }
   bool isPrecise() const { return isPrecise_; }
 
+  void setBitfieldInfo(const BitfieldInfo &info) { bitfieldInfo = info; }
+  llvm::Optional<BitfieldInfo> getBitfieldInfo() const { return bitfieldInfo; }
+
   /// Legalization-specific code
   ///
   /// Note: the following two functions are currently needed in order to support
@@ -255,6 +260,7 @@ protected:
   bool isRelaxedPrecision_;
   bool isNonUniform_;
   bool isPrecise_;
+  llvm::Optional<BitfieldInfo> bitfieldInfo;
 };
 
 /// \brief OpCapability instruction

+ 26 - 6
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -30,6 +30,13 @@ enum class StructInterfaceType : uint32_t {
   UniformBuffer = 2,
 };
 
+struct BitfieldInfo {
+  // Offset of the bitfield, in bits, from the basetype start.
+  uint32_t offsetInBits;
+  // Size of the bitfield, in bits.
+  uint32_t sizeInBits;
+};
+
 class SpirvType {
 public:
   enum Kind {
@@ -290,14 +297,16 @@ class StructType : public SpirvType {
 public:
   struct FieldInfo {
   public:
-    FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
+    FieldInfo(const SpirvType *type_, uint32_t fieldIndex_,
+              llvm::StringRef name_ = "",
               llvm::Optional<uint32_t> offset_ = llvm::None,
               llvm::Optional<uint32_t> matrixStride_ = llvm::None,
               llvm::Optional<bool> isRowMajor_ = llvm::None,
               bool relaxedPrecision = false, bool precise = false)
-        : type(type_), name(name_), offset(offset_), sizeInBytes(llvm::None),
-          matrixStride(matrixStride_), isRowMajor(isRowMajor_),
-          isRelaxedPrecision(relaxedPrecision), isPrecise(precise) {
+        : type(type_), fieldIndex(fieldIndex_), name(name_), offset(offset_),
+          sizeInBytes(llvm::None), matrixStride(matrixStride_),
+          isRowMajor(isRowMajor_), isRelaxedPrecision(relaxedPrecision),
+          isPrecise(precise) {
       // A StructType may not contain any hybrid types.
       assert(!isa<HybridType>(type_));
     }
@@ -306,6 +315,10 @@ public:
 
     // The field's type.
     const SpirvType *type;
+    // The index of this field in the composite construct.
+    // When the struct contains bitfields, StructType index and construct index
+    // can diverge as we merge bitfields together.
+    uint32_t fieldIndex;
     // The field's name.
     std::string name;
     // The integer offset in bytes for this field.
@@ -320,6 +333,8 @@ public:
     bool isRelaxedPrecision;
     // Whether this field is marked as 'precise'.
     bool isPrecise;
+    // Information about the bitfield (if applicable).
+    llvm::Optional<BitfieldInfo> bitfield;
   };
 
   StructType(
@@ -467,9 +482,11 @@ public:
               clang::VKOffsetAttr *offset = nullptr,
               hlsl::ConstantPacking *packOffset = nullptr,
               const hlsl::RegisterAssignment *regC = nullptr,
-              bool precise = false)
+              bool precise = false,
+              llvm::Optional<BitfieldInfo> bitfield = llvm::None)
         : astType(astType_), name(name_), vkOffsetAttr(offset),
-          packOffsetAttr(packOffset), registerC(regC), isPrecise(precise) {}
+          packOffsetAttr(packOffset), registerC(regC), isPrecise(precise),
+          bitfield(std::move(bitfield)) {}
 
     // The field's type.
     QualType astType;
@@ -483,6 +500,9 @@ public:
     const hlsl::RegisterAssignment *registerC;
     // Whether this field is marked as 'precise'.
     bool isPrecise;
+    // Whether this field is a bitfield or not. If set to false, bitfield width
+    // value is undefined.
+    llvm::Optional<BitfieldInfo> bitfield;
   };
 
   HybridStructType(

+ 3 - 2
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -4292,8 +4292,9 @@ void DeclResultIdMapper::storeOutStageVarsToStorage(
     }
     auto *ptrToOutputStageVar = spvBuilder.createAccessChain(
         outputControlPointType, found->second, {ctrlPointID}, /*loc=*/{});
-    auto *load = spvBuilder.createLoad(outputControlPointType,
-                                       ptrToOutputStageVar, /*loc=*/{});
+    auto *load =
+        spvBuilder.createLoad(outputControlPointType, ptrToOutputStageVar,
+                              /*loc=*/{});
     spvBuilder.createStore(ptr, load, /*loc=*/{});
     return;
   }

+ 31 - 10
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -23,6 +23,8 @@
 #include "clang/SPIRV/String.h"
 // clang-format on
 
+#include <functional>
+
 namespace clang {
 namespace spirv {
 
@@ -2326,6 +2328,17 @@ EmitTypeHandler::getOrCreateConstantComposite(SpirvConstantComposite *inst) {
   return inst->getResultId();
 }
 
+static inline bool
+isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
+                         const StructType::FieldInfo &field) {
+  if (previous.fieldIndex == field.fieldIndex) {
+    // Right now, the only reason for those indices to be shared is if both
+    // are merged bitfields.
+    assert(previous.bitfield.hasValue() && field.bitfield.hasValue());
+  }
+  return previous.fieldIndex == field.fieldIndex;
+}
+
 uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
   // First get the decorations that would apply to this type.
   bool alreadyExists = false;
@@ -2447,24 +2460,32 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
   }
   // Structure types
   else if (const auto *structType = dyn_cast<StructType>(type)) {
-    llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
-    size_t numFields = fields.size();
+    std::vector<std::reference_wrapper<const StructType::FieldInfo>>
+        fieldsToGenerate;
+    {
+      llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
+      for (size_t i = 0; i < fields.size(); ++i) {
+        if (i > 0 && isFieldMergeWithPrevious(fields[i - 1], fields[i]))
+          continue;
+        fieldsToGenerate.push_back(std::ref(fields[i]));
+      }
+    }
 
     // Emit OpMemberName for the struct members.
-    for (size_t i = 0; i < numFields; ++i)
-      emitNameForType(fields[i].name, id, i);
+    for (size_t i = 0; i < fieldsToGenerate.size(); ++i)
+      emitNameForType(fieldsToGenerate[i].get().name, id, i);
 
     llvm::SmallVector<uint32_t, 4> fieldTypeIds;
-    for (auto &field : fields) {
-      fieldTypeIds.push_back(emitType(field.type));
-    }
+    for (auto &field : fieldsToGenerate)
+      fieldTypeIds.push_back(emitType(field.get().type));
 
-    for (size_t i = 0; i < numFields; ++i) {
-      auto &field = fields[i];
+    for (size_t i = 0; i < fieldsToGenerate.size(); ++i) {
+      const auto &field = fieldsToGenerate[i].get();
       // Offset decorations
-      if (field.offset.hasValue())
+      if (field.offset.hasValue()) {
         emitDecoration(id, spv::Decoration::Offset, {field.offset.getValue()},
                        i);
+      }
 
       // MatrixStride decorations
       if (field.matrixStride.hasValue())

+ 6 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -8,6 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "LiteralTypeVisitor.h"
+#include "LowerTypeVisitor.h"
 #include "clang/SPIRV/AstTypeProbe.h"
 #include "clang/SPIRV/SpirvFunction.h"
 
@@ -389,6 +390,11 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
       const auto *decl = structType->getDecl();
       size_t i = 0;
       for (const auto *field : decl->fields()) {
+        // If the field is a bitfield, it might be squashed later when building
+        // the SPIR-V type depending on context. This means indices starting
+        // from this bitfield are not guaranteed, and we shouldn't touch them.
+        if (field->isBitField())
+          break;
         tryToUpdateInstLitType(constituents[i], field->getType());
         ++i;
       }

+ 123 - 63
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -33,10 +33,12 @@ inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
   return (val + pow2 - 1) & ~(pow2 - 1);
 }
 
+} // end anonymous namespace
+
 // This method sorts a field list in the following order:
 //  - fields with register annotation first, sorted by register index.
 //  - then fields without annotation, in order of declaration.
-std::vector<const HybridStructType::FieldInfo *>
+static std::vector<const HybridStructType::FieldInfo *>
 sortFields(llvm::ArrayRef<HybridStructType::FieldInfo> fields) {
   std::vector<const HybridStructType::FieldInfo *> output;
   output.resize(fields.size());
@@ -60,22 +62,34 @@ sortFields(llvm::ArrayRef<HybridStructType::FieldInfo> fields) {
   return output;
 }
 
+static void setDefaultFieldSize(const AlignmentSizeCalculator &alignmentCalc,
+                                const SpirvLayoutRule rule,
+                                const HybridStructType::FieldInfo *currentField,
+                                StructType::FieldInfo *field) {
+
+  const auto &fieldType = currentField->astType;
+  uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
+  std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
+      fieldType, rule, /*isRowMajor*/ llvm::None, &stride);
+  field->sizeInBytes = memberSize;
+  return;
+}
+
 // Correctly determine a field offset/size/padding depending on its neighbors
 // and other rules.
-void setDefaultFieldOffsetAndSize(
-    const AlignmentSizeCalculator &alignmentCalc, const SpirvLayoutRule rule,
-    const uint32_t previousFieldEnd,
-    const HybridStructType::FieldInfo *currentField,
-    StructType::FieldInfo *field) {
+static void
+setDefaultFieldOffset(const AlignmentSizeCalculator &alignmentCalc,
+                      const SpirvLayoutRule rule,
+                      const uint32_t previousFieldEnd,
+                      const HybridStructType::FieldInfo *currentField,
+                      StructType::FieldInfo *field) {
 
   const auto &fieldType = currentField->astType;
   uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
   std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
       fieldType, rule, /*isRowMajor*/ llvm::None, &stride);
-  field->sizeInBytes = memberSize;
 
   const uint32_t baseOffset = previousFieldEnd;
-
   // The next avaiable location after laying out the previous members
   if (rule != SpirvLayoutRule::RelaxedGLSLStd140 &&
       rule != SpirvLayoutRule::RelaxedGLSLStd430 &&
@@ -90,8 +104,6 @@ void setDefaultFieldOffsetAndSize(
   field->offset = newOffset;
 }
 
-} // end anonymous namespace
-
 bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
   if (phase == Visitor::Phase::Done) {
     // Lower the function return type.
@@ -223,8 +235,8 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpImageSparseRead: {
     const auto *uintType = spvContext.getUIntType(32);
     const auto *sparseResidencyStruct = spvContext.getStructType(
-        {StructType::FieldInfo(uintType, "Residency.Code"),
-         StructType::FieldInfo(resultType, "Result.Type")},
+        {StructType::FieldInfo(uintType, /* fieldIndex*/ 0, "Residency.Code"),
+         StructType::FieldInfo(resultType, /* fieldIndex*/ 1, "Result.Type")},
         "SparseResidencyStruct");
     instr->setResultType(sparseResidencyStruct);
     break;
@@ -510,12 +522,20 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
 
     // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
+      llvm::Optional<BitfieldInfo> bitfieldInfo;
+      if (field->isBitField()) {
+        bitfieldInfo = BitfieldInfo();
+        bitfieldInfo->sizeInBits =
+            field->getBitWidthValue(field->getASTContext());
+      }
+
       fields.push_back(HybridStructType::FieldInfo(
           field->getType(), field->getName(),
           /*vkoffset*/ field->getAttr<VKOffsetAttr>(),
           /*packoffset*/ getPackOffset(field),
           /*RegisterAssignment*/ nullptr,
-          /*isPrecise*/ field->hasAttr<HLSLPreciseAttr>()));
+          /*isPrecise*/ field->hasAttr<HLSLPreciseAttr>(),
+          /*bitfield*/ bitfieldInfo));
     }
 
     auto loweredFields = populateLayoutInformation(fields, rule);
@@ -719,8 +739,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
 
     const std::string typeName = "type." + name.str() + "." + getAstTypeName(s);
     const auto *valType = spvContext.getStructType(
-        {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0, matrixStride,
-                               isRowMajor)},
+        {StructType::FieldInfo(raType, /* fieldIndex*/ 0, /*name*/ "",
+                               /*offset*/ 0, matrixStride, isRowMajor)},
         typeName, isReadOnly, StructInterfaceType::StorageBuffer);
 
     if (asAlias) {
@@ -862,12 +882,13 @@ LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
 
 StructType::FieldInfo
 LowerTypeVisitor::lowerField(const HybridStructType::FieldInfo *field,
-                             SpirvLayoutRule rule) {
+                             SpirvLayoutRule rule, const uint32_t fieldIndex) {
   auto fieldType = field->astType;
   // Lower the field type fist. This call will populate proper matrix
   // majorness information.
   StructType::FieldInfo loweredField(
-      lowerType(fieldType, rule, /*isRowMajor*/ llvm::None, {}), field->name);
+      lowerType(fieldType, rule, /*isRowMajor*/ llvm::None, {}), fieldIndex,
+      field->name);
 
   // Set RelaxedPrecision information for the lowered field.
   if (isRelaxedPrecisionType(fieldType, spvOptions)) {
@@ -876,6 +897,7 @@ LowerTypeVisitor::lowerField(const HybridStructType::FieldInfo *field,
   if (field->isPrecise) {
     loweredField.isPrecise = true;
   }
+  loweredField.bitfield = field->bitfield;
 
   // We only need layout information for structures with non-void layout rule.
   if (rule == SpirvLayoutRule::Void) {
@@ -912,59 +934,89 @@ LowerTypeVisitor::populateLayoutInformation(
 
   auto fieldVisitor = [this,
                        &rule](const StructType::FieldInfo *previousField,
-                              const HybridStructType::FieldInfo *currentField) {
-    StructType::FieldInfo loweredField = lowerField(currentField, rule);
-    // We only need layout information for structures with non-void layout rule.
-    if (rule == SpirvLayoutRule::Void) {
+                              const HybridStructType::FieldInfo *currentField,
+                              const uint32_t nextFieldIndex) {
+    StructType::FieldInfo loweredField =
+        lowerField(currentField, rule, nextFieldIndex);
+    setDefaultFieldSize(alignmentCalc, rule, currentField, &loweredField);
+
+    // We only need size information for structures with non-void layout &
+    // non-bitfield fields.
+    if (rule == SpirvLayoutRule::Void && !currentField->bitfield.hasValue())
       return loweredField;
-    }
 
-    const uint32_t previousFieldEnd =
-        previousField ? previousField->offset.getValue() +
-                            previousField->sizeInBytes.getValue()
-                      : 0;
-    setDefaultFieldOffsetAndSize(alignmentCalc, rule, previousFieldEnd,
-                                 currentField, &loweredField);
+    // We only need layout information for structures with non-void layout rule.
+    if (rule != SpirvLayoutRule::Void) {
+      const uint32_t previousFieldEnd =
+          previousField ? previousField->offset.getValue() +
+                              previousField->sizeInBytes.getValue()
+                        : 0;
+      setDefaultFieldOffset(alignmentCalc, rule, previousFieldEnd, currentField,
+                            &loweredField);
+
+      // The vk::offset attribute takes precedence over all.
+      if (currentField->vkOffsetAttr) {
+        loweredField.offset = currentField->vkOffsetAttr->getOffset();
+        return loweredField;
+      }
 
-    // The vk::offset attribute takes precedence over all.
-    if (currentField->vkOffsetAttr) {
-      loweredField.offset = currentField->vkOffsetAttr->getOffset();
-      return loweredField;
-    }
+      // The :packoffset() annotation takes precedence over normal layout
+      // calculation.
+      if (currentField->packOffsetAttr) {
+        const uint32_t offset =
+            currentField->packOffsetAttr->Subcomponent * 16 +
+            currentField->packOffsetAttr->ComponentOffset * 4;
+        // Do minimal check to make sure the offset specified by packoffset does
+        // not cause overlap.
+        if (offset < previousFieldEnd) {
+          emitError("packoffset caused overlap with previous members",
+                    currentField->packOffsetAttr->Loc);
+        }
 
-    // The :packoffset() annotation takes precedence over normal layout
-    // calculation.
-    if (currentField->packOffsetAttr) {
-      const uint32_t offset = currentField->packOffsetAttr->Subcomponent * 16 +
-                              currentField->packOffsetAttr->ComponentOffset * 4;
-      // Do minimal check to make sure the offset specified by packoffset does
-      // not cause overlap.
-      if (offset < previousFieldEnd) {
-        emitError("packoffset caused overlap with previous members",
-                  currentField->packOffsetAttr->Loc);
+        loweredField.offset = offset;
+        return loweredField;
       }
 
-      loweredField.offset = offset;
-      return loweredField;
-    }
+      // The :register(c#) annotation takes precedence over normal layout
+      // calculation.
+      if (currentField->registerC) {
+        const uint32_t offset = 16 * currentField->registerC->RegisterNumber;
+        // Do minimal check to make sure the offset specified by :register(c#)
+        // does not cause overlap.
+        if (offset < previousFieldEnd) {
+          emitError(
+              "found offset overlap when processing register(c%0) assignment",
+              currentField->registerC->Loc)
+              << currentField->registerC->RegisterNumber;
+        }
 
-    // The :register(c#) annotation takes precedence over normal layout
-    // calculation.
-    if (currentField->registerC) {
-      const uint32_t offset = 16 * currentField->registerC->RegisterNumber;
-      // Do minimal check to make sure the offset specified by :register(c#)
-      // does not cause overlap.
-      if (offset < previousFieldEnd) {
-        emitError(
-            "found offset overlap when processing register(c%0) assignment",
-            currentField->registerC->Loc)
-            << currentField->registerC->RegisterNumber;
+        loweredField.offset = offset;
+        return loweredField;
       }
+    }
 
-      loweredField.offset = offset;
+    if (!currentField->bitfield.hasValue())
       return loweredField;
-    }
 
+    // Previous field is a full type, cannot merge.
+    if (!previousField || !previousField->bitfield.hasValue())
+      return loweredField;
+
+    // Bitfields can only be merged if they have the exact base type.
+    // (SPIR-V cannot handle mixed-types bitfields).
+    if (previousField->type != loweredField.type)
+      return loweredField;
+
+    const uint32_t basetypeSize = previousField->sizeInBytes.getValue() * 8;
+    const auto &previousBitfield = previousField->bitfield.getValue();
+    const uint32_t nextAvailableBit =
+        previousBitfield.offsetInBits + previousBitfield.sizeInBits;
+    if (nextAvailableBit + currentField->bitfield->sizeInBits > basetypeSize)
+      return loweredField;
+
+    loweredField.bitfield->offsetInBits = nextAvailableBit;
+    loweredField.offset = previousField->offset;
+    loweredField.fieldIndex = previousField->fieldIndex;
     return loweredField;
   };
 
@@ -982,18 +1034,26 @@ LowerTypeVisitor::populateLayoutInformation(
 
   // The resulting vector of fields with proper layout information.
   // Second, build each field, and determine their actual offset in the
-  // structure.
+  // structure (explicit layout, bitfield merging, etc).
   llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
   llvm::DenseMap<const HybridStructType::FieldInfo *, uint32_t> fieldToIndexMap;
 
+  // This stores the index of the field in the actual SPIR-V construct.
+  // When bitfields are merged, this index will be the same for merged fields.
+  uint32_t fieldIndexInConstruct = 0;
   for (size_t i = 0; i < sortedFields.size(); i++) {
     const StructType::FieldInfo *previousField =
         i > 0 ? &loweredFields.back() : nullptr;
     const HybridStructType::FieldInfo *currentField = sortedFields[i];
-    const size_t field_index = loweredFields.size();
+    const size_t fieldIndexForMap = loweredFields.size();
 
-    loweredFields.emplace_back(fieldVisitor(previousField, currentField));
-    fieldToIndexMap[sortedFields[i]] = field_index;
+    loweredFields.emplace_back(
+        fieldVisitor(previousField, currentField, fieldIndexInConstruct));
+    if (!previousField ||
+        previousField->fieldIndex != loweredFields.back().fieldIndex) {
+      fieldIndexInConstruct++;
+    }
+    fieldToIndexMap[sortedFields[i]] = fieldIndexForMap;
   }
 
   // Re-order the sorted fields back to their original order.

+ 2 - 1
tools/clang/lib/SPIRV/LowerTypeVisitor.h

@@ -94,7 +94,8 @@ private:
   /// This function only considers the field as standalone.
   /// Offset and layout constraint from the parent struct are not considered.
   StructType::FieldInfo lowerField(const HybridStructType::FieldInfo *field,
-                                   SpirvLayoutRule rule);
+                                   SpirvLayoutRule rule,
+                                   const uint32_t fieldIndex);
 
 private:
   ASTContext &astContext;                /// AST context

+ 45 - 5
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -190,9 +190,10 @@ SpirvVectorShuffle *SpirvBuilder::createVectorShuffle(
   return instruction;
 }
 
-SpirvLoad *SpirvBuilder::createLoad(QualType resultType,
-                                    SpirvInstruction *pointer,
-                                    SourceLocation loc, SourceRange range) {
+SpirvInstruction *SpirvBuilder::createLoad(QualType resultType,
+                                           SpirvInstruction *pointer,
+                                           SourceLocation loc,
+                                           SourceRange range) {
   assert(insertPoint && "null insert point");
   auto *instruction = new (context) SpirvLoad(resultType, loc, pointer, range);
   instruction->setStorageClass(pointer->getStorageClass());
@@ -210,7 +211,20 @@ SpirvLoad *SpirvBuilder::createLoad(QualType resultType,
   }
 
   insertPoint->addInstruction(instruction);
-  return instruction;
+
+  const auto &bitfieldInfo = pointer->getBitfieldInfo();
+  if (!bitfieldInfo.hasValue())
+    return instruction;
+
+  auto *offset =
+      getConstantInt(astContext.UnsignedIntTy,
+                     llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->offsetInBits), /* isSigned= */ false));
+  auto *count =
+      getConstantInt(astContext.UnsignedIntTy,
+                     llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->sizeInBits), /* isSigned= */ false));
+  return createBitFieldExtract(
+      resultType, instruction, offset, count,
+      pointer->getAstResultType()->isSignedIntegerOrEnumerationType(), loc);
 }
 
 SpirvCopyObject *SpirvBuilder::createCopyObject(QualType resultType,
@@ -255,8 +269,32 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address,
                                       SpirvInstruction *value,
                                       SourceLocation loc, SourceRange range) {
   assert(insertPoint && "null insert point");
+  // Safeguard. If this happens, it means we leak non-extracted bitfields.
+  assert(false == value->getBitfieldInfo().hasValue());
+
+  SpirvInstruction *source = value;
+  const auto &bitfieldInfo = address->getBitfieldInfo();
+  if (bitfieldInfo.hasValue()) {
+    // Generate SPIR-V type for value. This is required to know the final
+    // layout.
+    LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
+    lowerTypeVisitor.visitInstruction(value);
+    context.addToInstructionsWithLoweredType(value);
+
+    auto *base = createLoad(value->getResultType(), address, loc, range);
+    auto *offset =
+        getConstantInt(astContext.UnsignedIntTy,
+                       llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->offsetInBits), false));
+    auto *count =
+        getConstantInt(astContext.UnsignedIntTy,
+                       llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->sizeInBits), false));
+    source =
+        createBitFieldInsert(/*QualType*/ {}, base, value, offset, count, loc);
+    source->setResultType(value->getResultType());
+  }
+
   auto *instruction =
-      new (context) SpirvStore(loc, address, value, llvm::None, range);
+      new (context) SpirvStore(loc, address, source, llvm::None, range);
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -814,6 +852,7 @@ SpirvBitFieldInsert *SpirvBuilder::createBitFieldInsert(
   auto *inst = new (context)
       SpirvBitFieldInsert(resultType, loc, base, insert, offset, count);
   insertPoint->addInstruction(inst);
+  inst->setRValue(true);
   return inst;
 }
 
@@ -824,6 +863,7 @@ SpirvBitFieldExtract *SpirvBuilder::createBitFieldExtract(
   auto *inst = new (context)
       SpirvBitFieldExtract(resultType, loc, base, offset, count, isSigned);
   insertPoint->addInstruction(inst);
+  inst->setRValue(true);
   return inst;
 }
 

+ 7 - 5
tools/clang/lib/SPIRV/SpirvContext.cpp

@@ -343,10 +343,11 @@ const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
       getRuntimeArrayType(getUIntType(32), /* ArrayStride */ 4);
 
   // Create a struct containing the runtime array as its only member.
-  return getStructType(
-      {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0)},
-      isWritable ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer",
-      !isWritable, StructInterfaceType::StorageBuffer);
+  return getStructType({StructType::FieldInfo(raType, /*fieldIndex*/ 0,
+                                              /*name*/ "", /*offset*/ 0)},
+                       isWritable ? "type.RWByteAddressBuffer"
+                                  : "type.ByteAddressBuffer",
+                       !isWritable, StructInterfaceType::StorageBuffer);
 }
 
 const StructType *SpirvContext::getACSBufferCounterType() {
@@ -355,7 +356,8 @@ const StructType *SpirvContext::getACSBufferCounterType() {
 
   // Create a struct containing the integer counter as its only member.
   const StructType *type =
-      getStructType({StructType::FieldInfo(int32Type, "counter", /*offset*/ 0)},
+      getStructType({StructType::FieldInfo(int32Type, /*fieldIndex*/ 0,
+                                           "counter", /*offset*/ 0)},
                     "type.ACSBuffer.counter",
                     /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
 

+ 165 - 64
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -15,10 +15,12 @@
 
 #include "AlignmentSizeCalculator.h"
 #include "InitListHandler.h"
+#include "LowerTypeVisitor.h"
 #include "RawBufferMethods.h"
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "spirv-tools/optimizer.hpp"
+#include "clang/AST/RecordLayout.h"
 #include "clang/SPIRV/AstTypeProbe.h"
 #include "clang/SPIRV/String.h"
 #include "clang/Sema/Sema.h"
@@ -542,6 +544,45 @@ bool isVkRawBufferLoadIntrinsic(const clang::FunctionDecl *FD) {
   return true;
 }
 
+// Takes an AST member type, and determines its index in the equivalent SPIR-V
+// struct type. This is required as the struct layout might change between the
+// AST representation and SPIR-V representation.
+uint32_t getFieldIndexInStruct(const StructType *spirvStructType,
+                               const QualType &astStructType,
+                               const FieldDecl *fieldDecl) {
+  assert(fieldDecl);
+  const uint32_t indexAST =
+      getNumBaseClasses(astStructType) + fieldDecl->getFieldIndex();
+
+  const auto &fields = spirvStructType->getFields();
+  assert(indexAST <= fields.size());
+  return fields[indexAST].fieldIndex;
+}
+
+// Takes an AST struct type, and lowers is to the equivalent SPIR-V type.
+const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
+                                  LowerTypeVisitor &lowerTypeVisitor,
+                                  const QualType &structType) {
+  // If we are accessing a derived struct, we need to account for the number
+  // of base structs, since they are placed as fields at the beginning of the
+  // derived struct.
+  auto baseType = structType;
+  if (baseType->isPointerType()) {
+    baseType = baseType->getPointeeType();
+  }
+
+  // The AST type index is not representative of the SPIR-V type index
+  // because we might squash some fields (bitfields by ex.).
+  // What we need is to match each AST node with the squashed field and then,
+  // determine the real index.
+  const SpirvType *spvType = lowerTypeVisitor.lowerType(
+      baseType, spirvOptions.sBufferLayoutRule, llvm::None, SourceLocation());
+
+  const StructType *output = dyn_cast<StructType>(spvType);
+  assert(output != nullptr);
+  return output;
+}
+
 } // namespace
 
 SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
@@ -2584,12 +2625,12 @@ SpirvEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr,
   SourceRange range =
       (rangeOverride != SourceRange()) ? rangeOverride : expr->getSourceRange();
 
-  if (!indices.empty()) {
-    info = turnIntoElementPtr(base->getType(), info, expr->getType(), indices,
-                              base->getExprLoc(), range);
+  if (indices.empty()) {
+    return info;
   }
 
-  return info;
+  return derefOrCreatePointerToValue(base->getType(), info, expr->getType(),
+                                     indices, base->getExprLoc(), range);
 }
 
 SpirvInstruction *SpirvEmitter::doBinaryOperator(const BinaryOperator *expr) {
@@ -3281,9 +3322,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
           astContext.UnsignedIntTy, llvm::APInt(32, baseIndices[i]));
 
     auto *derivedInfo = doExpr(subExpr);
-    return turnIntoElementPtr(subExpr->getType(), derivedInfo, expr->getType(),
-                              baseIndexInstructions, subExpr->getExprLoc(),
-                              range);
+    return derefOrCreatePointerToValue(subExpr->getType(), derivedInfo,
+                                       expr->getType(), baseIndexInstructions,
+                                       subExpr->getExprLoc(), range);
   }
   case CastKind::CK_ArrayToPointerDecay: {
     // Literal string to const string conversion falls under this category.
@@ -4355,8 +4396,9 @@ SpirvEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
   auto *zero = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
   auto *index = doExpr(expr->getArg(0));
 
-  return turnIntoElementPtr(buffer->getType(), info, structType, {zero, index},
-                            buffer->getExprLoc(), range);
+  return derefOrCreatePointerToValue(buffer->getType(), info, structType,
+                                     {zero, index}, buffer->getExprLoc(),
+                                     range);
 }
 
 SpirvInstruction *
@@ -4574,7 +4616,8 @@ SpirvEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   }
 
   const auto range = expr->getSourceRange();
-  bufferInfo = turnIntoElementPtr(object->getType(), bufferInfo, bufferElemTy,
+  bufferInfo =
+      derefOrCreatePointerToValue(object->getType(), bufferInfo, bufferElemTy,
                                   {zero, index}, object->getExprLoc(), range);
 
   if (isAppend) {
@@ -5598,8 +5641,8 @@ SpirvEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr,
   SourceRange range =
       (rangeOverride != SourceRange()) ? rangeOverride : expr->getSourceRange();
 
-  return turnIntoElementPtr(baseExpr->getType(), base, expr->getType(), indices,
-                            baseExpr->getExprLoc(), range);
+  return derefOrCreatePointerToValue(baseExpr->getType(), base, expr->getType(),
+                                     indices, baseExpr->getExprLoc(), range);
 }
 
 SpirvInstruction *
@@ -5794,16 +5837,50 @@ SpirvInstruction *SpirvEmitter::doMemberExpr(const MemberExpr *expr,
   llvm::SmallVector<SpirvInstruction *, 4> indices;
   const Expr *base = collectArrayStructIndices(
       expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
-  SourceRange range =
+  const SourceRange &range =
       (rangeOverride != SourceRange()) ? rangeOverride : expr->getSourceRange();
   auto *instr = loadIfAliasVarRef(base, range);
+  const auto &loc = base->getExprLoc();
 
-  if (instr && !indices.empty()) {
-    instr = turnIntoElementPtr(base->getType(), instr, expr->getType(), indices,
-                               base->getExprLoc(), range);
+  if (!instr || indices.empty()) {
+    return instr;
   }
 
-  return instr;
+  const auto *fieldDecl = dyn_cast<FieldDecl>(expr->getMemberDecl());
+  if (!fieldDecl || !fieldDecl->isBitField()) {
+    return derefOrCreatePointerToValue(base->getType(), instr, expr->getType(),
+                                       indices, loc, range);
+  }
+
+  auto baseType = expr->getBase()->getType();
+  if (baseType->isPointerType()) {
+    baseType = baseType->getPointeeType();
+  }
+  const uint32_t indexAST =
+      getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
+  LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
+  const StructType *spirvStructType =
+      lowerStructType(spirvOptions, lowerTypeVisitor, baseType);
+  assert(spirvStructType);
+
+  const uint32_t bitfieldOffset =
+      spirvStructType->getFields()[indexAST].bitfield->offsetInBits;
+  const uint32_t bitfieldSize =
+      spirvStructType->getFields()[indexAST].bitfield->sizeInBits;
+  BitfieldInfo bitfieldInfo{bitfieldOffset, bitfieldSize};
+
+  if (instr->isRValue()) {
+    SpirvVariable *variable = turnIntoLValue(base->getType(), instr, loc);
+    SpirvInstruction *chain = spvBuilder.createAccessChain(
+        expr->getType(), variable, indices, loc, range);
+    chain->setBitfieldInfo(bitfieldInfo);
+    return spvBuilder.createLoad(expr->getType(), chain, loc);
+  }
+
+  SpirvInstruction *chain =
+      spvBuilder.createAccessChain(expr->getType(), instr, indices, loc, range);
+  chain->setBitfieldInfo(bitfieldInfo);
+  return chain;
 }
 
 SpirvVariable *SpirvEmitter::createTemporaryVar(QualType type,
@@ -6833,23 +6910,42 @@ void SpirvEmitter::splitVecLastElement(QualType vecType, SpirvInstruction *vec,
       spvBuilder.createCompositeExtract(elemType, vec, {count - 1}, loc);
 }
 
-SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType structType,
+SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
                                                       QualType elemType,
                                                       SpirvInstruction *vector,
                                                       SourceLocation loc,
                                                       SourceRange range) {
-  assert(structType->isStructureType());
+  assert(astStructType->isStructureType());
+
+  const auto *structDecl = astStructType->getAsStructureType()->getDecl();
+  LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
+  const StructType *spirvStructType =
+      lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
 
-  const auto *structDecl = structType->getAsStructureType()->getDecl();
   uint32_t vectorIndex = 0;
   uint32_t elemCount = 1;
+  uint32_t lastConvertedIndex = 0;
   llvm::SmallVector<SpirvInstruction *, 4> members;
+  for (auto field = structDecl->field_begin(); field != structDecl->field_end();
+       field++) {
+    // Multiple bitfields can share the same storing type. In such case, we only
+    // want to append the whole storage once.
+    const size_t astFieldIndex =
+        std::distance(structDecl->field_begin(), field);
+    const uint32_t currentFieldIndex =
+        spirvStructType->getFields()[astFieldIndex].fieldIndex;
+    if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
+      continue;
+    }
+    lastConvertedIndex = currentFieldIndex;
 
-  for (const auto *field : structDecl->fields()) {
     if (isScalarType(field->getType())) {
       members.push_back(spvBuilder.createCompositeExtract(
           elemType, vector, {vectorIndex++}, loc, range));
-    } else if (isVectorType(field->getType(), nullptr, &elemCount)) {
+      continue;
+    }
+
+    if (isVectorType(field->getType(), nullptr, &elemCount)) {
       llvm::SmallVector<uint32_t, 4> indices;
       for (uint32_t i = 0; i < elemCount; ++i)
         indices.push_back(vectorIndex++);
@@ -6857,13 +6953,14 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType structType,
       members.push_back(spvBuilder.createVectorShuffle(
           astContext.getExtVectorType(elemType, elemCount), vector, vector,
           indices, loc, range));
-    } else {
-      assert(false && "unhandled type");
+      continue;
     }
+
+    assert(false && "unhandled type");
   }
 
   return spvBuilder.createCompositeConstruct(
-      structType, members, vector->getSourceLocation(), range);
+      astStructType, members, vector->getSourceLocation(), range);
 }
 
 SpirvInstruction *
@@ -7614,23 +7711,24 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
       }
     }
 
-    // Append the index of the current level
-    const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
-    assert(fieldDecl);
-    // If we are accessing a derived struct, we need to account for the number
-    // of base structs, since they are placed as fields at the beginning of the
-    // derived struct.
-    auto baseType = indexing->getBase()->getType();
-    if (baseType->isPointerType()) {
-      baseType = baseType->getPointeeType();
-    }
-    const uint32_t index =
-        getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
-    if (rawIndex) {
-      rawIndices->push_back(index);
-    } else {
-      indices->push_back(spvBuilder.getConstantInt(
-          astContext.IntTy, llvm::APInt(32, index, true)));
+    {
+      LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
+      const auto &astStructType =
+          /* structType */ indexing->getBase()->getType();
+      const StructType *spirvStructType =
+          lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
+      assert(spirvStructType != nullptr);
+      const uint32_t fieldIndex = getFieldIndexInStruct(
+          spirvStructType, astStructType,
+          /* fieldDecl */
+          dyn_cast<FieldDecl>(indexing->getMemberDecl()));
+
+      if (rawIndex) {
+        rawIndices->push_back(fieldIndex);
+      } else {
+        indices->push_back(spvBuilder.getConstantInt(
+            astContext.IntTy, llvm::APInt(32, fieldIndex, true)));
+      }
     }
 
     return base;
@@ -7758,26 +7856,31 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
   return expr;
 }
 
-SpirvInstruction *SpirvEmitter::turnIntoElementPtr(
+SpirvVariable *SpirvEmitter::turnIntoLValue(QualType type,
+                                            SpirvInstruction *source,
+                                            SourceLocation loc) {
+  assert(source->isRValue());
+  const auto varName = getAstTypeName(type);
+  const auto var = createTemporaryVar(type, varName, source, loc);
+  var->setLayoutRule(SpirvLayoutRule::Void);
+  var->setStorageClass(spv::StorageClass::Function);
+  var->setContainsAliasComponent(source->containsAliasComponent());
+  return var;
+}
+
+SpirvInstruction *SpirvEmitter::derefOrCreatePointerToValue(
     QualType baseType, SpirvInstruction *base, QualType elemType,
     const llvm::SmallVector<SpirvInstruction *, 4> &indices, SourceLocation loc,
     SourceRange range) {
-  // If this is a rvalue, we need a temporary object to hold it
-  // so that we can get access chain from it.
-  const bool needTempVar = base->isRValue();
-  SpirvInstruction *accessChainBase = base;
-
-  if (needTempVar) {
-    auto varName = getAstTypeName(baseType);
-    const auto var = createTemporaryVar(baseType, varName, base, loc);
-    var->setLayoutRule(SpirvLayoutRule::Void);
-    var->setStorageClass(spv::StorageClass::Function);
-    var->setContainsAliasComponent(base->containsAliasComponent());
-    accessChainBase = var;
+  if (base->isLValue()) {
+    return spvBuilder.createAccessChain(elemType, base, indices, loc, range);
   }
 
-  base = spvBuilder.createAccessChain(elemType, accessChainBase, indices, loc,
-                                      range);
+  // If this is a rvalue, we need a temporary object to hold it
+  // so that we can get access chain from it.
+  SpirvVariable *variable = turnIntoLValue(baseType, base, loc);
+  SpirvInstruction *chain =
+      spvBuilder.createAccessChain(elemType, variable, indices, loc, range);
 
   // Okay, this part seems weird, but it is intended:
   // If the base is originally a rvalue, the whole AST involving the base
@@ -7788,11 +7891,7 @@ SpirvInstruction *SpirvEmitter::turnIntoElementPtr(
   // to rely on to load the access chain if a rvalue is expected. Therefore,
   // we must do the load here. Otherwise, it's up to the consumer of this
   // access chain to do the load, and that can be everywhere.
-  if (needTempVar) {
-    base = spvBuilder.createLoad(elemType, base, loc);
-  }
-
-  return base;
+  return spvBuilder.createLoad(elemType, chain, loc);
 }
 
 SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
@@ -11342,7 +11441,7 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
                                 /*isDevice*/ false,
                                 /*groupSync*/ true,
                                 /*isAllBarrier*/ false);
-  
+
   // 2) create PerTaskNV out attribute block and store MeshPayload info.
   const auto *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
@@ -13476,7 +13575,9 @@ SpirvEmitter::loadDataFromRawAddress(SpirvInstruction *addressInUInt64,
       spv::Op::OpBitcast, bufferPtrType, addressInUInt64, loc);
   address->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
 
-  SpirvLoad *loadInst = spvBuilder.createLoad(bufferType, address, loc);
+  SpirvLoad *loadInst = dyn_cast<SpirvLoad>(
+      spvBuilder.createLoad(bufferType, address, loc));
+  assert(loadInst);
   loadInst->setAlignment(alignment);
   loadInst->setRValue();
   return loadInst;

+ 11 - 7
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -369,13 +369,17 @@ private:
                             llvm::SmallVectorImpl<SpirvInstruction *> *indices,
                             bool *isMSOutAttribute = nullptr);
 
-  /// Creates an access chain to index into the given SPIR-V evaluation result
-  /// and returns the new SPIR-V evaluation result.
-  SpirvInstruction *
-  turnIntoElementPtr(QualType baseType, SpirvInstruction *base,
-                     QualType elemType,
-                     const llvm::SmallVector<SpirvInstruction *, 4> &indices,
-                     SourceLocation loc, SourceRange range = {});
+  /// For L-values, creates an access chain to index into the given SPIR-V
+  /// evaluation result and returns the new SPIR-V evaluation result.
+  /// For R-values, stores it in a variable, then create the access chain and
+  /// return the evaluation result.
+  SpirvInstruction *derefOrCreatePointerToValue(
+      QualType baseType, SpirvInstruction *base, QualType elemType,
+      const llvm::SmallVector<SpirvInstruction *, 4> &indices,
+      SourceLocation loc, SourceRange range = {});
+
+  SpirvVariable *turnIntoLValue(QualType type, SpirvInstruction *source,
+                                SourceLocation loc);
 
 private:
   /// Validates that vk::* attributes are used correctly and returns false if

+ 33 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferload.bitfield.hlsl

@@ -0,0 +1,33 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+// CHECK: OpCapability PhysicalStorageBufferAddresses
+// CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
+// CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+
+struct S {
+  uint f1;
+  uint f2 : 1;
+  uint f3 : 1;
+  uint f4;
+};
+
+uint64_t Address;
+
+// CHECK: [[type_S:%\w+]] = OpTypeStruct %uint %uint %uint
+// CHECK: [[ptr_f_S:%\w+]] = OpTypePointer Function [[type_S]]
+// CHECK: [[ptr_p_S:%\w+]] = OpTypePointer PhysicalStorageBuffer [[type_S]]
+
+void main() : B {
+// CHECK: [[tmp_S:%\w+]] = OpVariable [[ptr_f_S]] Function
+// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+// CHECK: [[value:%\d+]] = OpLoad %ulong [[value]]
+// CHECK: [[value:%\d+]] = OpBitcast [[ptr_p_S]] [[value]]
+// CHECK: [[value:%\d+]] = OpLoad [[type_S]] [[value]] Aligned 4
+// CHECK: OpStore [[tmp_S]] [[value]]
+// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp_S]] %int_1
+// CHECK: [[value:%\d+]] = OpLoad %uint [[value]]
+// CHECK: [[value:%\d+]] = OpBitFieldUExtract %uint [[value]] %uint_1 %uint_1
+// CHECK: OpStore %tmp [[value]]
+  uint tmp = vk::RawBufferLoad<S>(Address).f3;
+}
+

+ 22 - 0
tools/clang/test/CodeGenSPIRV/op.buffer.access.bitfield.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -T ps_6_6 -E main -HV 2021
+
+struct S1 {
+    uint f1 : 1;
+    uint f2 : 1;
+};
+
+Buffer<S1> input_1;
+
+void main() {
+// CHECK: [[img:%\d+]] = OpLoad %type_buffer_image %input_1
+// CHECK: [[tmp:%\d+]] = OpImageFetch %v4uint [[img]] %uint_0 None
+// CHECK: [[tmp:%\d+]] = OpVectorShuffle %v2uint [[tmp]] [[tmp]] 0 1
+// CHECK: [[tmp_f1:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+// CHECK: [[tmp_s1:%\d+]] = OpCompositeConstruct %S1 [[tmp_f1]]
+// CHECK: OpStore [[tmp_var_S1:%\w+]] [[tmp_s1]]
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp_var_S1]] %int_0
+// CHECK: [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK: [[extract:%\d+]] = OpBitFieldUExtract %uint [[load]] %uint_1 %uint_1
+// CHECK: OpStore %tmp [[extract]]
+  uint tmp = input_1[0].f2;
+}

+ 104 - 0
tools/clang/test/CodeGenSPIRV/op.struct.access.bitfield.hlsl

@@ -0,0 +1,104 @@
+// RUN: %dxc -T ps_6_6 -E main -HV 2021
+
+struct S1 {
+  uint f1 : 1;
+  uint f2 : 8;
+};
+
+struct S2 {
+  int f1 : 2;
+  int f2 : 9;
+};
+
+struct S3 {
+  int f1;
+  int f2 : 1;
+  int f3;
+};
+
+// CHECK: OpMemberName %S1 0 "f1"
+// CHECK-NOT: OpMemberName %S1 1 "f2"
+// CHECK: OpMemberName %S2 0 "f1"
+// CHECK-NOT: OpMemberName %S2 1 "f2"
+
+// CHECK: %S1 = OpTypeStruct %uint
+// CHECK: %S2 = OpTypeStruct %int
+
+void main() {
+  // CHECK: [[s1_var:%\w+]] = OpVariable %_ptr_Function_S1 Function
+  // CHECK: [[s2_var:%\w+]] = OpVariable %_ptr_Function_S2 Function
+  // CHECK: [[s3_var:%\w+]] = OpVariable %_ptr_Function_S3 Function
+  S1 s1;
+  S2 s2;
+  S3 s3;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint [[s1_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_0 %uint_1
+  // CHECK: OpStore [[ptr]] [[insert]]
+  s1.f1 = 1;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint [[s1_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_2 %uint_1 %uint_8
+  // CHECK: OpStore [[ptr]] [[insert]]
+  s1.f2 = 2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s2_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[insert:%\d+]] = OpBitFieldInsert %int [[load]] %int_3 %uint_0 %uint_2
+  // CHECK: OpStore [[ptr]] [[insert]]
+  s2.f1 = 3;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s2_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[insert:%\d+]] = OpBitFieldInsert %int [[load]] %int_4 %uint_2 %uint_9
+  // CHECK: OpStore [[ptr]] [[insert]]
+  s2.f2 = 4;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint [[s1_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[extract:%\d+]] = OpBitFieldUExtract %uint [[load]] %uint_0 %uint_1
+  // CHECK: OpStore %t1 [[extract]]
+  uint t1 = s1.f1;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint [[s1_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[extract:%\d+]] = OpBitFieldUExtract %uint [[load]] %uint_1 %uint_8
+  // CHECK: OpStore %t2 [[extract]]
+  uint t2 = s1.f2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s2_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[extract:%\d+]] = OpBitFieldSExtract %int [[load]] %uint_0 %uint_2
+  // CHECK: OpStore %t3 [[extract]]
+  int t3 = s2.f1;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s2_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[extract:%\d+]] = OpBitFieldSExtract %int [[load]] %uint_2 %uint_9
+  // CHECK: OpStore %t4 [[extract]]
+  int t4 = s2.f2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s2_var]] %int_0
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[extract:%\d+]] = OpBitFieldSExtract %int [[load]] %uint_2 %uint_9
+  // CHECK: [[cast:%\d+]] = OpBitcast %uint [[extract]]
+  // CHECK: OpStore %t5 [[cast]]
+  uint t5 = s2.f2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s3_var]] %int_0
+  // CHECK: OpStore [[ptr]] %int_3
+  s3.f1 = 3;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s3_var]] %int_1
+  // CHECK: [[load:%\d+]] = OpLoad %int [[ptr]]
+  // CHECK: [[insert:%\d+]] = OpBitFieldInsert %int [[load]] %int_4 %uint_0 %uint_1
+  // CHECK: OpStore [[ptr]] [[insert]]
+  s3.f2 = 4;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int [[s3_var]] %int_2
+  // CHECK: OpStore [[ptr]] %int_5
+  s3.f3 = 5;
+}
+

+ 24 - 0
tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.bitfield.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -T vs_6_6 -E main -HV 2021
+
+// CHECK: [[type_S:%\w+]] = OpTypeStruct %uint %uint %uint
+// CHECK: [[rarr_S:%\w+]] = OpTypeRuntimeArray [[type_S]]
+// CHECK: [[buffer:%\w+]] = OpTypeStruct [[rarr_S]]
+// CHECK: [[ptr_buffer:%\w+]] = OpTypePointer Uniform [[buffer]]
+struct S {
+    uint f1;
+    uint f2 : 1;
+    uint f3 : 1;
+    uint f4;
+};
+
+// CHECK: [[var_buffer:%\w+]] = OpVariable [[ptr_buffer]] Uniform
+StructuredBuffer<S> buffer;
+
+void main(uint id : A) {
+  // CHECK: [[id:%\d+]] = OpLoad %uint %id
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[var_buffer]] %int_0 [[id]] %int_1
+  // CHECK: [[value:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[value:%\d+]] = OpBitFieldUExtract %uint [[value]] %uint_1 %uint_1
+  // CHECK: OpStore %tmp [[value]]
+  uint tmp = buffer[id].f3;
+}

+ 78 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.struct.bitfield.assignment.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -T vs_6_0 -E main -HV 2021
+
+// Sanity check.
+struct S1 {
+    uint f1 : 1;
+};
+
+struct S2 {
+    uint f1 : 1;
+    uint f2 : 3;
+    uint f3 : 8;
+    uint f4 : 1;
+};
+
+struct S3 {
+    uint f1 : 1;
+     int f2 : 1;
+    uint f3 : 1;
+};
+
+void main() : A {
+  S1 s1;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s1 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s1.f1 = 1;
+
+  S2 s2;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s2.f1 = 1;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_5 %uint_1 %uint_3
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s2.f2 = 5;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_2 %uint_4 %uint_8
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s2.f3 = 2;
+
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_12 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s2.f4 = 1;
+
+  S3 s3;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s3 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s3.f1 = 1;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int %s3 %int_1
+// CHECK:     [[load:%\d+]] = OpLoad %int [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %int [[load]] %int_0 %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s3.f2 = 0;
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s3 %int_2
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %uint [[load]] %uint_1 %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s3.f3 = 1;
+
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_0
+// CHECK:     [[load:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK:     [[s2f4_extract:%\d+]] = OpBitFieldUExtract %uint [[load]] %uint_12 %uint_1
+// CHECK:     [[s2f4_sext:%\d+]] = OpBitcast %int [[s2f4_extract]]
+// CHECK:     [[ptr:%\d+]] = OpAccessChain %_ptr_Function_int %s3 %int_1
+// CHECK:     [[load:%\d+]] = OpLoad %int [[ptr]]
+// CHECK:     [[insert:%\d+]] = OpBitFieldInsert %int [[load]] [[s2f4_sext]] %uint_0 %uint_1
+// CHECK:     OpStore [[ptr]] [[insert]]
+  s3.f2 = s2.f4;
+}

+ 179 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.struct.bitfield.hlsl

@@ -0,0 +1,179 @@
+// RUN: %dxc -T vs_6_0 -E main -HV 2021
+
+// CHECK:     OpMemberDecorate %S1 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S1 1 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S2 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S2 1 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S3 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S3 1 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S4 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S4 1 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S5 0 Offset 0
+// CHECK:     OpMemberDecorate %S5 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S5 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S6 0 Offset 0
+// CHECK:     OpMemberDecorate %S6 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S6 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S7 0 Offset 0
+// CHECK:     OpMemberDecorate %S7 1 Offset 4
+// CHECK:     OpMemberDecorate %S7 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S7 3 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S8 0 Offset 0
+// CHECK:     OpMemberDecorate %S8 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S8 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S9 0 Offset 0
+// CHECK:     OpMemberDecorate %S9 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S9 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S10 0 Offset 0
+// CHECK:     OpMemberDecorate %S10 1 Offset 4
+// CHECK:     OpMemberDecorate %S10 2 Offset 8
+// CHECK:     OpMemberDecorate %S10 3 Offset 12
+// CHECK-NOT: OpMemberDecorate %S10 4 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S11 0 Offset 0
+// CHECK:     OpMemberDecorate %S11 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S11 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S12 0 Offset 0
+// CHECK:     OpMemberDecorate %S12 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S12 2 Offset {{.+}}
+
+// CHECK:     OpMemberDecorate %S13 0 Offset 0
+// CHECK:     OpMemberDecorate %S13 1 Offset 16
+// CHECK-NOT: OpMemberDecorate %S13 2 Offset {{.+}}
+
+// CHECK: OpMemberDecorate %type_buff 0 Offset 0
+// CHECK: OpMemberDecorate %type_buff 1 Offset 16
+// CHECK: OpMemberDecorate %type_buff 2 Offset 32
+// CHECK: OpMemberDecorate %type_buff 3 Offset 48
+// CHECK: OpMemberDecorate %type_buff 4 Offset 64
+// CHECK: OpMemberDecorate %type_buff 5 Offset 80
+// CHECK: OpMemberDecorate %type_buff 6 Offset 96
+// CHECK: OpMemberDecorate %type_buff 7 Offset 112
+// CHECK: OpMemberDecorate %type_buff 8 Offset 128
+// CHECK: OpMemberDecorate %type_buff 9 Offset 144
+// CHECK: OpMemberDecorate %type_buff 10 Offset 160
+// CHECK: OpMemberDecorate %type_buff 11 Offset 176
+// CHECK: OpMemberDecorate %type_buff 12 Offset 192
+
+// CHECK:  %S1 = OpTypeStruct %uint
+// CHECK:  %S2 = OpTypeStruct %uint
+// CHECK:  %S3 = OpTypeStruct %uint
+// CHECK:  %S4 = OpTypeStruct %uint
+// CHECK:  %S5 = OpTypeStruct %uint %uint
+// CHECK:  %S6 = OpTypeStruct %uint %int
+// CHECK:  %S7 = OpTypeStruct %uint %uint %uint
+// CHECK:  %S8 = OpTypeStruct %uint %uint
+// CHECK:  %S9 = OpTypeStruct %uint %uint
+// CHECK: %S10 = OpTypeStruct %uint %int %uint %int
+// CHECK: %S11 = OpTypeStruct %uint %uint
+// CHECK: %S12 = OpTypeStruct %uint %uint
+// CHECK: %S13 = OpTypeStruct %uint %v4float
+// CHECK: %S14 = OpTypeStruct %uint
+
+// Sanity check.
+struct S1 {
+    uint f1 : 1;
+};
+
+// Bitfield merging.
+struct S2 {
+    uint f2 : 1;
+    uint f1 : 1;
+};
+
+// Bitfield merging: limit.
+struct S3 {
+    uint f2 : 1;
+    uint f1 : 31;
+};
+struct S4 {
+    uint f2 : 31;
+    uint f1 : 1;
+};
+
+// Bitfield merging: overflow.
+struct S5 {
+    uint f1 : 30;
+    uint f2 : 3;
+};
+
+// Bitfield merging: type.
+struct S6 {
+    uint f2 : 1;
+     int f1 : 1;
+};
+
+// Bitfield merging: mix.
+struct S7 {
+    uint f1;
+    uint f2 : 1;
+    uint f3 : 1;
+    uint f4;
+};
+struct S8 {
+    uint f1 : 1;
+    uint f2 : 1;
+    uint f3;
+};
+struct S9 {
+    uint f1;
+    uint f2 : 1;
+    uint f3 : 1;
+};
+struct S10 {
+    uint f2 : 1;
+     int f1 : 1;
+    uint f3 : 1;
+     int f4 : 1;
+};
+
+// alignment.
+struct S11 {
+    uint f1;
+    uint f2 : 1;
+};
+struct S12 {
+    uint f2 : 1;
+    uint f1;
+};
+struct S13 {
+    uint   f1 : 1;
+    float4 f2;
+};
+
+struct S14 {
+    uint f1 : 1;
+    uint f2 : 3;
+    uint f3 : 8;
+};
+
+cbuffer buff : register(b0) {
+  S1 CB_s1;
+  S2 CB_s2;
+  S3 CB_s3;
+  S4 CB_s4;
+  S5 CB_s5;
+  S6 CB_s6;
+  S7 CB_s7;
+  S8 CB_s8;
+  S9 CB_s9;
+  S10 CB_s10;
+  S11 CB_s11;
+  S12 CB_s12;
+  S13 CB_s13;
+  S14 CB_s14;
+}
+
+uint main() : A {
+  return 0u;
+}

+ 18 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -432,10 +432,16 @@ TEST_F(FileTest, OpMatrixAccess1x1) {
 
 // For struct & array accessing operator
 TEST_F(FileTest, OpStructAccess) { runFileTest("op.struct.access.hlsl"); }
+TEST_F(FileTest, OpStructAccessBitfield) {
+  runFileTest("op.struct.access.bitfield.hlsl");
+}
 TEST_F(FileTest, OpArrayAccess) { runFileTest("op.array.access.hlsl"); }
 
 // For buffer accessing operator
 TEST_F(FileTest, OpBufferAccess) { runFileTest("op.buffer.access.hlsl"); }
+TEST_F(FileTest, OpBufferAccessBitfield) {
+  runFileTest("op.buffer.access.bitfield.hlsl");
+}
 TEST_F(FileTest, OpRWBufferAccess) { runFileTest("op.rwbuffer.access.hlsl"); }
 TEST_F(FileTest, OpCBufferAccess) { runFileTest("op.cbuffer.access.hlsl"); }
 TEST_F(FileTest, OpCBufferAccessMajorness) {
@@ -452,6 +458,9 @@ TEST_F(FileTest, OpTextureBufferAccess) {
 TEST_F(FileTest, OpStructuredBufferAccess) {
   runFileTest("op.structured-buffer.access.hlsl");
 }
+TEST_F(FileTest, OpStructuredBufferAccessBitfield) {
+  runFileTest("op.structured-buffer.access.bitfield.hlsl");
+}
 TEST_F(FileTest, OpRWStructuredBufferAccess) {
   runFileTest("op.rw-structured-buffer.access.hlsl");
 }
@@ -1410,6 +1419,9 @@ TEST_F(FileTest, IntrinsicsVkReadClock) {
 TEST_F(FileTest, IntrinsicsVkRawBufferLoad) {
   runFileTest("intrinsics.vkrawbufferload.hlsl");
 }
+TEST_F(FileTest, IntrinsicsVkRawBufferLoadBitfield) {
+  runFileTest("intrinsics.vkrawbufferload.bitfield.hlsl");
+}
 TEST_F(FileTest, IntrinsicsVkRawBufferStore) {
   runFileTest("intrinsics.vkrawbufferstore.hlsl");
 }
@@ -2306,6 +2318,12 @@ TEST_F(FileTest, VulkanLayoutStructRelaxedLayout) {
   // Checks VK_KHR_relaxed_block_layout on struct types
   runFileTest("vk.layout.struct.relaxed.hlsl");
 }
+TEST_F(FileTest, VulkanLayoutStructBitfield) {
+  runFileTest("vk.layout.struct.bitfield.hlsl");
+}
+TEST_F(FileTest, VulkanLayoutStructBitfieldAssignment) {
+  runFileTest("vk.layout.struct.bitfield.assignment.hlsl");
+}
 
 TEST_F(FileTest, VulkanLayoutVkOffsetAttr) {
   // Checks the behavior of [[vk::offset]]

+ 38 - 36
tools/clang/unittests/SPIRV/SpirvContextTest.cpp

@@ -418,13 +418,13 @@ TEST_F(SpirvContextTest, StructTypeUnique1) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   EXPECT_EQ(type1, type2);
@@ -437,13 +437,13 @@ TEST_F(SpirvContextTest, StructTypeUnique2) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct2", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   EXPECT_NE(type1, type2);
@@ -456,13 +456,13 @@ TEST_F(SpirvContextTest, StructTypeUnique3) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ true, StructInterfaceType::InternalStorage);
 
   EXPECT_NE(type1, type2);
@@ -475,13 +475,13 @@ TEST_F(SpirvContextTest, StructTypeUnique4) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
 
   EXPECT_NE(type1, type2);
@@ -494,13 +494,13 @@ TEST_F(SpirvContextTest, StructTypeUnique5) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field"),
-       StructType::FieldInfo(uint32, "field")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(uint32, "field"),
-       StructType::FieldInfo(int32, "field")},
+      {StructType::FieldInfo(uint32, /* fieldIndex */ 0, "field"),
+       StructType::FieldInfo(int32, /* fieldIndex */ 1, "field")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   EXPECT_NE(type1, type2);
@@ -513,13 +513,13 @@ TEST_F(SpirvContextTest, StructTypeUnique6) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "sine"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "sine"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "cosine"),
-       StructType::FieldInfo(uint32, "field2")},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "cosine"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2")},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   EXPECT_NE(type1, type2);
@@ -532,13 +532,15 @@ TEST_F(SpirvContextTest, StructTypeUnique7) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 8)},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2",
+                             /*offset*/ 8)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 4)},
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2",
+                             /*offset*/ 4)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   EXPECT_NE(type1, type2);
@@ -551,14 +553,14 @@ TEST_F(SpirvContextTest, StructTypeUnique8) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 4,
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2", /*offset*/ 4,
                              /*matrixStride*/ 16)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 4,
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2", /*offset*/ 4,
                              /*matrixStride*/ 32)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
@@ -572,14 +574,14 @@ TEST_F(SpirvContextTest, StructTypeUnique9) {
   const auto *uint32 = spvContext.getUIntType(32);
 
   const auto *type1 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 4,
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2", /*offset*/ 4,
                              /*matrixStride*/ 16, /*isRowMajor*/ false)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 
   const auto *type2 = spvContext.getStructType(
-      {StructType::FieldInfo(int32, "field1"),
-       StructType::FieldInfo(uint32, "field2", /*offset*/ 4,
+      {StructType::FieldInfo(int32, /* fieldIndex */ 0, "field1"),
+       StructType::FieldInfo(uint32, /* fieldIndex */ 1, "field2", /*offset*/ 4,
                              /*matrixStride*/ 16, /*isRowMajor*/ true)},
       "struct1", /*isReadOnly*/ false, StructInterfaceType::InternalStorage);
 

+ 3 - 2
tools/clang/unittests/SPIRV/SpirvTypeTest.cpp

@@ -126,8 +126,9 @@ TEST_F(SpirvTypeTest, StructType) {
   IntegerType int32(32, true);
   IntegerType uint32(32, false);
 
-  StructType::FieldInfo field0(&int32, "field1");
-  StructType::FieldInfo field1(&uint32, "field2", /*offset*/ 4,
+  StructType::FieldInfo field0(&int32, /* fieldIndex */ 0, "field1");
+  StructType::FieldInfo field1(&uint32, /* fieldIndex */ 1, "field2",
+                               /*offset*/ 4,
                                /*matrixStride*/ 16, /*isRowMajor*/ false);
 
   StructType s({field0, field1}, "some_struct", /*isReadOnly*/ true,