瀏覽代碼

[spirv] Emit decorations for types.

Ehsan 6 年之前
父節點
當前提交
39705bb0a0

+ 41 - 7
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -40,11 +40,12 @@ struct SpirvLayoutRuleDenseMapInfo {
 
 class EmitTypeHandler {
 public:
-  EmitTypeHandler(SpirvContext &c, std::vector<uint32_t> *decVec,
+  EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvCtx,
+                  std::vector<uint32_t> *decVec,
                   std::vector<uint32_t> *typesVec,
                   const std::function<uint32_t()> &takeNextIdFn)
-      : context(c), annotationsBinary(decVec), typeConstantBinary(typesVec),
-        takeNextIdFunction(takeNextIdFn) {
+      : astContext(astCtx), spirvContext(spvCtx), annotationsBinary(decVec),
+        typeConstantBinary(typesVec), takeNextIdFunction(takeNextIdFn) {
     assert(decVec);
     assert(typesVec);
   }
@@ -53,6 +54,13 @@ public:
   EmitTypeHandler(const EmitTypeHandler &) = delete;
   EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
 
+  // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
+  // targetting the given type. Uses the given decoration kind and its
+  // parameters.
+  void emitDecoration(uint32_t typeResultId, spv::Decoration,
+                      llvm::ArrayRef<uint32_t> decorationParams,
+                      uint32_t memberIndex = 0);
+
   // Emits the instruction for the given type into the typeConstantBinary and
   // returns the result-id for the type.
   uint32_t emitType(const SpirvType *, SpirvLayoutRule);
@@ -64,9 +72,34 @@ private:
   void initTypeInstruction(spv::Op op);
   void finalizeTypeInstruction();
 
+  // Methods associated with layout calculations ----
+
+  // TODO: This function should be merged into the Type class hierarchy.
+  std::pair<uint32_t, uint32_t> getAlignmentAndSize(const SpirvType *type,
+                                                    SpirvLayoutRule rule,
+                                                    uint32_t *stride);
+
+  void alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
+                                   uint32_t fieldSize, uint32_t fieldAlignment,
+                                   uint32_t *currentOffset);
+
+  void emitLayoutDecorations(const StructType *, SpirvLayoutRule);
+
+private:
+  /// Emits error to the diagnostic engine associated with this visitor.
+  template <unsigned N>
+  DiagnosticBuilder emitError(const char (&message)[N],
+                              SourceLocation loc = {}) {
+    const auto diagId = astContext.getDiagnostics().getCustomDiagID(
+        clang::DiagnosticsEngine::Error, message);
+    return astContext.getDiagnostics().Report(loc, diagId);
+  }
+
 private:
-  SpirvContext &context;
+  ASTContext &astContext;
+  SpirvContext &spirvContext;
   std::vector<uint32_t> curTypeInst;
+  std::vector<uint32_t> curDecorationInst;
   std::vector<uint32_t> *annotationsBinary;
   std::vector<uint32_t> *typeConstantBinary;
   std::function<uint32_t()> takeNextIdFunction;
@@ -82,9 +115,10 @@ private:
 /// representation.
 class EmitVisitor : public Visitor {
 public:
-  EmitVisitor(const SpirvCodeGenOptions &opts, SpirvContext &ctx)
-      : Visitor(opts, ctx), id(0),
-        typeHandler(ctx, &annotationsBinary, &typeConstantBinary,
+  EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
+              const SpirvCodeGenOptions &opts)
+      : Visitor(opts, spvCtx), id(0),
+        typeHandler(astCtx, spvCtx, &annotationsBinary, &typeConstantBinary,
                     [this]() -> uint32_t { return takeNextId(); }) {}
 
   // Visit different SPIR-V constructs for emitting.

+ 5 - 3
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -188,9 +188,11 @@ public:
   const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount);
   const RuntimeArrayType *getRuntimeArrayType(const SpirvType *elemType);
 
-  const StructType *getStructType(
-      llvm::ArrayRef<const SpirvType *> fieldTypes, llvm::StringRef name,
-      llvm::ArrayRef<llvm::StringRef> fieldNames = {}, bool isReadOnly = false);
+  const StructType *
+  getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
+                llvm::StringRef name, bool isReadOnly = false,
+                StructType::InterfaceType interfaceType =
+                    StructType::InterfaceType::InternalStorage);
 
   const SpirvPointerType *getPointerType(const SpirvType *pointee,
                                          spv::StorageClass);

+ 54 - 27
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -14,10 +14,10 @@
 #include <vector>
 
 #include "spirv/unified1/spirv.hpp11"
+#include "clang/AST/Attr.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Casting.h"
 
 namespace clang {
 namespace spirv {
@@ -73,8 +73,12 @@ public:
     return t->getKind() == TK_Integer || t->getKind() == TK_Float;
   }
 
+  uint32_t getBitwidth() const { return bitwidth; }
+
 protected:
-  NumericalType(Kind k) : ScalarType(k) {}
+  NumericalType(Kind k, uint32_t width) : ScalarType(k), bitwidth(width) {}
+
+  uint32_t bitwidth;
 };
 
 class BoolType : public ScalarType {
@@ -87,28 +91,21 @@ public:
 class IntegerType : public NumericalType {
 public:
   IntegerType(uint32_t numBits, bool sign)
-      : NumericalType(TK_Integer), bitwidth(numBits), isSigned(sign) {}
+      : NumericalType(TK_Integer, numBits), isSigned(sign) {}
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Integer; }
 
-  uint32_t getBitwidth() const { return bitwidth; }
   bool isSignedInt() const { return isSigned; }
 
 private:
-  uint32_t bitwidth;
   bool isSigned;
 };
 
 class FloatType : public NumericalType {
 public:
-  FloatType(uint32_t numBits) : NumericalType(TK_Float), bitwidth(numBits) {}
+  FloatType(uint32_t numBits) : NumericalType(TK_Float, numBits) {}
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Float; }
-
-  uint32_t getBitwidth() const { return bitwidth; }
-
-private:
-  uint32_t bitwidth;
 };
 
 class VectorType : public SpirvType {
@@ -118,9 +115,7 @@ public:
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Vector; }
 
