浏览代码

[spirv] Separate type handling from EmitVisitor.

We used to use a map inside EmitVisitor to ensure type uniqueness. That
brought a lot of ugliness into the EmitVisitor and frankly, that's not
what EmitVisitor should be doing.

It also caused several issues:
1- two struct types with different layout rules are not necessarily
different types: These different rules may result in the same offsets.

2- Other types (e.g. access chain result type being a pointer into a
struct with layout rules) has to resolve into a unique type before
getting to EmitVisitor.

This change drastically improves test passing rate.
Ehsan Nasiri 6 年之前
父节点
当前提交
2d8080fc8f

+ 5 - 71
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -42,44 +42,6 @@ public:
     llvm::Optional<uint32_t> memberIndex;
   };
 
-  using DecorationList = llvm::SmallVector<DecorationInfo, 4>;
-
-  // Provides DenseMapInfo for SpirvLayoutRule so that we can use it as key to
-  // DenseMap.
-  struct DecorationSetDenseMapInfo {
-    static inline DecorationList getEmptyKey() {
-      return {DecorationInfo(spv::Decoration::Max)};
-    }
-    static inline DecorationList getTombstoneKey() {
-      return {DecorationInfo(spv::Decoration::Max)};
-    }
-    static unsigned getHashValue(const DecorationList &Val) {
-      unsigned hashValue = Val.size();
-      for (auto &decorationInfo : Val)
-        hashValue += static_cast<unsigned>(decorationInfo.decoration);
-
-      return hashValue;
-    }
-
-    static bool isEqual(const DecorationList &LHS, const DecorationList &RHS) {
-      // Must have the same number of decorations.
-      if (LHS.size() != RHS.size())
-        return false;
-
-      // Order of decorations does not matter.
-      for (auto &dec : LHS) {
-        auto found = std::find_if(
-            RHS.begin(), RHS.end(),
-            [&dec](const DecorationInfo &otherDec) { return dec == otherDec; });
-
-        if (found == RHS.end())
-          return false;
-      }
-
-      return true;
-    }
-  };
-
 public:
   EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvContext,
                   std::vector<uint32_t> *debugVec,
@@ -106,8 +68,7 @@ public:
   //
   // If any decorations apply to the type, it also emits the decoration
   // instructions into the annotationsBinary.
-  uint32_t emitType(const SpirvType *, SpirvLayoutRule,
-                    llvm::Optional<bool> isRowMajor = llvm::None);
+  uint32_t emitType(const SpirvType *);
 
   // Emits an OpConstant instruction with uint32 type and returns its result-id.
   // If such constant has already been emitted, just returns its resutl-id.
@@ -119,17 +80,10 @@ private:
   void initTypeInstruction(spv::Op op);
   void finalizeTypeInstruction();
 
-  // Figures out the decorations that apply to the given type with the given
-  // layout rule, and populates the given decoration set.
-  void getDecorationsForType(const SpirvType *type, SpirvLayoutRule rule,
-                             llvm::Optional<bool> isRowMajor,
-                             DecorationList *decorations);
-
   // Returns the result-id for the given type and decorations. If a type with
   // the same decorations have already been used, it returns the existing
   // result-id. If not, creates a new result-id for such type and returns it.
-  uint32_t getResultIdForType(const SpirvType *, const DecorationList &,
-                              bool *alreadyExists);
+  uint32_t getResultIdForType(const SpirvType *, bool *alreadyExists);
 
   // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
   // targetting the given type. Uses the given decoration kind and its
@@ -154,24 +108,6 @@ private:
     return obj->getResultId();
   }
 
-  // ---- Methods associated with layout calculations ----
-
-  // If the given type is a matrix type inside a struct, its majorness should
-  // also be passed to this method in order to determine the correct alignment.
-  std::pair<uint32_t, uint32_t>
-  getAlignmentAndSize(const SpirvType *type, SpirvLayoutRule rule,
-                      uint32_t *stride,
-                      llvm::Optional<bool> isRowMajorStructMember = llvm::None);
-
-  void alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
-                                   uint32_t fieldSize, uint32_t fieldAlignment,
-                                   uint32_t *currentOffset);
-
-  // Adds the layout decorations for the given type and layout rule to the given
-  // vector of decorations.
-  void getLayoutDecorations(const StructType *, SpirvLayoutRule,
-                            DecorationList *);
-
 private:
   /// Emits error to the diagnostic engine associated with this visitor.
   template <unsigned N>
@@ -197,11 +133,9 @@ private:
   // uint value to the result-id of the OpConstant for that value.
   llvm::DenseMap<uint32_t, uint32_t> UintConstantValueToResultIdMap;
 
-  // emittedTypes is a map that caches the result-id of types with a given list
-  // of decorations in order to avoid emitting an identical type multiple times.
-  using DecorationSetToTypeIdMap =
-      llvm::DenseMap<DecorationList, uint32_t, DecorationSetDenseMapInfo>;
-  llvm::DenseMap<const SpirvType *, DecorationSetToTypeIdMap> emittedTypes;
+  // emittedTypes is a map that caches the result-id of types in order to avoid
+  // emitting an identical type multiple times.
+  llvm::DenseMap<const SpirvType *, uint32_t> emittedTypes;
 };
 
 /// \breif The visitor class that emits the SPIR-V words from the in-memory

+ 46 - 5
tools/clang/include/clang/SPIRV/LowerTypeVisitor.h

@@ -40,7 +40,8 @@ public:
 private:
   /// Emits error to the diagnostic engine associated with this visitor.
   template <unsigned N>
-  DiagnosticBuilder emitError(const char (&message)[N], SourceLocation srcLoc) {
+  DiagnosticBuilder emitError(const char (&message)[N],
+                              SourceLocation srcLoc = {}) {
     const auto diagId = astContext.getDiagnostics().getCustomDiagID(
         clang::DiagnosticsEngine::Error, message);
     return astContext.getDiagnostics().Report(srcLoc, diagId);
@@ -75,23 +76,63 @@ private:
 
   /// Returns true if type is an HLSL row-major matrix or array of matrices.
   /// Returns false if type is an HLSL col-major matrix or array of matrices.
-  /// Returns llvm::None if the type is not a matrix or array of matrices.
   /// It does so by checking the majorness of the HLSL matrix either with
   /// explicit attribute or implicit command-line option.
-  llvm::Optional<bool> isHLSLRowMajorMatrix(QualType type) const;
+  bool isHLSLRowMajorMatrix(QualType type) const;
 
   /// Returns true if type is a SPIR-V row-major matrix or array of matrices.
   /// Returns false if type is a SPIR-V col-major matrix or array of matrices.
-  /// Returns llvm::None if the type is not a matrix or array of matrices.
+  /// It does so by checking the majorness of the HLSL matrix either with
+  /// explicit attribute or implicit command-line option.
   /// 
   /// Note that HLSL matrices are conceptually row major, while SPIR-V matrices
   /// are conceptually column major. We are mapping what HLSL semantically mean
   /// a row into a column here.
-  llvm::Optional<bool> isRowMajorMatrix(QualType type) const;
+  bool isRowMajorMatrix(QualType type) const;
+
+private:
+  /// Calculates all layout information needed for the given structure fields.
+  /// Returns the lowered field info vector.
+  /// In other words: lowers the HybridStructType field information to
+  /// StructType field information.
+  llvm::SmallVector<StructType::FieldInfo, 4>
+  populateLayoutInformation(llvm::ArrayRef<HybridStructType::FieldInfo> fields,
+                            SpirvLayoutRule rule);
+
+  /// \brief Aligns currentOffset properly to allow packing vectors in the HLSL
+  /// way: using the element type's alignment as the vector alignment, as long
+  /// as there is no improper straddle.
+  /// fieldSize and fieldAlignment are the original size and alignment
+  /// calculated without considering the HLSL vector relaxed rule.
+  void alignUsingHLSLRelaxedLayout(QualType fieldType, uint32_t fieldSize,
+                                   uint32_t fieldAlignment,
+                                   uint32_t *currentOffset);
+
+  /// \brief Returns the alignment and size in bytes for the given type
+  /// according to the given LayoutRule.
+
+  /// If the type is an array/matrix type, writes the array/matrix stride to
+  /// stride. If the type is a matrix.
+  ///
+  /// Note that the size returned is not exactly how many bytes the type
+  /// will occupy in memory; rather it is used in conjunction with alignment
+  /// to get the next available location (alignment + size), which means
+  /// size contains post-paddings required by the given type.
+  std::pair<uint32_t, uint32_t>
+  getAlignmentAndSize(QualType type, SpirvLayoutRule rule, uint32_t *stride);
 
 private:
   ASTContext &astContext;   /// AST context
   SpirvContext &spvContext; /// SPIR-V context
+
+  /// A place to keep the matrix majorness attributes so that we can retrieve
+  /// the information when really processing the desugared matrix type.
+  ///
+  /// This is needed because the majorness attribute is decorated on a
+  /// TypedefType (i.e., floatMxN) of the real matrix type (i.e., matrix<elem,
+  /// row, col>). When we reach the desugared matrix type, this information
+  /// is already gone.
+  llvm::Optional<AttributedType::Kind> typeMatMajorAttr;
 };
 
 } // end namespace spirv

+ 5 - 5
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -205,8 +205,10 @@ public:
   const HybridSampledImageType *getSampledImageType(QualType image);
 
   const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount,