-  const SpirvType *getElementType() const {
-    return llvm::cast<SpirvType>(elementType);
-  }
+  const SpirvType *getElementType() const { return elementType; }
   uint32_t getElementCount() const { return elementCount; }
 
 private:
@@ -136,12 +131,16 @@ public:
 
   bool operator==(const MatrixType &that) const;
 
-  const SpirvType *getVecType() const {
-    return llvm::cast<SpirvType>(vectorType);
+  const SpirvType *getVecType() const { return vectorType; }
+  const SpirvType *getElementType() const {
+    return vectorType->getElementType();
   }
   uint32_t getVecCount() const { return vectorCount; }
   bool isRowMajorMat() const { return isRowMajor; }
 
+  uint32_t numCols() const { return vectorCount; }
+  uint32_t numRows() const { return vectorType->getElementCount(); }
+
 private:
   const VectorType *vectorType;
   uint32_t vectorCount;
@@ -174,9 +173,7 @@ public:
 
   bool operator==(const ImageType &that) const;
 
-  const SpirvType *getSampledType() const {
-    return llvm::cast<SpirvType>(sampledType);
-  }
+  const SpirvType *getSampledType() const { return sampledType; }
   spv::Dim getDimension() const { return dimension; }
   WithDepth getDepth() const { return imageDepth; }
   bool isArrayedImage() const { return isArrayed; }
@@ -210,7 +207,7 @@ public:
     return t->getKind() == TK_SampledImage;
   }
 
-  const ImageType *getImageType() const { return imageType; }
+  const SpirvType *getImageType() const { return imageType; }
 
 private:
   const ImageType *imageType;
@@ -248,16 +245,42 @@ private:
 
 class StructType : public SpirvType {
 public:
-  StructType(llvm::ArrayRef<const SpirvType *> memberTypes,
-             llvm::StringRef name, llvm::ArrayRef<llvm::StringRef> memberNames,
-             bool isReadOnly);
+  enum class InterfaceType : uint32_t {
+    InternalStorage = 0,
+    StorageBuffer = 1,
+    UniformBuffer = 2,
+  };
+
+  struct FieldInfo {
+  public:
+    FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
+              clang::VKOffsetAttr *offset = nullptr,
+              hlsl::ConstantPacking *packOffset = nullptr)
+        : type(type_), name(name_), vkOffsetAttr(offset),
+          packOffsetAttr(packOffset) {}
+
+    bool operator==(const FieldInfo &that) const;
+
+    // The field's type.
+    const SpirvType *type;
+    // The field's name.
+    std::string name;
+    // vk::offset attributes associated with this field.
+    clang::VKOffsetAttr *vkOffsetAttr;
+    // :packoffset() annotations associated with this field.
+    hlsl::ConstantPacking *packOffsetAttr;
+  };
+
+  StructType(llvm::ArrayRef<FieldInfo> fields, llvm::StringRef name,
+             bool isReadOnly,
+             InterfaceType interfaceType = InterfaceType::InternalStorage);
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Struct; }
 