-                                llvm::Optional<bool> rowMajorElem);
-  const RuntimeArrayType *getRuntimeArrayType(const SpirvType *elemType);
+                                llvm::Optional<uint32_t> arrayStride);
+  const RuntimeArrayType *
+  getRuntimeArrayType(const SpirvType *elemType,
+                      llvm::Optional<uint32_t> arrayStride);
 
   const StructType *getStructType(
       llvm::ArrayRef<StructType::FieldInfo> fields, llvm::StringRef name,
@@ -254,7 +256,6 @@ private:
 
   using VectorTypeArray = std::array<const VectorType *, 5>;
   using MatrixTypeVector = std::vector<const MatrixType *>;
-  using CountToArrayMap = llvm::DenseMap<uint32_t, const ArrayType *>;
   using SCToPtrTyMap =
       llvm::DenseMap<spv::StorageClass, const SpirvPointerType *,
                      StorageClassDenseMapInfo>;
@@ -274,9 +275,8 @@ private:
   llvm::DenseMap<QualType, const HybridSampledImageType *, QualTypeDenseMapInfo>
       hybridSampledImageTypes;
 
-  //llvm::DenseMap<const SpirvType *, CountToArrayMap> arrayTypes;
   llvm::SmallVector<const ArrayType *, 8> arrayTypes;
-  llvm::DenseMap<const SpirvType *, const RuntimeArrayType *> runtimeArrayTypes;
+  llvm::SmallVector<const RuntimeArrayType *, 8> runtimeArrayTypes;
 
   llvm::SmallVector<const StructType *, 8> structTypes;
   llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;

+ 33 - 21
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -22,6 +22,8 @@
 namespace clang {
 namespace spirv {
 
+class HybridType;
+
 enum class StructInterfaceType : uint32_t {
   InternalStorage = 0,
   StorageBuffer = 1,
@@ -235,13 +237,13 @@ private:
 class ArrayType : public SpirvType {
 public:
   ArrayType(const SpirvType *elemType, uint32_t elemCount,
-            llvm::Optional<bool> hasRowMajorElem)
+            llvm::Optional<uint32_t> arrayStride)
       : SpirvType(TK_Array), elementType(elemType), elementCount(elemCount),
-        rowMajorElem(hasRowMajorElem) {}
+        stride(arrayStride) {}
 
   const SpirvType *getElementType() const { return elementType; }
   uint32_t getElementCount() const { return elementCount; }
-  llvm::Optional<bool> hasRowMajorElement() const { return rowMajorElem; }
+  llvm::Optional<uint32_t> getStride() const { return stride; }
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Array; }
 
@@ -251,39 +253,49 @@ private:
   const SpirvType *elementType;
   uint32_t elementCount;
   // Two arrays types with different ArrayStride decorations, are in fact two
-  // different array types. In general, the combination of element type and
-  // element count is enough to determine the array stride. However, in the case
-  // of arrays of matrices, we also need to know the majorness of the matrices.
-  // An array of 5 row_major 2x3 matrices is a different type from
-  // an array of 5 col_major 2x3 matrices.
-  llvm::Optional<bool> rowMajorElem;
+  // different array types. If no layout information is needed, use llvm::None.
+  llvm::Optional<uint32_t> stride;
 };
 
 class RuntimeArrayType : public SpirvType {
 public:
-  RuntimeArrayType(const SpirvType *elemType)
-      : SpirvType(TK_RuntimeArray), elementType(elemType) {}
+  RuntimeArrayType(const SpirvType *elemType,
+                   llvm::Optional<uint32_t> arrayStride)
+      : SpirvType(TK_RuntimeArray), elementType(elemType), stride(arrayStride) {
+  }
 
   static bool classof(const SpirvType *t) {
     return t->getKind() == TK_RuntimeArray;
   }
 
+  bool operator==(const RuntimeArrayType &that) const;
+
   const SpirvType *getElementType() const { return elementType; }
+  llvm::Optional<uint32_t> getStride() const { return stride; }
 
 private:
   const SpirvType *elementType;
+  // Two runtime arrays with different ArrayStride decorations, are in fact two
+  // different types. If no layout information is needed, use llvm::None.
+  llvm::Optional<uint32_t> stride;
 };
 
+// The StructType is the lowered type that best represents what a structure type
+// is in SPIR-V. Contains all necessary information for properly emitting a
+// SPIR-V structure type.
 class StructType : public SpirvType {
 public:
   struct FieldInfo {
   public:
     FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
-              clang::VKOffsetAttr *offset = nullptr,
-              hlsl::ConstantPacking *packOffset = nullptr,
-              llvm::Optional<bool> rowMajor = llvm::None)
-        : type(type_), name(name_), vkOffsetAttr(offset),
-          packOffsetAttr(packOffset), isRowMajor(rowMajor) {}
+              llvm::Optional<uint32_t> offset_ = llvm::None,
+              llvm::Optional<uint32_t> matrixStride_ = llvm::None,
+              llvm::Optional<bool> isRowMajor_ = llvm::None)
+        : type(type_), name(name_), offset(offset_),
+          matrixStride(matrixStride_), isRowMajor(isRowMajor_) {
+      // A StructType may not contain any hybrid types.
+      assert(!isa<HybridType>(type_));
+    }
 
     bool operator==(const FieldInfo &that) const;
 
@@ -291,11 +303,11 @@ public:
     const SpirvType *type;
     // The field's name.
     std::string name;
-    // vk::offset attributes associated with this field.
-    clang::VKOffsetAttr *vkOffsetAttr;
-    // :packoffset() annotations associated with this field.
-    hlsl::ConstantPacking *packOffsetAttr;
-    // The majorness of this field (if it is a matrix).
+    // The integer offset for this field.
+    llvm::Optional<uint32_t> offset;
+    // The matrix stride for this field (if applicable).
+    llvm::Optional<uint32_t> matrixStride;
+    // The majorness of this field (if applicable).
     llvm::Optional<bool> isRowMajor;
   };
 

+ 3 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -720,12 +720,11 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
 
   // Make an array if requested.
   if (arraySize > 0) {
-    // The array element is a structure, therefore no majorness information is
-    // needed.
     resultType = spvContext.getArrayType(resultType, arraySize,
-                                         /*rowMajorElement*/ llvm::None);
+                                         /*ArrayStride*/ llvm::None);
   } else if (arraySize == -1) {
-    resultType = spvContext.getRuntimeArrayType(resultType);
+    resultType =
+        spvContext.getRuntimeArrayType(resultType, /*ArrayStride*/ llvm::None);
   }
 
   // Register the <type-id> for this decl