+  llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
   bool isReadOnly() const { return readOnly; }
   std::string getStructName() const { return structName; }
-  llvm::ArrayRef<const SpirvType *> getFieldTypes() const { return fieldTypes; }
-  llvm::ArrayRef<std::string> getFieldNames() const { return fieldNames; }
+  InterfaceType getInterfaceType() const { return interfaceType; }
 
   bool operator==(const StructType &that) const;
 
@@ -266,10 +289,14 @@ private:
   // struct names and field names. That basically means we cannot ignore these
   // names when considering unification. Otherwise, reflection will be confused.
 
+  llvm::SmallVector<FieldInfo, 8> fields;
   std::string structName;
-  llvm::SmallVector<const SpirvType *, 8> fieldTypes;
-  llvm::SmallVector<std::string, 8> fieldNames;
   bool readOnly;
+  // Indicates the interface type of this structure. If this structure is a
+  // storage buffer shader-interface, it will be decorated with 'BufferBlock'.
+  // If this structure is a uniform buffer shader-interface, it will be
+  // decorated with 'Block'.
+  InterfaceType interfaceType;
 };
 
 class SpirvPointerType : public SpirvType {

+ 452 - 3
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -17,6 +17,38 @@
 #include "clang/SPIRV/String.h"
 
 namespace {
+
+/// The alignment for 4-component float vectors.
+constexpr uint32_t kStd140Vec4Alignment = 16u;
+
+/// Rounds the given value up to the given power of 2.
+inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
+  assert(pow2 != 0);
+  return (val + pow2 - 1) & ~(pow2 - 1);
+}
+
+/// Returns true if the given vector type (of the given size) crosses the
+/// 4-component vector boundary if placed at the given offset.
+bool improperStraddle(const clang::spirv::VectorType *type, int size,
+                      int offset) {
+  return size <= 16 ? offset / 16 != (offset + size - 1) / 16
+                    : offset % 16 != 0;
+}
+
+bool isAKindOfStructuredOrByteBuffer(const clang::spirv::SpirvType *type) {
+  // Strip outer arrayness first
+  while (llvm::isa<clang::spirv::ArrayType>(type))
+    type = llvm::cast<clang::spirv::ArrayType>(type)->getElementType();
+
+  // They are structures with the first member that is of RuntimeArray type.
+  if (auto *structType = llvm::dyn_cast<clang::spirv::StructType>(type))
+    return structType->getFields().size() == 1 &&
+           llvm::isa<clang::spirv::RuntimeArrayType>(
+               structType->getFields()[0].type);
+
+  return false;
+}
+
 uint32_t zeroExtendTo32Bits(uint16_t value) {
   // TODO: The ordering of the 2 words depends on the endian-ness of the host
   // machine. Assuming Little Endian at the moment.
@@ -956,7 +988,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   else if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
     // Emit the OpConstant instruction that is needed to get the result-id for
     // the array length.
-    auto *constant = context.getConstantUint32(arrayType->getElementCount());
+    SpirvConstant *constant =
+        spirvContext.getConstantUint32(arrayType->getElementCount());
     if (constant->getResultId() == 0) {
       constant->setResultId(takeNextIdFunction());
     }
@@ -975,6 +1008,16 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     curTypeInst.push_back(elemTypeId);
     curTypeInst.push_back(constant->getResultId());
     finalizeTypeInstruction();
+
+    // ArrayStride decoration is needed for array types, but we won't have
+    // stride information for structured/byte buffers since they contain runtime
+    // arrays.
+    if (rule != SpirvLayoutRule::Void &&
+        !isAKindOfStructuredOrByteBuffer(type)) {
+      uint32_t stride = 0;
+      (void)getAlignmentAndSize(type, rule, &stride);
+      emitDecoration(id, spv::Decoration::ArrayStride, {stride});
+    }
   }
   // RuntimeArray types
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
@@ -983,17 +1026,39 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
     finalizeTypeInstruction();
+
+    // ArrayStride decoration is needed for runtime array types.
+    if (rule != SpirvLayoutRule::Void) {
+      uint32_t stride = 0;
+      (void)getAlignmentAndSize(type, rule, &stride);
+      emitDecoration(id, spv::Decoration::ArrayStride, {stride});
+    }
   }
   // Structure types
   else if (const auto *structType = dyn_cast<StructType>(type)) {
     llvm::SmallVector<uint32_t, 4> fieldTypeIds;
-    for (auto *fieldType : structType->getFieldTypes())
-      fieldTypeIds.push_back(emitType(fieldType, rule));
+    for (auto &field : structType->getFields())
+      fieldTypeIds.push_back(emitType(field.type, rule));
     initTypeInstruction(spv::Op::OpTypeStruct);
     curTypeInst.push_back(id);
     for (auto fieldTypeId : fieldTypeIds)
       curTypeInst.push_back(fieldTypeId);
     finalizeTypeInstruction();
+
+    // Emit the layout decorations for the structure.
+    emitLayoutDecorations(structType, rule);
+
+    // Emit NonWritable decorations
+    if (structType->isReadOnly())
+      for (size_t i = 0; i < structType->getFields().size(); ++i)
+        emitDecoration(id, spv::Decoration::NonWritable, {}, i);
+
+    // Emit Block or BufferBlock decorations if necessary.
+    auto interfaceType = structType->getInterfaceType();
+    if (interfaceType == StructType::InterfaceType::StorageBuffer)
+      emitDecoration(id, spv::Decoration::BufferBlock, {});
+    else if (interfaceType == StructType::InterfaceType::UniformBuffer)
+      emitDecoration(id, spv::Decoration::Block, {});
   }
   // Pointer types
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
@@ -1026,5 +1091,389 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   return id;
 }
 
+std::pair<uint32_t, uint32_t>
+EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
+                                     SpirvLayoutRule rule, uint32_t *stride) {
+  // std140 layout rules:
+
+  // 1. If the member is a scalar consuming N basic machine units, the base
+  //    alignment is N.
+  //
+  // 2. If the member is a two- or four-component vector with components
+  //    consuming N basic machine units, the base alignment is 2N or 4N,
+  //    respectively.
+  //
+  // 3. If the member is a three-component vector with components consuming N
+  //    basic machine units, the base alignment is 4N.
+  //
+  // 4. If the member is an array of scalars or vectors, the base alignment and
+  //    array stride are set to match the base alignment of a single array
+  //    element, according to rules (1), (2), and (3), and rounded up to the
+  //    base alignment of a vec4. The array may have padding at the end; the
+  //    base offset of the member following the array is rounded up to the next
+  //    multiple of the base alignment.
+  //
+  // 5. If the member is a column-major matrix with C columns and R rows, the
+  //    matrix is stored identically to an array of C column vectors with R
+  //    components each, according to rule (4).
+  //
+  // 6. If the member is an array of S column-major matrices with C columns and
+  //    R rows, the matrix is stored identically to a row of S X C column
+  //    vectors with R components each, according to rule (4).
+  //
+  // 7. If the member is a row-major matrix with C columns and R rows, the
+  //    matrix is stored identically to an array of R row vectors with C
+  //    components each, according to rule (4).
+  //
+  // 8. If the member is an array of S row-major matrices with C columns and R
+  //    rows, the matrix is stored identically to a row of S X R row vectors
+  //    with C components each, according to rule (4).
+  //
+  // 9. If the member is a structure, the base alignment of the structure is N,
+  //    where N is the largest base alignment value of any of its members, and
+  //    rounded up to the base alignment of a vec4. The individual members of
+  //    this substructure are then assigned offsets by applying this set of
+  //    rules recursively, where the base offset of the first member of the
+  //    sub-structure is equal to the aligned offset of the structure. The
+  //    structure may have padding at the end; the base offset of the member
+  //    following the sub-structure is rounded up to the next multiple of the
+  //    base alignment of the structure.
+  //
+  // 10. If the member is an array of S structures, the S elements of the array
+  //     are laid out in order, according to rule (9).
+  //
+  // This method supports multiple layout rules, all of them modifying the
+  // std140 rules listed above:
+  //
+  // std430:
+  // - Array base alignment and stride does not need to be rounded up to a
+  //   multiple of 16.
+  // - Struct base alignment does not need to be rounded up to a multiple of 16.
+  //
+  // Relaxed std140/std430:
+  // - Vector base alignment is set as its element type's base alignment.
+  //
+  // FxcCTBuffer:
+  // - Vector base alignment is set as its element type's base alignment.
+  // - Arrays/structs do not need to have padding at the end; arrays/structs do
+  //   not affect the base offset of the member following them.
+  //
+  // FxcSBuffer:
+  // - Vector/matrix/array base alignment is set as its element type's base
+  //   alignment.
+  // - Arrays/structs do not need to have padding at the end; arrays/structs do
+  //   not affect the base offset of the member following them.
+  // - Struct base alignment does not need to be rounded up to a multiple of 16.
+
+  { // Rule 1
+    if (isa<BoolType>(type))
+      return {4, 4};
+    // Integer and Float types are NumericalType
+    if (auto *numericType = dyn_cast<NumericalType>(type)) {
+      switch (numericType->getBitwidth()) {
+      case 64:
+        return {8, 8};
+      case 32:
+        return {4, 4};
+      case 16:
+        return {2, 2};
+      default:
+        emitError("alignment and size calculation unimplemented for type");
+        return {0, 0};
+      }
+    }
+  }
+
+  { // Rule 2 and 3
+    if (auto *vecType = dyn_cast<VectorType>(type)) {
+      uint32_t alignment = 0, size = 0;
+      uint32_t elemCount = vecType->getElementCount();
+      const SpirvType *elemType = vecType->getElementType();
+      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+      // Use element alignment for fxc rules
+      if (rule != SpirvLayoutRule::FxcCTBuffer &&
+          rule != SpirvLayoutRule::FxcSBuffer)
+        alignment = (elemCount == 3 ? 4 : elemCount) * size;
+
+      return {alignment, elemCount * size};
+    }
+  }
+
+  { // Rule 5 and 7
+    if (auto *matType = dyn_cast<MatrixType>(type)) {
+      const SpirvType *elemType = matType->getElementType();
+      uint32_t rowCount = matType->numRows();
+      uint32_t colCount = matType->numCols();
+      uint32_t alignment = 0, size = 0;
+      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+
+      // Matrices are treated as arrays of vectors:
+      // The base alignment and array stride are set to match the base alignment
+      // of a single array element, according to rules 1, 2, and 3, and rounded
+      // up to the base alignment of a vec4.
+      bool isRowMajor = matType->isRowMajorMat();
+
+      const uint32_t vecStorageSize = isRowMajor ? colCount : rowCount;
+
+      if (rule == SpirvLayoutRule::FxcSBuffer) {
+        *stride = vecStorageSize * size;
+        // Use element alignment for fxc structured buffers
+        return {alignment, rowCount * colCount * size};
+      }
+
+      alignment *= (vecStorageSize == 3 ? 4 : vecStorageSize);
+      if (rule == SpirvLayoutRule::GLSLStd140 ||
+          rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+          rule == SpirvLayoutRule::FxcCTBuffer) {
+        alignment = roundToPow2(alignment, kStd140Vec4Alignment);
+      }
+      *stride = alignment;
+      size = (isRowMajor ? rowCount : colCount) * alignment;
+
+      return {alignment, size};
+    }
+  }
+
+  // Rule 9
+  if (auto *structType = dyn_cast<StructType>(type)) {
+    // Special case for handling empty structs, whose size is 0 and has no
+    // requirement over alignment (thus 1).
+    if (structType->getFields().size() == 0)
+      return {1, 0};
+
+    uint32_t maxAlignment = 0;
+    uint32_t structSize = 0;
+
+    for (auto &field : structType->getFields()) {
+      uint32_t memberAlignment = 0, memberSize = 0;
+      std::tie(memberAlignment, memberSize) =
+          getAlignmentAndSize(field.type, rule, stride);
+
+      if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+          rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
+          rule == SpirvLayoutRule::FxcCTBuffer) {
+        alignUsingHLSLRelaxedLayout(field.type, memberSize, memberAlignment,
+                                    &structSize);
+      } else {
+        structSize = roundToPow2(structSize, memberAlignment);
+      }
+
+      // Reset the current offset to the one specified in the source code
+      // if exists. It's debatable whether we should do sanity check here.
+      // If the developers want manually control the layout, we leave
+      // everything to them.
+      if (field.vkOffsetAttr) {
+        structSize = field.vkOffsetAttr->getOffset();
+      }
+
+      // The base alignment of the structure is N, where N is the largest
+      // base alignment value of any of its members...
+      maxAlignment = std::max(maxAlignment, memberAlignment);
+      structSize += memberSize;
+    }
+
+    if (rule == SpirvLayoutRule::GLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      // ... and rounded up to the base alignment of a vec4.
+      maxAlignment = roundToPow2(maxAlignment, kStd140Vec4Alignment);
+    }
+
+    if (rule != SpirvLayoutRule::FxcCTBuffer &&
+        rule != SpirvLayoutRule::FxcSBuffer) {
+      // The base offset of the member following the sub-structure is rounded up
+      // to the next multiple of the base alignment of the structure.
+      structSize = roundToPow2(structSize, maxAlignment);
+    }
+    return {maxAlignment, structSize};
+  }
+
+  // Rule 4, 6, 8, and 10
+  if (auto *arrayType = dyn_cast<ArrayType>(type)) {
+    const auto elemCount = arrayType->getElementCount();
+    uint32_t alignment = 0, size = 0;
+    std::tie(alignment, size) =
+        getAlignmentAndSize(arrayType->getElementType(), rule, stride);
+
+    if (rule == SpirvLayoutRule::FxcSBuffer) {
+      *stride = size;
+      // Use element alignment for fxc structured buffers
+      return {alignment, size * elemCount};
+    }
+
+    if (rule == SpirvLayoutRule::GLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      // The base alignment and array stride are set to match the base alignment
+      // of a single array element, according to rules 1, 2, and 3, and rounded
+      // up to the base alignment of a vec4.
+      alignment = roundToPow2(alignment, kStd140Vec4Alignment);
+    }
+    if (rule == SpirvLayoutRule::FxcCTBuffer) {
+      // In fxc cbuffer/tbuffer packing rules, arrays does not affect the data
+      // packing after it. But we still need to make sure paddings are inserted
+      // internally if necessary.
+      *stride = roundToPow2(size, alignment);
+      size += *stride * (elemCount - 1);
+    } else {
+      // Need to round size up considering stride for scalar types
+      size = roundToPow2(size, alignment);
+      *stride = size; // Use size instead of alignment here for Rule 10
+      size *= elemCount;
+      // The base offset of the member following the array is rounded up to the
+      // next multiple of the base alignment.
+      size = roundToPow2(size, alignment);
+    }
+
+    return {alignment, size};
+  }
+
+  emitError("alignment and size calculation unimplemented for type");
+  return {0, 0};
+}
+
+void EmitTypeHandler::alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
+                                                  uint32_t fieldSize,
+                                                  uint32_t fieldAlignment,
+                                                  uint32_t *currentOffset) {
+  if (auto *vecType = dyn_cast<VectorType>(fieldType)) {
+    const SpirvType *elemType = vecType->getElementType();
+    // Adjust according to HLSL relaxed layout rules.
+    // Aligning vectors as their element types so that we can pack a float
+    // and a float3 tightly together.
+    uint32_t scalarAlignment = 0;
+    std::tie(scalarAlignment, std::ignore) =
+        getAlignmentAndSize(elemType, SpirvLayoutRule::Void, nullptr);
+    if (scalarAlignment <= 4)
+      fieldAlignment = scalarAlignment;
+
+    *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
+
+    // Adjust according to HLSL relaxed layout rules.
+    // Bump to 4-component vector alignment if there is a bad straddle
+    if (improperStraddle(vecType, fieldSize, *currentOffset)) {
+      fieldAlignment = kStd140Vec4Alignment;
+      *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
+    }
+  }
+  // Cases where the field is not a vector
+  else {
+    *currentOffset = roundToPow2(*currentOffset, fieldAlignment);
+  }
+}
+
+void EmitTypeHandler::emitLayoutDecorations(const StructType *structType,
+                                            SpirvLayoutRule rule) {
+  // Decorations for a type can be emitted after the type itself has been
+  // visited, because we need the result-id of the type as the target of the
+  // decoration.
+  bool visited = false;
+  const uint32_t typeResultId = getResultIdForType(structType, rule, &visited);
+  assert(visited);
+
+  uint32_t offset = 0, index = 0;
+  for (auto &field : structType->getFields()) {
+    const SpirvType *fieldType = field.type;
+    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
+    std::tie(memberAlignment, memberSize) =
+        getAlignmentAndSize(fieldType, rule, &stride);
+
+    // The next avaiable location after laying out the previous members
+    const uint32_t nextLoc = offset;
+
+    if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
+        rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
+        rule == SpirvLayoutRule::FxcCTBuffer) {
+      alignUsingHLSLRelaxedLayout(fieldType, memberSize, memberAlignment,
+                                  &offset);
+    } else {
+      offset = roundToPow2(offset, memberAlignment);
+    }
+
+    // The vk::offset attribute takes precedence over all.
+    if (field.vkOffsetAttr) {
+      offset = field.vkOffsetAttr->getOffset();
+    }
+    // The :packoffset() annotation takes precedence over normal layout
+    // calculation.
+    else if (field.packOffsetAttr) {
+      const uint32_t packOffset = field.packOffsetAttr->Subcomponent * 16 +
+                                  field.packOffsetAttr->ComponentOffset * 4;
+      // Do minimal check to make sure the offset specified by packoffset does
+      // not cause overlap.
+      if (packOffset < nextLoc) {
+        emitError("packoffset caused overlap with previous members",
+                  field.packOffsetAttr->Loc);
+      } else {
+        offset = packOffset;
+      }
+    }
+
+    // Each structure-type member must have an Offset Decoration.
+    emitDecoration(typeResultId, spv::Decoration::Offset, {offset}, index);
+    offset += memberSize;
+
+    // Each structure-type member that is a matrix or array-of-matrices must be
+    // decorated with
+    // * A MatrixStride decoration, and
+    // * one of the RowMajor or ColMajor Decorations.
+    if (auto *arrayType = dyn_cast<ArrayType>(fieldType)) {
+      // We have an array of matrices as a field, we need to decorate
+      // MatrixStride on the field. So skip possible arrays here.
+      fieldType = arrayType->getElementType();
+    }
+
+    // Non-floating point matrices are represented as arrays of vectors, and
+    // therefore ColMajor and RowMajor decorations should not be applied to
+    // them.
+    if (auto *matType = dyn_cast<MatrixType>(fieldType)) {
+      if (isa<FloatType>(matType->getElementType())) {
+        memberAlignment = memberSize = stride = 0;
+        std::tie(memberAlignment, memberSize) =
+            getAlignmentAndSize(fieldType, rule, &stride);
+
+        emitDecoration(typeResultId, spv::Decoration::MatrixStride, {stride},
+                       index);
+
+        // We need to swap the RowMajor and ColMajor decorations since HLSL
+        // matrices are conceptually row-major while SPIR-V are conceptually
+        // column-major.
+        if (matType->isRowMajorMat()) {
+          emitDecoration(typeResultId, spv::Decoration::ColMajor, {}, index);
+        } else {
+          // If the source code has neither row_major nor column_major
+          // annotated, it should be treated as column_major since that's the
+          // default.
+          emitDecoration(typeResultId, spv::Decoration::RowMajor, {}, index);
+        }
+      }
+    }
+
+    ++index;
+  }
+}
+
+void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
+                                     spv::Decoration decoration,
+                                     llvm::ArrayRef<uint32_t> decorationParams,
+                                     uint32_t memberIndex) {
+
+  spv::Op op = memberIndex ? spv::Op::OpMemberDecorate : spv::Op::OpDecorate;
+  assert(curDecorationInst.empty());
+  curDecorationInst.push_back(static_cast<uint32_t>(op));
+  curDecorationInst.push_back(typeResultId);
+  if (memberIndex)
+    curDecorationInst.push_back(memberIndex);
+  curDecorationInst.push_back(static_cast<uint32_t>(decoration));
+  for (auto param : decorationParams)
+    curDecorationInst.push_back(param);
+  curDecorationInst[0] |= static_cast<uint32_t>(curDecorationInst.size()) << 16;
+
+  // Add to the full annotations list
+  annotationsBinary->insert(annotationsBinary->end(), curDecorationInst.begin(),
+                            curDecorationInst.end());
+  curDecorationInst.clear();
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 16 - 9
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -126,6 +126,13 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
     if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
       const auto *vecType =
           spvContext.getVectorType(lowerType(elemType, rule, srcLoc), colCount);
+
+      // Non-float matrices are represented as an array of vectors.
+      if (!elemType->isFloatingType()) {
+        // This return type is ArrayType
+        return spvContext.getArrayType(vecType, rowCount);
+      }
+
       // HLSL matrices are conceptually row major, while SPIR-V matrices are
       // conceptually column major. We are mapping what HLSL semantically mean
       // a row into a column here.
@@ -146,25 +153,24 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
     if (const auto *spvType = lowerResourceType(type, rule, srcLoc))
       return spvType;
 
-    // Collect all fields' types and names.
-    llvm::SmallVector<const SpirvType *, 8> fieldTypes;
-    llvm::SmallVector<llvm::StringRef, 8> fieldNames;
+    // Collect all fields' information.
+    llvm::SmallVector<StructType::FieldInfo, 8> fields;
 
     // If this struct is derived from some other struct, place an implicit
     // field at the very beginning for the base struct.
     if (const auto *cxxDecl = dyn_cast<CXXRecordDecl>(decl))
       for (const auto base : cxxDecl->bases()) {
-        fieldTypes.push_back(lowerType(base.getType(), rule, srcLoc));
-        fieldNames.push_back("");
+        fields.push_back(
+            StructType::FieldInfo(lowerType(base.getType(), rule, srcLoc)));
       }
 
     // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
-      fieldTypes.push_back(lowerType(field->getType(), rule, srcLoc));
-      fieldNames.push_back(field->getName());
+      const SpirvType *fieldType = lowerType(field->getType(), rule, srcLoc);
+      fields.push_back(StructType::FieldInfo(fieldType, field->getName()));
     }
 