+ 61 - 500
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -48,37 +48,6 @@ void chopString(llvm::StringRef original,
 constexpr uint32_t kGeneratorNumber = 14;
 constexpr uint32_t kToolVersion = 0;
 
-/// The alignment for 4-component float vectors.
-constexpr uint32_t kStd140Vec4Alignment = 16u;
-
-/// Rounds the given value up to the given power of 2.
-inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
-  assert(pow2 != 0);
-  return (val + pow2 - 1) & ~(pow2 - 1);
-}
-
-/// Returns true if the given vector type (of the given size) crosses the
-/// 4-component vector boundary if placed at the given offset.
-bool improperStraddle(const clang::spirv::VectorType *type, int size,
-                      int offset) {
-  return size <= 16 ? offset / 16 != (offset + size - 1) / 16
-                    : offset % 16 != 0;
-}
-
-bool isAKindOfStructuredOrByteBuffer(const clang::spirv::SpirvType *type) {
-  // Strip outer arrayness first
-  while (llvm::isa<clang::spirv::ArrayType>(type))
-    type = llvm::cast<clang::spirv::ArrayType>(type)->getElementType();
-
-  // They are structures with the first member that is of RuntimeArray type.
-  if (auto *structType = llvm::dyn_cast<clang::spirv::StructType>(type))
-    return structType->getFields().size() == 1 &&
-           llvm::isa<clang::spirv::RuntimeArrayType>(
-               structType->getFields()[0].type);
-
-  return false;
-}
-
 uint32_t zeroExtendTo32Bits(uint16_t value) {
   // TODO: The ordering of the 2 words depends on the endian-ness of the host
   // machine. Assuming Little Endian at the moment.
@@ -147,8 +116,7 @@ void EmitVisitor::emitDebugNameForInstruction(uint32_t resultId,
 void EmitVisitor::initInstruction(SpirvInstruction *inst) {
   // Emit the result type if the instruction has a result type.
   if (inst->hasResultType()) {
-    const uint32_t resultTypeId =
-        typeHandler.emitType(inst->getResultType(), inst->getLayoutRule());
+    const uint32_t resultTypeId = typeHandler.emitType(inst->getResultType());
     inst->setResultTypeId(resultTypeId);
   }
 
@@ -242,10 +210,8 @@ bool EmitVisitor::visit(SpirvFunction *fn, Phase phase) {
 
   // Before emitting the function
   if (phase == Visitor::Phase::Init) {
-    const uint32_t returnTypeId =
-        typeHandler.emitType(fn->getReturnType(), SpirvLayoutRule::Void);
-    const uint32_t functionTypeId =
-        typeHandler.emitType(fn->getFunctionType(), SpirvLayoutRule::Void);
+    const uint32_t returnTypeId = typeHandler.emitType(fn->getReturnType());
+    const uint32_t functionTypeId = typeHandler.emitType(fn->getFunctionType());
     fn->setReturnTypeId(returnTypeId);
     fn->setFunctionTypeId(functionTypeId);
 
@@ -1098,21 +1064,17 @@ void EmitTypeHandler::finalizeTypeInstruction() {
 }
 
 uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
-                                             const DecorationList &decs,
                                              bool *alreadyExists) {
   assert(alreadyExists);
   auto foundType = emittedTypes.find(type);
   if (foundType != emittedTypes.end()) {
-    auto foundDecorationSet = foundType->second.find(decs);
-    if (foundDecorationSet != foundType->second.end()) {
-      *alreadyExists = true;
-      return foundDecorationSet->second;
-    }
+    *alreadyExists = true;
+    return foundType->second;
   }
 
   *alreadyExists = false;
   const uint32_t id = takeNextIdFunction();
-  emittedTypes[type][decs] = id;
+  emittedTypes[type] = id;
   return id;
 }
 
@@ -1124,7 +1086,7 @@ uint32_t EmitTypeHandler::getOrCreateConstantUint32(uint32_t value) {
 
   const uint32_t constantResultId = takeNextIdFunction();
   const SpirvType *uintType = context.getUIntType(32);
-  const uint32_t uint32TypeId = emitType(uintType, SpirvLayoutRule::Void);
+  const uint32_t uint32TypeId = emitType(uintType);
   initTypeInstruction(spv::Op::OpConstant);
   curTypeInst.push_back(uint32TypeId);
   curTypeInst.push_back(constantResultId);
@@ -1135,66 +1097,10 @@ uint32_t EmitTypeHandler::getOrCreateConstantUint32(uint32_t value) {
   return constantResultId;
 }
 
-void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
-                                            SpirvLayoutRule rule,
-                                            llvm::Optional<bool> isRowMajor,
-                                            DecorationList *decs) {
-  // Array types
-  if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
-    // ArrayStride decoration is needed for array types, but we won't have
-    // stride information for structured/byte buffers since they contain runtime
-    // arrays.
-    if (rule != SpirvLayoutRule::Void &&
-        !isAKindOfStructuredOrByteBuffer(type)) {
-      uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride,
-                                arrayType->hasRowMajorElement());
-      decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
-    }
-  }
-  // RuntimeArray types
-  else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
-    // ArrayStride decoration is needed for runtime arrays containing structures
-    // (StructuredBuffers).
-    if (rule != SpirvLayoutRule::Void &&
-        !isa<ImageType>(raType->getElementType())) {
-      uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride, isRowMajor);
-      decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
-    }
-  }
-  // Structure types
-  else if (const auto *structType = dyn_cast<StructType>(type)) {
-    llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
-    size_t numFields = fields.size();
-
-    // Emit the layout decorations for the structure.
-    getLayoutDecorations(structType, rule, decs);
-
-    // Emit NonWritable decorations
-    if (structType->isReadOnly())
-      for (size_t i = 0; i < numFields; ++i)
-        decs->push_back(DecorationInfo(spv::Decoration::NonWritable, {}, i));
-
-    // Emit Block or BufferBlock decorations if necessary.
-    auto interfaceType = structType->getInterfaceType();
-    if (interfaceType == StructInterfaceType::StorageBuffer)
-      decs->push_back(DecorationInfo(spv::Decoration::BufferBlock, {}));
-    else if (interfaceType == StructInterfaceType::UniformBuffer)
-      decs->push_back(DecorationInfo(spv::Decoration::Block, {}));
-  }
-
-  // We currently only have decorations for arrays, runtime arrays, and
-  // structure types.
-}
-
-uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
-                                   llvm::Optional<bool> isRowMajor) {
+uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
   // First get the decorations that would apply to this type.
   bool alreadyExists = false;
-  DecorationList decs;
-  getDecorationsForType(type, rule, isRowMajor, &decs);
-  const uint32_t id = getResultIdForType(type, decs, &alreadyExists);
+  const uint32_t id = getResultIdForType(type, &alreadyExists);
 
   // If the type has already been emitted, we just need to return its
   // <result-id>.
@@ -1232,8 +1138,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // Vector types
   else if (const auto *vecType = dyn_cast<VectorType>(type)) {
-    const uint32_t elementTypeId =
-        emitType(vecType->getElementType(), rule, isRowMajor);
+    const uint32_t elementTypeId = emitType(vecType->getElementType());
     initTypeInstruction(spv::Op::OpTypeVector);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elementTypeId);
@@ -1242,8 +1147,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // Matrix types
   else if (const auto *matType = dyn_cast<MatrixType>(type)) {
-    const uint32_t vecTypeId =
-        emitType(matType->getVecType(), rule, isRowMajor);
+    const uint32_t vecTypeId = emitType(matType->getVecType());
     initTypeInstruction(spv::Op::OpTypeMatrix);
     curTypeInst.push_back(id);
     curTypeInst.push_back(vecTypeId);
@@ -1254,8 +1158,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // Image types
   else if (const auto *imageType = dyn_cast<ImageType>(type)) {
-    const uint32_t sampledTypeId =
-        emitType(imageType->getSampledType(), rule, isRowMajor);
+    const uint32_t sampledTypeId = emitType(imageType->getSampledType());
     initTypeInstruction(spv::Op::OpTypeImage);
     curTypeInst.push_back(id);
     curTypeInst.push_back(sampledTypeId);
@@ -1275,8 +1178,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // SampledImage types
   else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
-    const uint32_t imageTypeId =
-        emitType(sampledImageType->getImageType(), rule, isRowMajor);
+    const uint32_t imageTypeId = emitType(sampledImageType->getImageType());
     initTypeInstruction(spv::Op::OpTypeSampledImage);
     curTypeInst.push_back(id);
     curTypeInst.push_back(imageTypeId);
@@ -1289,22 +1191,28 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
     const auto length = getOrCreateConstantUint32(arrayType->getElementCount());
 
     // Emit the OpTypeArray instruction
-    const uint32_t elemTypeId =
-        emitType(arrayType->getElementType(), rule, isRowMajor);
+    const uint32_t elemTypeId = emitType(arrayType->getElementType());
     initTypeInstruction(spv::Op::OpTypeArray);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
     curTypeInst.push_back(length);
     finalizeTypeInstruction();
+
+    auto stride = arrayType->getStride();
+    if (stride.hasValue())
+      emitDecoration(id, spv::Decoration::ArrayStride, {stride.getValue()});
   }
   // RuntimeArray types
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
-    const uint32_t elemTypeId =
-        emitType(raType->getElementType(), rule, isRowMajor);
+    const uint32_t elemTypeId = emitType(raType->getElementType());
     initTypeInstruction(spv::Op::OpTypeRuntimeArray);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
     finalizeTypeInstruction();
+
+    auto stride = raType->getStride();
+    if (stride.hasValue())
+      emitDecoration(id, spv::Decoration::ArrayStride, {stride.getValue()});
   }
   // Structure types
   else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -1316,8 +1224,41 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
       emitNameForType(fields[i].name, id, i);
 
     llvm::SmallVector<uint32_t, 4> fieldTypeIds;
-    for (auto &field : fields)
-      fieldTypeIds.push_back(emitType(field.type, rule, field.isRowMajor));
+    for (auto &field : fields) {
+      fieldTypeIds.push_back(emitType(field.type));
+    }
+
+    for (size_t i = 0; i < numFields; ++i) {
+      auto &field = fields[i];
+      // Offset decorations
+      if (field.offset.hasValue())
+        emitDecoration(id, spv::Decoration::Offset, {field.offset.getValue()},
+                       i);
+
+      // MatrixStride decorations
+      if (field.matrixStride.hasValue())
+        emitDecoration(id, spv::Decoration::MatrixStride,
+                       {field.matrixStride.getValue()}, i);
+
+      // RowMajor/ColMajor decorations
+      if (field.isRowMajor.hasValue())
+        emitDecoration(id,
+                       field.isRowMajor.getValue() ? spv::Decoration::RowMajor
+                                                   : spv::Decoration::ColMajor,
+                       {}, i);
+
+      // NonWritable decorations
+      if (structType->isReadOnly())
+        emitDecoration(id, spv::Decoration::NonWritable, {}, i);
+    }
+
+    // Emit Block or BufferBlock decorations if necessary.
+    auto interfaceType = structType->getInterfaceType();
+    if (interfaceType == StructInterfaceType::StorageBuffer)
+      emitDecoration(id, spv::Decoration::BufferBlock, {});
+    else if (interfaceType == StructInterfaceType::UniformBuffer)
+      emitDecoration(id, spv::Decoration::Block, {});
+
     initTypeInstruction(spv::Op::OpTypeStruct);
     curTypeInst.push_back(id);
     for (auto fieldTypeId : fieldTypeIds)
@@ -1326,8 +1267,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // Pointer types
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
-    const uint32_t pointeeType =
-        emitType(ptrType->getPointeeType(), rule, isRowMajor);
+    const uint32_t pointeeType = emitType(ptrType->getPointeeType());
     initTypeInstruction(spv::Op::OpTypePointer);
     curTypeInst.push_back(id);
     curTypeInst.push_back(static_cast<uint32_t>(ptrType->getStorageClass()));
@@ -1336,11 +1276,10 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
   }
   // Function types
   else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
-    const uint32_t retTypeId =
-        emitType(fnType->getReturnType(), rule, isRowMajor);
+    const uint32_t retTypeId = emitType(fnType->getReturnType());
     llvm::SmallVector<uint32_t, 4> paramTypeIds;
     for (auto *paramType : fnType->getParamTypes())
-      paramTypeIds.push_back(emitType(paramType, rule, isRowMajor));
+      paramTypeIds.push_back(emitType(paramType));
 
     initTypeInstruction(spv::Op::OpTypeFunction);
     curTypeInst.push_back(id);
@@ -1361,387 +1300,9 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
     llvm_unreachable("unhandled type in emitType");
   }
 
-  // Finally, emit decorations for the type into the annotationsBinary.
-  for (auto &decorInfo : decs)
-    emitDecoration(id, decorInfo.decoration, decorInfo.decorationParams,
-                   decorInfo.memberIndex);
-
   return id;
 }
 