-    return spvContext.getStructType(fieldTypes, decl->getName(), fieldNames);
+    return spvContext.getStructType(fields, decl->getName());
   }
 
   // Array type
@@ -288,7 +294,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     const auto *raType = spvContext.getRuntimeArrayType(structType);
 
     const std::string typeName = "type." + name.str() + "." + structName;
-    const auto *valType = spvContext.getStructType(raType, typeName);
+    const auto *valType =
+        spvContext.getStructType({StructType::FieldInfo(raType)}, typeName);
 
     if (asAlias) {
       // All structured buffers are in the Uniform storage class.

+ 8 - 7
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -225,15 +225,16 @@ SpirvContext::getRuntimeArrayType(const SpirvType *elemType) {
   return runtimeArrayTypes[elemType] = new (this) RuntimeArrayType(elemType);
 }
 
-const StructType *SpirvContext::getStructType(
-    llvm::ArrayRef<const SpirvType *> fieldTypes, llvm::StringRef name,
-    llvm::ArrayRef<llvm::StringRef> fieldNames, bool isReadOnly) {
+const StructType *
+SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
+                            llvm::StringRef name, bool isReadOnly,
+                            StructType::InterfaceType interfaceType) {
   // We are creating a temporary struct type here for querying whether the
   // same type was already created. It is a little bit costly, but we can
   // avoid allocating directly from the bump pointer allocator, from which
   // then we are unable to reclaim until the allocator itself is destroyed.
 
-  StructType type(fieldTypes, name, fieldNames, isReadOnly);
+  StructType type(fields, name, isReadOnly, interfaceType);
 
   auto found = std::find_if(
       structTypes.begin(), structTypes.end(),
@@ -243,7 +244,7 @@ const StructType *SpirvContext::getStructType(
     return *found;
 
   structTypes.push_back(
-      new (this) StructType(fieldTypes, name, fieldNames, isReadOnly));
+      new (this) StructType(fields, name, isReadOnly, interfaceType));
 
   return structTypes.back();
 }
@@ -286,10 +287,10 @@ const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
   const auto *raType = getRuntimeArrayType(getUIntType(32));
 
   // Create a struct containing the runtime array as its only member.
-  return getStructType({raType},
+  return getStructType({StructType::FieldInfo(raType)},
                        isWritable ? "type.RWByteAddressBuffer"
                                   : "type.ByteAddressBuffer",
-                       {}, !isWritable);
+                       !isWritable);
 }
 
 SpirvConstant *SpirvContext::getConstantUint32(uint32_t value) {

+ 13 - 9
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -50,18 +50,22 @@ bool ImageType::operator==(const ImageType &that) const {
          isSampled == that.isSampled && imageFormat == that.imageFormat;
 }
 
-StructType::StructType(llvm::ArrayRef<const SpirvType *> memberTypes,
-                       llvm::StringRef name,
-                       llvm::ArrayRef<llvm::StringRef> memberNames,
-                       bool isReadOnly)
-    : SpirvType(TK_Struct), structName(name),
-      fieldTypes(memberTypes.begin(), memberTypes.end()),
-      fieldNames(memberNames.begin(), memberNames.end()), readOnly(isReadOnly) {
+StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
+                       llvm::StringRef name, bool isReadOnly,
+                       StructType::InterfaceType iface)
+    : SpirvType(TK_Struct), fields(fieldsVec.begin(), fieldsVec.end()),
+      structName(name), readOnly(isReadOnly), interfaceType(iface) {}
+
+bool StructType::FieldInfo::
+operator==(const StructType::FieldInfo &that) const {
+  return type == that.type && name == that.name &&
+         vkOffsetAttr == that.vkOffsetAttr &&
+         packOffsetAttr == that.packOffsetAttr;
 }
 
 bool StructType::operator==(const StructType &that) const {
-  return structName == that.structName && fieldTypes == that.fieldTypes &&
-         fieldNames == that.fieldNames && readOnly == that.readOnly;
+  return fields == that.fields && structName == that.structName &&
+         readOnly == that.readOnly;
 }
 
 } // namespace spirv