-std::pair<uint32_t, uint32_t>
-EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
-                                     SpirvLayoutRule rule, uint32_t *stride,
-                                     llvm::Optional<bool> isRowMajor) {
-  // std140 layout rules:
-
-  // 1. If the member is a scalar consuming N basic machine units, the base
-  //    alignment is N.
-  //
-  // 2. If the member is a two- or four-component vector with components
-  //    consuming N basic machine units, the base alignment is 2N or 4N,
-  //    respectively.
-  //
-  // 3. If the member is a three-component vector with components consuming N
-  //    basic machine units, the base alignment is 4N.
-  //
-  // 4. If the member is an array of scalars or vectors, the base alignment and
-  //    array stride are set to match the base alignment of a single array
-  //    element, according to rules (1), (2), and (3), and rounded up to the
-  //    base alignment of a vec4. The array may have padding at the end; the
-  //    base offset of the member following the array is rounded up to the next
-  //    multiple of the base alignment.
-  //
-  // 5. If the member is a column-major matrix with C columns and R rows, the
-  //    matrix is stored identically to an array of C column vectors with R
-  //    components each, according to rule (4).
-  //
-  // 6. If the member is an array of S column-major matrices with C columns and
-  //    R rows, the matrix is stored identically to a row of S X C column
-  //    vectors with R components each, according to rule (4).
-  //
-  // 7. If the member is a row-major matrix with C columns and R rows, the
-  //    matrix is stored identically to an array of R row vectors with C
-  //    components each, according to rule (4).
-  //
-  // 8. If the member is an array of S row-major matrices with C columns and R
-  //    rows, the matrix is stored identically to a row of S X R row vectors
-  //    with C components each, according to rule (4).
-  //
-  // 9. If the member is a structure, the base alignment of the structure is N,
-  //    where N is the largest base alignment value of any of its members, and
-  //    rounded up to the base alignment of a vec4. The individual members of
-  //    this substructure are then assigned offsets by applying this set of
-  //    rules recursively, where the base offset of the first member of the
-  //    sub-structure is equal to the aligned offset of the structure. The
-  //    structure may have padding at the end; the base offset of the member
-  //    following the sub-structure is rounded up to the next multiple of the
-  //    base alignment of the structure.
-  //
-  // 10. If the member is an array of S structures, the S elements of the array
-  //     are laid out in order, according to rule (9).
-  //
-  // This method supports multiple layout rules, all of them modifying the
-  // std140 rules listed above:
-  //
-  // std430:
-  // - Array base alignment and stride does not need to be rounded up to a
-  //   multiple of 16.
-  // - Struct base alignment does not need to be rounded up to a multiple of 16.
-  //
-  // Relaxed std140/std430:
-  // - Vector base alignment is set as its element type's base alignment.
-  //
-  // FxcCTBuffer:
-  // - Vector base alignment is set as its element type's base alignment.
-  // - Arrays/structs do not need to have padding at the end; arrays/structs do
-  //   not affect the base offset of the member following them.
-  //
-  // FxcSBuffer:
-  // - Vector/matrix/array base alignment is set as its element type's base
-  //   alignment.
-  // - Arrays/structs do not need to have padding at the end; arrays/structs do
-  //   not affect the base offset of the member following them.
-  // - Struct base alignment does not need to be rounded up to a multiple of 16.
-
-  { // Rule 1
-    if (isa<BoolType>(type))
-      return {4, 4};
-    // Integer and Float types are NumericalType
-    if (auto *numericType = dyn_cast<NumericalType>(type)) {
-      switch (numericType->getBitwidth()) {
-      case 64:
-        return {8, 8};
-      case 32:
-        return {4, 4};
-      case 16:
-        return {2, 2};
-      default:
-        emitError("alignment and size calculation unimplemented for type");
-        return {0, 0};
-      }
-    }
-  }
-
-  { // Rule 2 and 3
-    if (auto *vecType = dyn_cast<VectorType>(type)) {
-      uint32_t alignment = 0, size = 0;
-      uint32_t elemCount = vecType->getElementCount();
-      const SpirvType *elemType = vecType->getElementType();
-      std::tie(alignment, size) =
-          getAlignmentAndSize(elemType, rule, stride, isRowMajor);
-      // Use element alignment for fxc rules
-      if (rule != SpirvLayoutRule::FxcCTBuffer &&
-          rule != SpirvLayoutRule::FxcSBuffer)
-        alignment = (elemCount == 3 ? 4 : elemCount) * size;
-
-      return {alignment, elemCount * size};
-    }
-  }
-
-  { // Rule 5 and 7
-    if (auto *matType = dyn_cast<MatrixType>(type)) {
-      const SpirvType *elemType = matType->getElementType();
-      uint32_t rowCount = matType->numRows();
-      uint32_t colCount = matType->numCols();
-      uint32_t alignment = 0, size = 0;
-      std::tie(alignment, size) =
-          getAlignmentAndSize(elemType, rule, stride, isRowMajor);
-
-      // Matrices are treated as arrays of vectors:
-      // The base alignment and array stride are set to match the base alignment
-      // of a single array element, according to rules 1, 2, and 3, and rounded
-      // up to the base alignment of a vec4.
-      assert(isRowMajor.hasValue());
-      bool rowMajor = isRowMajor.getValue();
-      const uint32_t vecStorageSize = rowMajor ? colCount : rowCount;
-
-      if (rule == SpirvLayoutRule::FxcSBuffer) {
-        *stride = vecStorageSize * size;
-        // Use element alignment for fxc structured buffers
-        return {alignment, rowCount * colCount * size};
-      }
-
-      alignment *= (vecStorageSize == 3 ? 4 : vecStorageSize);
-      if (rule == SpirvLayoutRule::GLSLStd140 ||
-          rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
-          rule == SpirvLayoutRule::FxcCTBuffer) {
-        alignment = roundToPow2(alignment, kStd140Vec4Alignment);
-      }
-      *stride = alignment;
-      size = (rowMajor ? rowCount : colCount) * alignment;
-
-      return {alignment, size};
-    }
-  }
-
-  // Rule 9
-  if (auto *structType = dyn_cast<StructType>(type)) {
-    // Special case for handling empty structs, whose size is 0 and has no
-    // requirement over alignment (thus 1).
-    if (structType->getFields().size() == 0)
-      return {1, 0};
-
-    uint32_t maxAlignment = 0;
-    uint32_t structSize = 0;
-
-    for (auto &field : structType->getFields()) {
-      uint32_t memberAlignment = 0, memberSize = 0;
-      std::tie(memberAlignment, memberSize) =
-          getAlignmentAndSize(field.type, rule, stride, field.isRowMajor);
-
-      if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
-          rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
-          rule == SpirvLayoutRule::FxcCTBuffer) {
-        alignUsingHLSLRelaxedLayout(field.type, memberSize, memberAlignment,
-                                    &structSize);
-      } else {
-        structSize = roundToPow2(structSize, memberAlignment);
-      }
-
-      // Reset the current offset to the one specified in the source code
-      // if exists. It's debatable whether we should do sanity check here.
-      // If the developers want manually control the layout, we leave
-      // everything to them.
-      if (field.vkOffsetAttr) {
-        structSize = field.vkOffsetAttr->getOffset();
-      }
-
-      // The base alignment of the structure is N, where N is the largest
-      // base alignment value of any of its members...
-      maxAlignment = std::max(maxAlignment, memberAlignment);
-      structSize += memberSize;
-    }
-
-    if (rule == SpirvLayoutRule::GLSLStd140 ||
-        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
-        rule == SpirvLayoutRule::FxcCTBuffer) {
-      // ... and rounded up to the base alignment of a vec4.
-      maxAlignment = roundToPow2(maxAlignment, kStd140Vec4Alignment);
-    }
-
-    if (rule != SpirvLayoutRule::FxcCTBuffer &&
-        rule != SpirvLayoutRule::FxcSBuffer) {
-      // The base offset of the member following the sub-structure is rounded up
-      // to the next multiple of the base alignment of the structure.
-      structSize = roundToPow2(structSize, maxAlignment);
-    }
-    return {maxAlignment, structSize};
-  }
-
-  // Rule 4, 6, 8, and 10
-  auto *arrayType = dyn_cast<ArrayType>(type);
-  auto *raType = dyn_cast<RuntimeArrayType>(type);
-  if (arrayType)
-    isRowMajor = arrayType->hasRowMajorElement();
-  if (arrayType || raType) {
-    // Some exaplanation about runtime arrays:
-    // The number of elements in a runtime array is unknown at compile time. As
-    // a result, it would in fact be illegal to have a runtime array in a
-    // structure *unless* it is the *only* member in the structure. In such a
-    // case, we don't care about size and stride, and only care about alignment.
-    // Therefore, to re-use the logic of array types, we'll consider a runtime
-    // array as an array of size 1.
-    const auto elemCount = arrayType ? arrayType->getElementCount() : 1;
-    const auto *elemType =
-        arrayType ? arrayType->getElementType() : raType->getElementType();
-
-    uint32_t alignment = 0, size = 0;
-    std::tie(alignment, size) =
-        getAlignmentAndSize(elemType, rule, stride, isRowMajor);
-
-    if (rule == SpirvLayoutRule::FxcSBuffer) {
-      *stride = size;
-      // Use element alignment for fxc structured buffers
-      return {alignment, size * elemCount};
-    }
-
-    if (rule == SpirvLayoutRule::GLSLStd140 ||
-        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
-        rule == SpirvLayoutRule::FxcCTBuffer) {
-      // The base alignment and array stride are set to match the base alignment
-      // of a single array element, according to rules 1, 2, and 3, and rounded
-      // up to the base alignment of a vec4.
-      alignment = roundToPow2(alignment, kStd140Vec4Alignment);
-    }
-    if (rule == SpirvLayoutRule::FxcCTBuffer) {
-      // In fxc cbuffer/tbuffer packing rules, arrays does not affect the data
-      // packing after it. But we still need to make sure paddings are inserted
-      // internally if necessary.
-      *stride = roundToPow2(size, alignment);
-      size += *stride * (elemCount - 1);
-    } else {
-      // Need to round size up considering stride for scalar types
-      size = roundToPow2(size, alignment);
-      *stride = size; // Use size instead of alignment here for Rule 10
-      size *= elemCount;
-      // The base offset of the member following the array is rounded up to the
-      // next multiple of the base alignment.
-      size = roundToPow2(size, alignment);
-    }
-
-    return {alignment, size};
-  }
-
-  emitError("alignment and size calculation unimplemented for type");
-  return {0, 0};
-}
-
-void EmitTypeHandler::alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
-                                                  uint32_t fieldSize,
-                                                  uint32_t fieldAlignment,
-                                                  uint32_t *currentOffset) {
-  if (auto *vecType = dyn_cast<VectorType>(fieldType)) {
-    const SpirvType *elemType = vecType->getElementType();
-    // Adjust according to HLSL relaxed layout rules.
-    // Aligning vectors as their element types so that we can pack a float
-    // and a float3 tightly together.
-    uint32_t scalarAlignment = 0;
-    std::tie(scalarAlignment, std::ignore) =
-        getAlignmentAndSize(elemType, SpirvLayoutRule::Void, nullptr);
-    if (scalarAlignment <= 4)
-      fieldAlignment = scalarAlignment;
-
-    *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
-
-    // Adjust according to HLSL relaxed layout rules.
-    // Bump to 4-component vector alignment if there is a bad straddle
-    if (improperStraddle(vecType, fieldSize, *currentOffset)) {
-      fieldAlignment = kStd140Vec4Alignment;
-      *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
-    }
-  }
-  // Cases where the field is not a vector
-  else {
-    *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
-  }
-}
-
-void EmitTypeHandler::getLayoutDecorations(const StructType *structType,
-                                           SpirvLayoutRule rule,
-                                           DecorationList *decs) {
-  assert(decs);
-
-  uint32_t offset = 0, index = 0;
-  for (auto &field : structType->getFields()) {
-    const SpirvType *fieldType = field.type;
-    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
-    std::tie(memberAlignment, memberSize) =
-        getAlignmentAndSize(fieldType, rule, &stride, field.isRowMajor);
-
-    // The next avaiable location after laying out the previous members
-    const uint32_t nextLoc = offset;
-
-    if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
-        rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
-        rule == SpirvLayoutRule::FxcCTBuffer) {
-      alignUsingHLSLRelaxedLayout(fieldType, memberSize, memberAlignment,
-                                  &offset);
-    } else {
-      offset = roundToPow2(offset, memberAlignment);
-    }
-
-    // The vk::offset attribute takes precedence over all.
-    if (field.vkOffsetAttr) {
-      offset = field.vkOffsetAttr->getOffset();
-    }
-    // The :packoffset() annotation takes precedence over normal layout
-    // calculation.
-    else if (field.packOffsetAttr) {
-      const uint32_t packOffset = field.packOffsetAttr->Subcomponent * 16 +
-                                  field.packOffsetAttr->ComponentOffset * 4;
-      // Do minimal check to make sure the offset specified by packoffset does
-      // not cause overlap.
-      if (packOffset < nextLoc) {
-        emitError("packoffset caused overlap with previous members",
-                  field.packOffsetAttr->Loc);
-      } else {
-        offset = packOffset;
-      }
-    }
-
-    // Each structure-type member must have an Offset Decoration.
-    decs->push_back(DecorationInfo(spv::Decoration::Offset, {offset}, index));
-    offset += memberSize;
-
-    // Each structure-type member that is a matrix or array-of-matrices must be
-    // decorated with
-    // * A MatrixStride decoration, and
-    // * one of the RowMajor or ColMajor Decorations.
-    if (auto *arrayType = dyn_cast<ArrayType>(fieldType)) {
-      // We have an array of matrices as a field, we need to decorate
-      // MatrixStride on the field. So skip possible arrays here.
-      fieldType = arrayType->getElementType();
-    }
-
-    // Non-floating point matrices are represented as arrays of vectors, and
-    // therefore ColMajor and RowMajor decorations should not be applied to
-    // them.
-    if (auto *matType = dyn_cast<MatrixType>(fieldType)) {
-      if (isa<FloatType>(matType->getElementType())) {
-        memberAlignment = memberSize = stride = 0;
-        std::tie(memberAlignment, memberSize) =
-            getAlignmentAndSize(fieldType, rule, &stride, field.isRowMajor);
-
-        decs->push_back(
-            DecorationInfo(spv::Decoration::MatrixStride, {stride}, index));
-
-        if (field.isRowMajor.hasValue()) {
-          if (field.isRowMajor.getValue()) {
-            decs->push_back(
-                DecorationInfo(spv::Decoration::RowMajor, {}, index));
-          } else {
-            decs->push_back(
-                DecorationInfo(spv::Decoration::ColMajor, {}, index));
-          }
-        }
-      }
-    }
-
-    ++index;
-  }
-}
-
 void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
                                      spv::Decoration decoration,
                                      llvm::ArrayRef<uint32_t> decorationParams,

+ 5 - 5
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -309,14 +309,14 @@ void GlPerVertex::calculateClipCullDistanceArraySize() {
 SpirvVariable *GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
                                                       uint32_t arraySize) {
   const ArrayType *type = spvContext.getArrayType(
-      spvContext.getFloatType(32), arraySize, /*rowMajorElem*/ llvm::None);
+      spvContext.getFloatType(32), arraySize, /*ArrayStride*/ llvm::None);
 
   if (asInput && inArraySize != 0) {
     type =
-        spvContext.getArrayType(type, inArraySize, /*rowMajorElem*/ llvm::None);
+        spvContext.getArrayType(type, inArraySize, /*ArrayStride*/ llvm::None);
   } else if (!asInput && outArraySize != 0) {
     type = spvContext.getArrayType(type, outArraySize,
-                                   /*rowMajorElem*/ llvm::None);
+                                   /*ArrayStride*/ llvm::None);
   }
 
   spv::StorageClass sc =
@@ -432,7 +432,7 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(bool isClip,
 
   if (isScalarType(asType)) {
     arrayType = spvContext.getArrayType(f32Type, inArraySize,
-                                        /*rowMajorElem*/ llvm::None);
+                                        /*ArrayStride*/ llvm::None);
     for (uint32_t i = 0; i < inArraySize; ++i) {
       auto *ptr = spvBuilder.createAccessChain(
           ptrType, clipCullVar,
@@ -443,7 +443,7 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(bool isClip,
   } else if (isVectorType(asType, &elemType, &count)) {
     arrayType =
         spvContext.getArrayType(spvContext.getVectorType(f32Type, count),
-                                inArraySize, /*rowMajorElem*/ llvm::None);
+                                inArraySize, /*ArrayStride*/ llvm::None);
 
     for (uint32_t i = 0; i < inArraySize; ++i) {
       // For each gl_PerVertex block, we need to read a vector from it.

+ 506 - 79
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -24,6 +24,37 @@ hlsl::ConstantPacking *getPackOffset(const clang::NamedDecl *decl) {
   return nullptr;
 }
 
+/// The alignment for 4-component float vectors.
+constexpr uint32_t kStd140Vec4Alignment = 16u;
+
+/// Rounds the given value up to the given power of 2.
+inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
+  assert(pow2 != 0);
+  return (val + pow2 - 1) & ~(pow2 - 1);
+}
+
+/// Returns true if the given vector type (of the given size) crosses the
+/// 4-component vector boundary if placed at the given offset.
+bool improperStraddle(clang::QualType type, int size, int offset) {
+  assert(clang::spirv::isVectorType(type));
+  return size <= 16 ? offset / 16 != (offset + size - 1) / 16
+                    : offset % 16 != 0;
+}
+
+bool isAKindOfStructuredOrByteBuffer(const clang::spirv::SpirvType *type) {
+  // Strip outer arrayness first
+  while (llvm::isa<clang::spirv::ArrayType>(type))
+    type = llvm::cast<clang::spirv::ArrayType>(type)->getElementType();
+
+  // They are structures with the first member that is of RuntimeArray type.
+  if (auto *structType = llvm::dyn_cast<clang::spirv::StructType>(type))
+    return structType->getFields().size() == 1 &&
+           llvm::isa<clang::spirv::RuntimeArrayType>(
+               structType->getFields()[0].type);
+
+  return false;
+}
+
 } // end anonymous namespace
 
 namespace clang {
@@ -147,17 +178,11 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     return spvContext.getFunctionType(spirvReturnType, paramTypes);
   } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(type)) {
     // lower all fields of the struct.
-    std::vector<StructType::FieldInfo> structFields;
-    for (auto field : hybridStruct->getFields()) {
-      const SpirvType *fieldSpirvType = lowerType(field.astType, rule, loc);
-      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(field.astType);
-      structFields.push_back(
-          StructType::FieldInfo(fieldSpirvType, field.name, field.vkOffsetAttr,
-                                field.packOffsetAttr, isRowMajor));
-    }
-    return spvContext.getStructType(structFields, hybridStruct->getStructName(),
-                                    hybridStruct->isReadOnly(),
-                                    hybridStruct->getInterfaceType());
+    auto loweredFields =
+        populateLayoutInformation(hybridStruct->getFields(), rule);
+    return spvContext.getStructType(
+        loweredFields, hybridStruct->getStructName(),
+        hybridStruct->isReadOnly(), hybridStruct->getInterfaceType());
   }
   // Void, bool, int, float cannot be further lowered.
   // Matrices cannot contain hybrid types. Only matrices of scalars are valid.
@@ -186,8 +211,9 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     // If array didn't contain any hybrid types, return itself.
     if (arrType->getElementType() == loweredElemType)
       return arrType;
+
     return spvContext.getArrayType(loweredElemType, arrType->getElementCount(),
-                                   arrType->hasRowMajorElement());
+                                   arrType->getStride());
   }
   // Runtime arrays could contain a hybrid type
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
@@ -195,32 +221,13 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
         lowerType(raType->getElementType(), rule, loc);
     // If runtime array didn't contain any hybrid types, return itself.
     if (raType->getElementType() == loweredElemType)
-      return arrType;
-    return spvContext.getRuntimeArrayType(loweredElemType);
+      return raType;
+    return spvContext.getRuntimeArrayType(loweredElemType, raType->getStride());
   }
   // Struct types could contain a hybrid type
   else if (const auto *structType = dyn_cast<StructType>(type)) {
-    const auto &fields = structType->getFields();
-    llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
-    bool wasLowered = false;
-    for (auto &field : fields) {
-      const auto *loweredFieldType = lowerType(field.type, rule, loc);
-      if (loweredFieldType != field.type) {
-        wasLowered = true;
-        loweredFields.push_back(StructType::FieldInfo(
-            loweredFieldType, field.name, field.vkOffsetAttr,
-            field.packOffsetAttr, field.isRowMajor));
-      } else {
-        loweredFields.push_back(field);
-      }
-    }
-    // If the struct didn't contain any hybrid types, return itself.
-    if (!wasLowered)
-      return structType;
-
-    return spvContext.getStructType(loweredFields, structType->getStructName(),
-                                    structType->isReadOnly(),
-                                    structType->getInterfaceType());
+    // Struct types can not contain hybrid types.
+    return structType;
   }
   // Pointer types could point to a hybrid type.
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
@@ -264,6 +271,11 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
 
   if (desugaredType != type) {
     const auto *spvType = lowerType(desugaredType, rule, srcLoc);
+    // Clear matrix majorness potentially set by previous desugarType() calls.
+    // This field will only be set when we were saying a matrix type. And the
+    // above lowerType() call already takes the majorness into consideration.
+    // So should be fine to clear now.
+    typeMatMajorAttr = llvm::None;
     return spvType;
   }
 
@@ -364,10 +376,16 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
 
       // Non-float matrices are represented as an array of vectors.
       if (!elemType->isFloatingType()) {
-        // This return type is ArrayType
-        // This is an array of vectors. No majorness information needed.
-        return spvContext.getArrayType(vecType, rowCount,
-                                       /*rowMajorElem*/ llvm::None);
+        llvm::Optional<uint32_t> arrayStride = llvm::None;
+        // If there is a layout rule, we need array stride information.
+        if (rule != SpirvLayoutRule::Void) {
+          uint32_t stride = 0;
+          (void)getAlignmentAndSize(type, rule, &stride);
+          arrayStride = stride;
+        }
+
+        // This return type is ArrayType.
+        return spvContext.getArrayType(vecType, rowCount, arrayStride);
       }
 
       return spvContext.getMatrixType(vecType, rowCount);
@@ -386,41 +404,52 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
       return spvType;
 
     // Collect all fields' information.
-    llvm::SmallVector<StructType::FieldInfo, 8> fields;
+    llvm::SmallVector<HybridStructType::FieldInfo, 8> fields;
 
     // If this struct is derived from some other struct, place an implicit
     // field at the very beginning for the base struct.
-    if (const auto *cxxDecl = dyn_cast<CXXRecordDecl>(decl))
+    if (const auto *cxxDecl = dyn_cast<CXXRecordDecl>(decl)) {
       for (const auto base : cxxDecl->bases()) {
-        fields.push_back(
-            StructType::FieldInfo(lowerType(base.getType(), rule, srcLoc)));
+        fields.push_back(HybridStructType::FieldInfo(base.getType()));
       }
+    }
 
     // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
-      const SpirvType *fieldType = lowerType(field->getType(), rule, srcLoc);
-      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(field->getType());
-      fields.push_back(StructType::FieldInfo(
-          fieldType, field->getName(),
+      fields.push_back(HybridStructType::FieldInfo(
+          field->getType(), field->getName(),
           /*vkoffset*/ field->getAttr<VKOffsetAttr>(),
-          /*packoffset*/ getPackOffset(field), isRowMajor));
+          /*packoffset*/ getPackOffset(field)));
     }
 
-    return spvContext.getStructType(fields, decl->getName());
+    auto loweredFields = populateLayoutInformation(fields, rule);
+
+    return spvContext.getStructType(loweredFields, decl->getName());
   }
 
   // Array type
   if (const auto *arrayType = astContext.getAsArrayType(type)) {
-    const auto *elemType = lowerType(arrayType->getElementType(), rule, srcLoc);
+    const auto elemType = arrayType->getElementType();
+    const auto *loweredElemType =
+        lowerType(arrayType->getElementType(), rule, srcLoc);
+    llvm::Optional<uint32_t> arrayStride = llvm::None;
+
+    if (rule != SpirvLayoutRule::Void &&
+        // We won't have stride information for structured/byte buffers since
+        // they contain runtime arrays.
+        !isAKindOfStructuredOrByteBuffer(elemType)) {
+      uint32_t stride = 0;
+      (void)getAlignmentAndSize(type, rule, &stride);
+      arrayStride = stride;
+    }
 
     if (const auto *caType = astContext.getAsConstantArrayType(type)) {
       const auto size = static_cast<uint32_t>(caType->getSize().getZExtValue());
-      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
-      return spvContext.getArrayType(elemType, size, isRowMajor);
+      return spvContext.getArrayType(loweredElemType, size, arrayStride);
     }
 
     assert(type->isIncompleteArrayType());
-    return spvContext.getRuntimeArrayType(elemType);
+    return spvContext.getRuntimeArrayType(loweredElemType, arrayStride);
   }
 
   // Reference types
@@ -528,13 +557,23 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     else
       structName = getAstTypeName(innerType);
 
-    const auto *raType = spvContext.getRuntimeArrayType(structType);
+    uint32_t size = 0, stride = 0;
+    std::tie(std::ignore, size) = getAlignmentAndSize(s, rule, &stride);
+
+    // We have a runtime array of structures. So:
+    // The stride of the runtime array is the size of the struct.
+    const auto *raType = spvContext.getRuntimeArrayType(structType, size);
     const bool isReadOnly = (name == "StructuredBuffer");
 
+    // Attach majorness decoration if this is a *StructuredBuffer<matrix>.
+    llvm::Optional<bool> isRowMajor =
+        isMxNMatrix(s) ? llvm::Optional<bool>(isRowMajorMatrix(s)) : llvm::None;
+
     const std::string typeName = "type." + name.str() + "." + structName;
     const auto *valType = spvContext.getStructType(
-        {StructType::FieldInfo(raType)}, typeName, isReadOnly,
-        StructInterfaceType::StorageBuffer);
+        {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0,
+                               /*matrixStride*/ llvm::None, isRowMajor)},
+        typeName, isReadOnly, StructInterfaceType::StorageBuffer);
 
     if (asAlias) {
       // All structured buffers are in the Uniform storage class.
@@ -591,17 +630,15 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
   if (name == "InputPatch") {
     const auto elemType = hlsl::GetHLSLInputPatchElementType(type);
     const auto elemCount = hlsl::GetHLSLInputPatchCount(type);
-    llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
     return spvContext.getArrayType(lowerType(elemType, rule, srcLoc), elemCount,
-                                   isRowMajor);
+                                   /*ArrayStride*/ llvm::None);
   }
   // OutputPatch
   if (name == "OutputPatch") {
     const auto elemType = hlsl::GetHLSLOutputPatchElementType(type);
     const auto elemCount = hlsl::GetHLSLOutputPatchCount(type);
-    llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
-    return spvContext.getArrayType(lowerType(elemType, rule, srcLoc),
-                                   elemCount, isRowMajor);
+    return spvContext.getArrayType(lowerType(elemType, rule, srcLoc), elemCount,
+                                   /*ArrayStride*/ llvm::None);
   }
   // Output stream objects (TriangleStream, LineStream, and PointStream)
   if (name == "TriangleStream" || name == "LineStream" ||
@@ -660,6 +697,15 @@ LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
 
 QualType LowerTypeVisitor::desugarType(QualType type) {
   if (const auto *attrType = type->getAs<AttributedType>()) {
+    switch (auto kind = attrType->getAttrKind()) {
+    case AttributedType::attr_hlsl_row_major:
+    case AttributedType::attr_hlsl_column_major:
+      typeMatMajorAttr = kind;
+      break;
+    default:
+      // Only matrices should apply to typeMatMajorAttr.
+      break;
+    }
     return desugarType(
         attrType->getLocallyUnqualifiedSingleStepDesugaredType());
   }
@@ -671,38 +717,419 @@ QualType LowerTypeVisitor::desugarType(QualType type) {
   return type;
 }
 
-llvm::Optional<bool>
-LowerTypeVisitor::isHLSLRowMajorMatrix(QualType type) const {
-  if (const auto *attrType = type->getAs<AttributedType>()) {
-    switch (auto kind = attrType->getAttrKind()) {
+bool LowerTypeVisitor::isHLSLRowMajorMatrix(QualType type) const {
+  // The type passed in may not be desugared. Check attributes on itself first.
+  bool attrRowMajor = false;
+  if (hlsl::HasHLSLMatOrientation(type, &attrRowMajor))
+    return attrRowMajor;
+
+  // Use the majorness info we recorded before.
+  if (typeMatMajorAttr.hasValue()) {
+    switch (typeMatMajorAttr.getValue()) {
     case AttributedType::attr_hlsl_row_major:
       return true;
     case AttributedType::attr_hlsl_column_major:
       return false;
-      break;
     default:
+      // Only oriented matrices are relevant.
       break;
     }
   }
-  if (const auto *typedefType = type->getAs<TypedefType>()) {
-    return isHLSLRowMajorMatrix(typedefType->desugar());
+  return spvOptions.defaultRowMajor;
+}
+
+bool LowerTypeVisitor::isRowMajorMatrix(QualType type) const {
+  return !isHLSLRowMajorMatrix(type);
+}
+
+llvm::SmallVector<StructType::FieldInfo, 4>
+LowerTypeVisitor::populateLayoutInformation(
+    llvm::ArrayRef<HybridStructType::FieldInfo> fields, SpirvLayoutRule rule) {
+
+  // The resulting vector of fields with proper layout information.
+  llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
+
+  uint32_t offset = 0;
+  for (const auto field : fields) {
+    // The field can only be FieldDecl (for normal structs) or VarDecl (for
+    // HLSLBufferDecls).
+    auto fieldType = field.astType;
+
+    // Lower the field type fist. This call will populate proper matrix
+    // majorness information.
+    StructType::FieldInfo loweredField(lowerType(fieldType, rule, {}),
+                                       field.name);
+
+    // We only need layout information for strcutres with non-void layout rule.
+    if (rule == SpirvLayoutRule::Void) {
+      loweredFields.push_back(loweredField);
+      continue;
+    }
+
+    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
+    std::tie(memberAlignment, memberSize) =
+        getAlignmentAndSize(fieldType, rule, &stride);
+
+    // The next avaiable location after layouting the previos members
+    const uint32_t nextLoc = offset;
+
+    if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      alignUsingHLSLRelaxedLayout(fieldType, memberSize, memberAlignment,
+                                  &offset);
+    } else {
+      offset = roundToPow2(offset, memberAlignment);
+    }
+
+    // The vk::offset attribute takes precedence over all.
+    if (field.vkOffsetAttr) {
+      offset = field.vkOffsetAttr->getOffset();
+    }
+    // The :packoffset() annotation takes precedence over normal layout
+    // calculation.
+    else if (field.packOffsetAttr) {
+      const uint32_t packOffset = field.packOffsetAttr->Subcomponent * 16 +
+                                  field.packOffsetAttr->ComponentOffset * 4;
+      // Do minimal check to make sure the offset specified by packoffset does
+      // not cause overlap.
+      if (packOffset < nextLoc) {
+        emitError("packoffset caused overlap with previous members",
+                  field.packOffsetAttr->Loc);
+      } else {
+        offset = packOffset;
+      }
+    }
+
+    // Each structure-type member must have an Offset Decoration.
+    loweredField.offset = offset;
+    offset += memberSize;
+
+    // Each structure-type member that is a matrix or array-of-matrices must be
+    // decorated with
+    // * A MatrixStride decoration, and
+    // * one of the RowMajor or ColMajor Decorations.
+    if (const auto *arrayType = astContext.getAsConstantArrayType(fieldType)) {
+      // We have an array of matrices as a field, we need to decorate
+      // MatrixStride on the field. So skip possible arrays here.
+      fieldType = arrayType->getElementType();
+    }
+
+    // Non-floating point matrices are represented as arrays of vectors, and
+    // therefore ColMajor and RowMajor decorations should not be applied to
+    // them.
+    QualType elemType = {};
+    if (isMxNMatrix(fieldType, &elemType) && elemType->isFloatingType()) {
+      memberAlignment = memberSize = stride = 0;
+      std::tie(memberAlignment, memberSize) =
+          getAlignmentAndSize(fieldType, rule, &stride);
+
+      loweredField.matrixStride = stride;
+      loweredField.isRowMajor = isRowMajorMatrix(fieldType);
+    }
+
+    loweredFields.push_back(loweredField);
   }
-  if (const auto *arrayType = astContext.getAsArrayType(type)) {
-    return isHLSLRowMajorMatrix(arrayType->getElementType());
+
+  return loweredFields;
+}
+
+void LowerTypeVisitor::alignUsingHLSLRelaxedLayout(QualType fieldType,
+                                                   uint32_t fieldSize,
+                                                   uint32_t fieldAlignment,
+                                                   uint32_t *currentOffset) {
+  QualType vecElemType = {};
+  const bool fieldIsVecType = isVectorType(fieldType, &vecElemType);
+
+  // Adjust according to HLSL relaxed layout rules.
+  // Aligning vectors as their element types so that we can pack a float
+  // and a float3 tightly together.
+  if (fieldIsVecType) {
+    uint32_t scalarAlignment = 0;
+    std::tie(scalarAlignment, std::ignore) =
+        getAlignmentAndSize(vecElemType, SpirvLayoutRule::Void, nullptr);
+    if (scalarAlignment <= 4)
+      fieldAlignment = scalarAlignment;
+  }
+
+  *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
+
+  // Adjust according to HLSL relaxed layout rules.
+  // Bump to 4-component vector alignment if there is a bad straddle
+  if (fieldIsVecType &&
+      improperStraddle(fieldType, fieldSize, *currentOffset)) {
+    fieldAlignment = kStd140Vec4Alignment;
+    *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
   }
-  return getCodeGenOptions().defaultRowMajor;
 }
 
-llvm::Optional<bool> LowerTypeVisitor::isRowMajorMatrix(QualType type) const {
-  // Row/Col majorness only applies to matrices or array of matrices.
-  if (!isMatrixOrArrayOfMatrix(astContext, type))
-    return llvm::None;
+std::pair<uint32_t, uint32_t>
+LowerTypeVisitor::getAlignmentAndSize(QualType type, SpirvLayoutRule rule,
+                                      uint32_t *stride) {
+  // std140 layout rules:
+
+  // 1. If the member is a scalar consuming N basic machine units, the base
+  //    alignment is N.
+  //
+  // 2. If the member is a two- or four-component vector with components
+  //    consuming N basic machine units, the base alignment is 2N or 4N,
+  //    respectively.
+  //
+  // 3. If the member is a three-component vector with components consuming N
+  //    basic machine units, the base alignment is 4N.
+  //
+  // 4. If the member is an array of scalars or vectors, the base alignment and
+  //    array stride are set to match the base alignment of a single array
+  //    element, according to rules (1), (2), and (3), and rounded up to the
+  //    base alignment of a vec4. The array may have padding at the end; the
+  //    base offset of the member following the array is rounded up to the next
+  //    multiple of the base alignment.
+  //
+  // 5. If the member is a column-major matrix with C columns and R rows, the
+  //    matrix is stored identically to an array of C column vectors with R
+  //    components each, according to rule (4).
+  //
+  // 6. If the member is an array of S column-major matrices with C columns and
+  //    R rows, the matrix is stored identically to a row of S X C column
+  //    vectors with R components each, according to rule (4).
+  //
+  // 7. If the member is a row-major matrix with C columns and R rows, the
+  //    matrix is stored identically to an array of R row vectors with C
+  //    components each, according to rule (4).
+  //
+  // 8. If the member is an array of S row-major matrices with C columns and R
+  //    rows, the matrix is stored identically to a row of S X R row vectors
+  //    with C components each, according to rule (4).
+  //
+  // 9. If the member is a structure, the base alignment of the structure is N,
+  //    where N is the largest base alignment value of any of its members, and
+  //    rounded up to the base alignment of a vec4. The individual members of
+  //    this substructure are then assigned offsets by applying this set of
+  //    rules recursively, where the base offset of the first member of the
+  //    sub-structure is equal to the aligned offset of the structure. The
+  //    structure may have padding at the end; the base offset of the member
+  //    following the sub-structure is rounded up to the next multiple of the
+  //    base alignment of the structure.
+  //
+  // 10. If the member is an array of S structures, the S elements of the array
+  //     are laid out in order, according to rule (9).
+  //
+  // This method supports multiple layout rules, all of them modifying the
+  // std140 rules listed above:
+  //
+  // std430:
+  // - Array base alignment and stride does not need to be rounded up to a
+  //   multiple of 16.
+  // - Struct base alignment does not need to be rounded up to a multiple of 16.
+  //
+  // Relaxed std140/std430:
+  // - Vector base alignment is set as its element type's base alignment.
+  //
+  // FxcCTBuffer:
+  // - Vector base alignment is set as its element type's base alignment.
+  // - Arrays/structs do not need to have padding at the end; arrays/structs do
+  //   not affect the base offset of the member following them.
+  //
+  // FxcSBuffer:
+  // - Vector/matrix/array base alignment is set as its element type's base
+  //   alignment.
+  // - Arrays/structs do not need to have padding at the end; arrays/structs do
+  //   not affect the base offset of the member following them.
+  // - Struct base alignment does not need to be rounded up to a multiple of 16.
+
+  const auto desugaredType = desugarType(type);
+  if (desugaredType != type) {
+    auto result = getAlignmentAndSize(desugaredType, rule, stride);
+    // Clear potentially set matrix majorness info
+    typeMatMajorAttr = llvm::None;
+    return result;
+  }
+
+  { // Rule 1
+    QualType ty = {};
+    if (isScalarType(type, &ty))
+      if (const auto *builtinType = ty->getAs<BuiltinType>())
+        switch (builtinType->getKind()) {
+        case BuiltinType::Bool:
+        case BuiltinType::Int:
+        case BuiltinType::UInt:
+        case BuiltinType::Float:
+          return {4, 4};
+        case BuiltinType::Double:
+        case BuiltinType::LongLong:
+        case BuiltinType::ULongLong:
+          return {8, 8};
+        case BuiltinType::Min12Int:
+        case BuiltinType::Min16Int:
+        case BuiltinType::Min16UInt:
+        case BuiltinType::Min16Float:
+        case BuiltinType::Min10Float: {
+          if (spvOptions.enable16BitTypes)
+            return {2, 2};
+          else
+            return {4, 4};
+        }
+        // the 'Half' enum always represents 16-bit floats.
+        // int16_t and uint16_t map to Short and UShort.
+        case BuiltinType::Short:
+        case BuiltinType::UShort:
+        case BuiltinType::Half:
+          return {2, 2};
+        // 'HalfFloat' always represents 32-bit floats.
+        case BuiltinType::HalfFloat:
+          return {4, 4};
+        default:
+          emitError("alignment and size calculation for type %0 unimplemented")
+              << type;
+          return {0, 0};
+        }
+  }
+
+  { // Rule 2 and 3
+    QualType elemType = {};
+    uint32_t elemCount = {};
+    if (isVectorType(type, &elemType, &elemCount)) {
+      uint32_t alignment = 0, size = 0;
+      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+      // Use element alignment for fxc rules
+      if (rule != SpirvLayoutRule::FxcCTBuffer &&
+          rule != SpirvLayoutRule::FxcSBuffer)
+        alignment = (elemCount == 3 ? 4 : elemCount) * size;
+
+      return {alignment, elemCount * size};
+    }
+  }
+
+  { // Rule 5 and 7
+    QualType elemType = {};
+    uint32_t rowCount = 0, colCount = 0;
+    if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
+      uint32_t alignment = 0, size = 0;
+      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+
+      // Matrices are treated as arrays of vectors:
+      // The base alignment and array stride are set to match the base alignment
+      // of a single array element, according to rules 1, 2, and 3, and rounded
+      // up to the base alignment of a vec4.
+      bool isRowMajor = isRowMajorMatrix(type);
+
+      const uint32_t vecStorageSize = isRowMajor ? rowCount : colCount;
+
+      if (rule == SpirvLayoutRule::FxcSBuffer) {
+        *stride = vecStorageSize * size;
+        // Use element alignment for fxc structured buffers
+        return {alignment, rowCount * colCount * size};
+      }
 
-  const auto hlslRowMajor = isHLSLRowMajorMatrix(type);
-  if (!hlslRowMajor.hasValue())
-    return hlslRowMajor;
+      alignment *= (vecStorageSize == 3 ? 4 : vecStorageSize);
+      if (rule == SpirvLayoutRule::GLSLStd140 ||
+          rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+          rule == SpirvLayoutRule::FxcCTBuffer) {
+        alignment = roundToPow2(alignment, kStd140Vec4Alignment);
+      }
+      *stride = alignment;
+      size = (isRowMajor ? colCount : rowCount) * alignment;
+
+      return {alignment, size};
+    }
+  }
+
+  // Rule 9
+  if (const auto *structType = type->getAs<RecordType>()) {
+    // Special case for handling empty structs, whose size is 0 and has no
+    // requirement over alignment (thus 1).
+    if (structType->getDecl()->field_empty())
+      return {1, 0};
+
+    uint32_t maxAlignment = 0;
+    uint32_t structSize = 0;
+
+    for (const auto *field : structType->getDecl()->fields()) {
+      uint32_t memberAlignment = 0, memberSize = 0;
+      std::tie(memberAlignment, memberSize) =
+          getAlignmentAndSize(field->getType(), rule, stride);
+
+      if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+          rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
+          rule == SpirvLayoutRule::FxcCTBuffer) {
+        alignUsingHLSLRelaxedLayout(field->getType(), memberSize,
+                                    memberAlignment, &structSize);
+      } else {
+        structSize = roundToPow2(structSize, memberAlignment);
+      }
+
+      // Reset the current offset to the one specified in the source code
+      // if exists. It's debatable whether we should do sanity check here.
+      // If the developers want manually control the layout, we leave
+      // everything to them.
+      if (const auto *offsetAttr = field->getAttr<VKOffsetAttr>()) {
+        structSize = offsetAttr->getOffset();
+      }
+
+      // The base alignment of the structure is N, where N is the largest
+      // base alignment value of any of its members...
+      maxAlignment = std::max(maxAlignment, memberAlignment);
+      structSize += memberSize;
+    }
+
+    if (rule == SpirvLayoutRule::GLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      // ... and rounded up to the base alignment of a vec4.
+      maxAlignment = roundToPow2(maxAlignment, kStd140Vec4Alignment);
+    }
+
+    if (rule != SpirvLayoutRule::FxcCTBuffer &&
+        rule != SpirvLayoutRule::FxcSBuffer) {
+      // The base offset of the member following the sub-structure is rounded up
+      // to the next multiple of the base alignment of the structure.
+      structSize = roundToPow2(structSize, maxAlignment);
+    }
+    return {maxAlignment, structSize};
+  }
+
+  // Rule 4, 6, 8, and 10
+  if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
+    const auto elemCount = arrayType->getSize().getZExtValue();
+    uint32_t alignment = 0, size = 0;
+    std::tie(alignment, size) =
+        getAlignmentAndSize(arrayType->getElementType(), rule, stride);
+
+    if (rule == SpirvLayoutRule::FxcSBuffer) {
+      *stride = size;
+      // Use element alignment for fxc structured buffers
+      return {alignment, size * elemCount};
+    }
+
+    if (rule == SpirvLayoutRule::GLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      // The base alignment and array stride are set to match the base alignment
+      // of a single array element, according to rules 1, 2, and 3, and rounded
+      // up to the base alignment of a vec4.
+      alignment = roundToPow2(alignment, kStd140Vec4Alignment);
+    }
+    if (rule == SpirvLayoutRule::FxcCTBuffer) {
+      // In fxc cbuffer/tbuffer packing rules, arrays does not affect the data
+      // packing after it. But we still need to make sure paddings are inserted
+      // internally if necessary.
+      *stride = roundToPow2(size, alignment);
+      size += *stride * (elemCount - 1);
+    } else {
+      // Need to round size up considering stride for scalar types
+      size = roundToPow2(size, alignment);
+      *stride = size; // Use size instead of alignment here for Rule 10
+      size *= elemCount;
+      // The base offset of the member following the array is rounded up to the
+      // next multiple of the base alignment.
+      size = roundToPow2(size, alignment);
+    }
+
+    return {alignment, size};
+  }
 
-  return !hlslRowMajor.getValue();
+  emitError("alignment and size calculation for type %0 unimplemented") << type;
+  return {0, 0};
 }
 
 } // namespace spirv

+ 24 - 30
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -218,8 +218,8 @@ SpirvContext::getSampledImageType(QualType image) {
 
 const ArrayType *SpirvContext::getArrayType(const SpirvType *elemType,
                                             uint32_t elemCount,
-                                            llvm::Optional<bool> rowMajorElem) {
-  ArrayType type(elemType, elemCount, rowMajorElem);
+                                            llvm::Optional<uint32_t> arrayStride) {
+  ArrayType type(elemType, elemCount, arrayStride);
   auto found = std::find_if(
       arrayTypes.begin(), arrayTypes.end(),
       [&type](const ArrayType *cachedType) { return type == *cachedType; });
@@ -227,33 +227,25 @@ const ArrayType *SpirvContext::getArrayType(const SpirvType *elemType,
   if (found != arrayTypes.end())
     return *found;
 
-  arrayTypes.push_back(new (this) ArrayType(elemType, elemCount, rowMajorElem));
+  arrayTypes.push_back(new (this) ArrayType(elemType, elemCount, arrayStride));
   return arrayTypes.back();
-
-  /*
-  auto foundElemType = arrayTypes.find(elemType);
-
-  if (foundElemType != arrayTypes.end()) {
-    auto &elemTypeMap = foundElemType->second;
-    auto foundCount = elemTypeMap.find(elemCount);
-
-    if (foundCount != elemTypeMap.end())
-      return foundCount->second;
-  }
-
-  return arrayTypes[elemType][elemCount] =
-             new (this) ArrayType(elemType, elemCount);
-  */
 }
 
 const RuntimeArrayType *
-SpirvContext::getRuntimeArrayType(const SpirvType *elemType) {
-  auto found = runtimeArrayTypes.find(elemType);
+SpirvContext::getRuntimeArrayType(const SpirvType *elemType,
+                                  llvm::Optional<uint32_t> arrayStride) {
+  RuntimeArrayType type(elemType, arrayStride);
+  auto found = std::find_if(runtimeArrayTypes.begin(), runtimeArrayTypes.end(),
+                            [&type](const RuntimeArrayType *cachedType) {
+                              return type == *cachedType;
+                            });
 
   if (found != runtimeArrayTypes.end())
-    return found->second;
+    return *found;
 
-  return runtimeArrayTypes[elemType] = new (this) RuntimeArrayType(elemType);
+  runtimeArrayTypes.push_back(new (this)
+                                  RuntimeArrayType(elemType, arrayStride));
+  return runtimeArrayTypes.back();
 }
 
 const StructType *
@@ -375,13 +367,14 @@ SpirvContext::getFunctionType(QualType ret,
 
 const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
   // Create a uint RuntimeArray.
-  const auto *raType = getRuntimeArrayType(getUIntType(32));
+  const auto *raType =
+      getRuntimeArrayType(getUIntType(32), /* ArrayStride */ 4);
 
   // Create a struct containing the runtime array as its only member.
-  return getStructType({StructType::FieldInfo(raType)},
-                       isWritable ? "type.RWByteAddressBuffer"
-                                  : "type.ByteAddressBuffer",
-                       !isWritable, StructInterfaceType::StorageBuffer);
+  return getStructType(
+      {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0)},
+      isWritable ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer",
+      !isWritable, StructInterfaceType::StorageBuffer);
 }
 
 const StructType *SpirvContext::getACSBufferCounterType() {
@@ -389,9 +382,10 @@ const StructType *SpirvContext::getACSBufferCounterType() {
   const auto *int32Type = getSIntType(32);
 
   // Create a struct containing the integer counter as its only member.
-  const StructType *type = getStructType(
-      {StructType::FieldInfo(int32Type, "counter")}, "type.ACSBuffer.counter",
-      /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
+  const StructType *type =
+      getStructType({StructType::FieldInfo(int32Type, "counter", /*offset*/ 0)},
+                    "type.ACSBuffer.counter",
+                    /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
 
   return type;
 }

+ 16 - 5
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -189,9 +189,14 @@ bool ImageType::operator==(const ImageType &that) const {
 
 bool ArrayType::operator==(const ArrayType &that) const {
   return elementType == that.elementType && elementCount == that.elementCount &&
-         rowMajorElem.hasValue() == that.rowMajorElem.hasValue() &&
-         (!rowMajorElem.hasValue() ||
-          rowMajorElem.getValue() == that.rowMajorElem.getValue());
+         stride.hasValue() == that.stride.hasValue() &&
+         (!stride.hasValue() || stride.getValue() == that.stride.getValue());
+}
+
+bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const {
+  return elementType == that.elementType &&
+         stride.hasValue() == that.stride.hasValue() &&
+         (!stride.hasValue() || stride.getValue() == that.stride.getValue());
 }
 
 StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
@@ -202,9 +207,15 @@ StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
 
 bool StructType::FieldInfo::
 operator==(const StructType::FieldInfo &that) const {
-  return type == that.type && vkOffsetAttr == that.vkOffsetAttr &&
-         packOffsetAttr == that.packOffsetAttr &&
+  return type == that.type && offset.hasValue() == that.offset.hasValue() &&
+         matrixStride.hasValue() == that.matrixStride.hasValue() &&
          isRowMajor.hasValue() == that.isRowMajor.hasValue() &&
+         // Either not have offset value, or have the same value
+         (!offset.hasValue() || offset.getValue() == that.offset.getValue()) &&
+         // Either not have matrix stride value, or have the same value
+         (!matrixStride.hasValue() ||
+          matrixStride.getValue() == that.matrixStride.getValue()) &&
+         // Either not have row major value, or have the same value
          (!isRowMajor.hasValue() ||
           isRowMajor.getValue() == that.isRowMajor.getValue());
 }

+ 0 - 13
tools/clang/test/CodeGenSPIRV/bezier.domain.hlsl2spv

@@ -109,19 +109,6 @@ DS_OUTPUT BezierEvalDS( HS_CONSTANT_DATA_OUTPUT input,
 //                OpDecorate %out_var_TEXCOORD Location 1
 //                OpDecorate %out_var_TANGENT Location 2
 //                OpDecorate %out_var_BITANGENT Location 3
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 0 Offset 0
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 1 Offset 16
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 2 Offset 32
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 3 Offset 96
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 4 Offset 128
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 5 Offset 192
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 6 Offset 256
-//                OpMemberDecorate %BEZIER_CONTROL_POINT 0 Offset 0
-//                OpMemberDecorate %DS_OUTPUT 0 Offset 0
-//                OpMemberDecorate %DS_OUTPUT 1 Offset 16
-//                OpMemberDecorate %DS_OUTPUT 2 Offset 32
-//                OpMemberDecorate %DS_OUTPUT 3 Offset 48
-//                OpMemberDecorate %DS_OUTPUT 4 Offset 64
 //        %uint = OpTypeInt 32 0
 //      %uint_4 = OpConstant %uint 4
 //       %float = OpTypeFloat 32

+ 0 - 11
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -126,17 +126,6 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 //                OpDecorate %out_var_TANVCORNER Location 9
 //                OpDecorate %out_var_TANWEIGHTS Location 13
 //                OpDecorate %out_var_TEXCOORD Location 14
-//                OpMemberDecorate %VS_CONTROL_POINT_OUTPUT 0 Offset 0
-//                OpMemberDecorate %VS_CONTROL_POINT_OUTPUT 1 Offset 16
-//                OpMemberDecorate %VS_CONTROL_POINT_OUTPUT 2 Offset 32
-//                OpMemberDecorate %BEZIER_CONTROL_POINT 0 Offset 0
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 0 Offset 0
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 1 Offset 16
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 2 Offset 32
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 3 Offset 96
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 4 Offset 128
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 5 Offset 192
-//                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 6 Offset 256
 //        %uint = OpTypeInt 32 0
 //      %uint_0 = OpConstant %uint 0
 //       %float = OpTypeFloat 32

+ 0 - 2
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -41,8 +41,6 @@ PSInput main(float4 position: POSITION, float4 color: COLOR) {
 //                OpDecorate %in_var_POSITION Location 0
 //                OpDecorate %in_var_COLOR Location 1
 //                OpDecorate %out_var_COLOR Location 0
-//                OpMemberDecorate %PSInput 0 Offset 0
-//                OpMemberDecorate %PSInput 1 Offset 16
 //         %int = OpTypeInt 32 1
 //       %int_0 = OpConstant %int 0
 //       %int_1 = OpConstant %int 1

+ 2 - 2
tools/clang/test/CodeGenSPIRV/vk.layout.16bit-types.tbuffer.hlsl

@@ -5,12 +5,12 @@
 // CHECK: OpExtension "SPV_KHR_16bit_storage"
 
 // CHECK: OpMemberDecorate %type_MyTBuffer 0 Offset 0
+// CHECK: OpMemberDecorate %type_MyTBuffer 0 NonWritable
 // CHECK: OpMemberDecorate %type_MyTBuffer 1 Offset 4
+// CHECK: OpMemberDecorate %type_MyTBuffer 1 NonWritable
 // CHECK: OpMemberDecorate %type_MyTBuffer 2 Offset 8
 // CHECK: OpMemberDecorate %type_MyTBuffer 2 MatrixStride 8
 // CHECK: OpMemberDecorate %type_MyTBuffer 2 RowMajor
-// CHECK: OpMemberDecorate %type_MyTBuffer 0 NonWritable
-// CHECK: OpMemberDecorate %type_MyTBuffer 1 NonWritable
 // CHECK: OpMemberDecorate %type_MyTBuffer 2 NonWritable
 // CHECK: OpDecorate %type_MyTBuffer BufferBlock
 

+ 1 - 2
tools/clang/test/CodeGenSPIRV/vk.layout.tbuffer.std430.hlsl

@@ -24,9 +24,8 @@
 // CHECK: OpDecorate %_arr_S_uint_2 ArrayStride 288
 
 // CHECK: OpMemberDecorate %type_myTbuffer 0 Offset 0
-// CHECK: OpMemberDecorate %type_myTbuffer 1 Offset 576
-
 // CHECK: OpMemberDecorate %type_myTbuffer 0 NonWritable
+// CHECK: OpMemberDecorate %type_myTbuffer 1 Offset 576
 // CHECK: OpMemberDecorate %type_myTbuffer 1 NonWritable
 
 // CHECK: OpDecorate %type_myTbuffer BufferBlock

+ 1 - 2
tools/clang/test/CodeGenSPIRV/vk.layout.texture-buffer.std430.hlsl

@@ -24,9 +24,8 @@
 // CHECK: OpDecorate %_arr_S_uint_2 ArrayStride 288
 
 // CHECK: OpMemberDecorate %type_TextureBuffer_T 0 Offset 0
-// CHECK: OpMemberDecorate %type_TextureBuffer_T 1 Offset 576
-
 // CHECK: OpMemberDecorate %type_TextureBuffer_T 0 NonWritable
+// CHECK: OpMemberDecorate %type_TextureBuffer_T 1 Offset 576
 // CHECK: OpMemberDecorate %type_TextureBuffer_T 1 NonWritable
 
 // CHECK: OpDecorate %type_TextureBuffer_T BufferBlock