Explorar o código

[spirv] Layout rules and their implications on types

and their implications on decorations.
Ehsan Nasiri %!s(int64=6) %!d(string=hai) anos
pai
achega
713bbbd5a4

+ 4 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -128,6 +128,10 @@ QualType getElementType(QualType type);
 QualType getTypeWithCustomBitwidth(const ASTContext &, QualType type,
                                    uint32_t bitwidth);
 
+/// Returns true if the given type is a matrix or an array of matrices.
+bool isMatrixOrArrayOfMatrix(const ASTContext &, QualType type);
+                             
+
 } // namespace spirv
 } // namespace clang
 

+ 9 - 4
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -106,7 +106,8 @@ public:
   //
   // If any decorations apply to the type, it also emits the decoration
   // instructions into the annotationsBinary.
-  uint32_t emitType(const SpirvType *, SpirvLayoutRule);
+  uint32_t emitType(const SpirvType *, SpirvLayoutRule,
+                    llvm::Optional<bool> isRowMajor = llvm::None);
 
   // Emits an OpConstant instruction with uint32 type and returns its result-id.
   // If such constant has already been emitted, just returns its resutl-id.
@@ -121,6 +122,7 @@ private:
   // Figures out the decorations that apply to the given type with the given
   // layout rule, and populates the given decoration set.
   void getDecorationsForType(const SpirvType *type, SpirvLayoutRule rule,
+                             llvm::Optional<bool> isRowMajor,
                              DecorationList *decorations);
 
   // Returns the result-id for the given type and decorations. If a type with
@@ -154,9 +156,12 @@ private:
 
   // ---- Methods associated with layout calculations ----
 
-  std::pair<uint32_t, uint32_t> getAlignmentAndSize(const SpirvType *type,
-                                                    SpirvLayoutRule rule,
-                                                    uint32_t *stride);
+  // If the given type is a matrix type inside a struct, its majorness should
+  // also be passed to this method in order to determine the correct alignment.
+  std::pair<uint32_t, uint32_t>
+  getAlignmentAndSize(const SpirvType *type, SpirvLayoutRule rule,
+                      uint32_t *stride,
+                      llvm::Optional<bool> isRowMajorStructMember = llvm::None);
 
   void alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
                                    uint32_t fieldSize, uint32_t fieldAlignment,

+ 15 - 12
tools/clang/include/clang/SPIRV/LowerTypeVisitor.h

@@ -73,22 +73,25 @@ private:
   /// This method will update internal bookkeeping regarding matrix majorness.
   QualType desugarType(QualType type);
 
-  /// Returns true if type is a HLSL row-major matrix, either with explicit
-  /// attribute or implicit command-line option.
-  bool isRowMajorMatrix(QualType type) const;
+  /// Returns true if type is an HLSL row-major matrix or array of matrices.
+  /// Returns false if type is an HLSL col-major matrix or array of matrices.
+  /// Returns llvm::None if the type is not a matrix or array of matrices.
+  /// It does so by checking the majorness of the HLSL matrix either with
+  /// explicit attribute or implicit command-line option.
+  llvm::Optional<bool> isHLSLRowMajorMatrix(QualType type) const;
+
+  /// Returns true if type is a SPIR-V row-major matrix or array of matrices.
+  /// Returns false if type is a SPIR-V col-major matrix or array of matrices.
+  /// Returns llvm::None if the type is not a matrix or array of matrices.
+  /// 
+  /// Note that HLSL matrices are conceptually row major, while SPIR-V matrices
+  /// are conceptually column major. We are mapping what HLSL semantically mean
+  /// a row into a column here.
+  llvm::Optional<bool> isRowMajorMatrix(QualType type) const;
 
 private:
   ASTContext &astContext;   /// AST context
   SpirvContext &spvContext; /// SPIR-V context
-
-  /// A place to keep the matrix majorness attributes so that we can retrieve
-  /// the information when really processing the desugared matrix type.
-  ///
-  /// This is needed because the majorness attribute is decorated on a
-  /// TypedefType (i.e., floatMxN) of the real matrix type (i.e., matrix<elem,
-  /// row, col>). When we reach the desugared matrix type, this information
-  /// is already gone.
-  llvm::Optional<AttributedType::Kind> typeMatMajorAttr;
 };
 
 } // end namespace spirv

+ 7 - 4
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -192,8 +192,9 @@ public:
   const FloatType *getFloatType(uint32_t bitwidth);
 
   const VectorType *getVectorType(const SpirvType *elemType, uint32_t count);
-  const MatrixType *getMatrixType(const SpirvType *vecType, uint32_t vecCount,
-                                  bool isRowMajor);
+  // Note: In the case of non-floating-point matrices, this method returns an
+  // array of vectors.
+  const SpirvType *getMatrixType(const SpirvType *vecType, uint32_t vecCount);
 
   const ImageType *getImageType(const SpirvType *, spv::Dim,
                                 ImageType::WithDepth, bool arrayed, bool ms,
@@ -203,7 +204,8 @@ public:
   const SampledImageType *getSampledImageType(const ImageType *image);
   const HybridSampledImageType *getSampledImageType(QualType image);
 
-  const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount);
+  const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount,
+                                llvm::Optional<bool> rowMajorElem);
   const RuntimeArrayType *getRuntimeArrayType(const SpirvType *elemType);
 
   const StructType *getStructType(
@@ -272,7 +274,8 @@ private:
   llvm::DenseMap<QualType, const HybridSampledImageType *, QualTypeDenseMapInfo>
       hybridSampledImageTypes;
 
-  llvm::DenseMap<const SpirvType *, CountToArrayMap> arrayTypes;
+  //llvm::DenseMap<const SpirvType *, CountToArrayMap> arrayTypes;
+  llvm::SmallVector<const ArrayType *, 8> arrayTypes;
   llvm::DenseMap<const SpirvType *, const RuntimeArrayType *> runtimeArrayTypes;
 
   llvm::SmallVector<const StructType *, 8> structTypes;

+ 21 - 13
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -66,6 +66,7 @@ public:
   static bool isSubpassInputMS(const SpirvType *);
   static bool isResourceType(const SpirvType *);
   static bool isOrContains16BitType(const SpirvType *);
+  static bool isMatrixOrArrayOfMatrix(const SpirvType *);
 
 protected:
   SpirvType(Kind k, llvm::StringRef name = "") : kind(k), debugName(name) {}
@@ -148,7 +149,7 @@ private:
 
 class MatrixType : public SpirvType {
 public:
-  MatrixType(const VectorType *vecType, uint32_t vecCount, bool rowMajor);
+  MatrixType(const VectorType *vecType, uint32_t vecCount);
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Matrix; }
 
@@ -159,20 +160,12 @@ public:
     return vectorType->getElementType();
   }
   uint32_t getVecCount() const { return vectorCount; }
-  bool isRowMajorMat() const { return isRowMajor; }
-
   uint32_t numCols() const { return vectorCount; }
   uint32_t numRows() const { return vectorType->getElementCount(); }
 
 private:
   const VectorType *vectorType;
   uint32_t vectorCount;
-  // It's debatable whether we should put majorness as a field in the type
-  // itself. Majorness only matters at the time of emitting SPIR-V words since
-  // we need the layout decoration then. However, if we don't put it here,
-  // we will need to rediscover the majorness information from QualType at
-  // the time of emitting SPIR-V words.
-  bool isRowMajor;
 };
 
 class ImageType : public SpirvType {
@@ -241,17 +234,29 @@ private:
 
 class ArrayType : public SpirvType {
 public:
-  ArrayType(const SpirvType *elemType, uint32_t elemCount)
-      : SpirvType(TK_Array), elementType(elemType), elementCount(elemCount) {}
+  ArrayType(const SpirvType *elemType, uint32_t elemCount,
+            llvm::Optional<bool> hasRowMajorElem)
+      : SpirvType(TK_Array), elementType(elemType), elementCount(elemCount),
+        rowMajorElem(hasRowMajorElem) {}
 
   const SpirvType *getElementType() const { return elementType; }
   uint32_t getElementCount() const { return elementCount; }
+  llvm::Optional<bool> hasRowMajorElement() const { return rowMajorElem; }
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Array; }
 
+  bool operator==(const ArrayType &that) const;
+
 private:
   const SpirvType *elementType;
   uint32_t elementCount;
+  // Two arrays types with different ArrayStride decorations, are in fact two
+  // different array types. In general, the combination of element type and
+  // element count is enough to determine the array stride. However, in the case
+  // of arrays of matrices, we also need to know the majorness of the matrices.
+  // An array of 5 row_major 2x3 matrices is a different type from
+  // an array of 5 col_major 2x3 matrices.
+  llvm::Optional<bool> rowMajorElem;
 };
 
 class RuntimeArrayType : public SpirvType {
@@ -275,9 +280,10 @@ public:
   public:
     FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
               clang::VKOffsetAttr *offset = nullptr,
-              hlsl::ConstantPacking *packOffset = nullptr)
+              hlsl::ConstantPacking *packOffset = nullptr,
+              llvm::Optional<bool> rowMajor = llvm::None)
         : type(type_), name(name_), vkOffsetAttr(offset),
-          packOffsetAttr(packOffset) {}
+          packOffsetAttr(packOffset), isRowMajor(rowMajor) {}
 
     bool operator==(const FieldInfo &that) const;
 
@@ -289,6 +295,8 @@ public:
     clang::VKOffsetAttr *vkOffsetAttr;
     // :packoffset() annotations associated with this field.
     hlsl::ConstantPacking *packOffsetAttr;
+    // The majorness of this field (if it is a matrix).
+    llvm::Optional<bool> isRowMajor;
   };
 
   StructType(

+ 11 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -589,5 +589,16 @@ QualType getTypeWithCustomBitwidth(const ASTContext &ctx, QualType type,
       "invalid type or bitwidth passed to getTypeWithCustomBitwidth");
 }
 
+bool isMatrixOrArrayOfMatrix(const ASTContext &context, QualType type) {
+  if (isMxNMatrix(type)) {
+    return true;
+  }
+
+  if (const auto *arrayType = context.getAsArrayType(type))
+    return isMatrixOrArrayOfMatrix(context, arrayType->getElementType());
+
+  return false;
+}
+
 } // namespace spirv
 } // namespace clang

+ 4 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -709,7 +709,10 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
 
   // Make an array if requested.
   if (arraySize > 0) {
-    resultType = spvContext.getArrayType(resultType, arraySize);
+    // The array element is a structure, therefore no majorness information is
+    // needed.
+    resultType = spvContext.getArrayType(resultType, arraySize,
+                                         /*rowMajorElement*/ llvm::None);
   } else if (arraySize == -1) {
     resultType = spvContext.getRuntimeArrayType(resultType);
   }

+ 53 - 31
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1054,7 +1054,8 @@ bool EmitVisitor::visit(SpirvArrayLength *inst) {
   initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
-  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getStructure()));
+  curInst.push_back(
+      getOrAssignResultId<SpirvInstruction>(inst->getStructure()));
   curInst.push_back(inst->getArrayMember());
   finalizeInstruction();
   emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
@@ -1115,6 +1116,7 @@ uint32_t EmitTypeHandler::getOrCreateConstantUint32(uint32_t value) {
 
 void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
                                             SpirvLayoutRule rule,
+                                            llvm::Optional<bool> isRowMajor,
                                             DecorationList *decs) {
   // Array types
   if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
@@ -1124,7 +1126,8 @@ void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
     if (rule != SpirvLayoutRule::Void &&
         !isAKindOfStructuredOrByteBuffer(type)) {
       uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride);
+      (void)getAlignmentAndSize(type, rule, &stride,
+                                arrayType->hasRowMajorElement());
       decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
     }
   }
@@ -1135,7 +1138,7 @@ void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
     if (rule != SpirvLayoutRule::Void &&
         !isa<ImageType>(raType->getElementType())) {
       uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride);
+      (void)getAlignmentAndSize(type, rule, &stride, isRowMajor);
       decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
     }
   }
@@ -1164,12 +1167,12 @@ void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
   // structure types.
 }
 
-uint32_t EmitTypeHandler::emitType(const SpirvType *type,
-                                   SpirvLayoutRule rule) {
+uint32_t EmitTypeHandler::emitType(const SpirvType *type, SpirvLayoutRule rule,
+                                   llvm::Optional<bool> isRowMajor) {
   // First get the decorations that would apply to this type.
   bool alreadyExists = false;
   DecorationList decs;
-  getDecorationsForType(type, rule, &decs);
+  getDecorationsForType(type, rule, isRowMajor, &decs);
   const uint32_t id = getResultIdForType(type, decs, &alreadyExists);
 
   // If the type has already been emitted, we just need to return its
@@ -1208,7 +1211,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // Vector types
   else if (const auto *vecType = dyn_cast<VectorType>(type)) {
-    const uint32_t elementTypeId = emitType(vecType->getElementType(), rule);
+    const uint32_t elementTypeId =
+        emitType(vecType->getElementType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeVector);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elementTypeId);
@@ -1217,7 +1221,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // Matrix types
   else if (const auto *matType = dyn_cast<MatrixType>(type)) {
-    const uint32_t vecTypeId = emitType(matType->getVecType(), rule);
+    const uint32_t vecTypeId =
+        emitType(matType->getVecType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeMatrix);
     curTypeInst.push_back(id);
     curTypeInst.push_back(vecTypeId);
@@ -1228,7 +1233,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // Image types
   else if (const auto *imageType = dyn_cast<ImageType>(type)) {
-    const uint32_t sampledTypeId = emitType(imageType->getSampledType(), rule);
+    const uint32_t sampledTypeId =
+        emitType(imageType->getSampledType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeImage);
     curTypeInst.push_back(id);
     curTypeInst.push_back(sampledTypeId);
@@ -1249,7 +1255,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   // SampledImage types
   else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
     const uint32_t imageTypeId =
-        emitType(sampledImageType->getImageType(), rule);
+        emitType(sampledImageType->getImageType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeSampledImage);
     curTypeInst.push_back(id);
     curTypeInst.push_back(imageTypeId);
@@ -1262,7 +1268,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     const auto length = getOrCreateConstantUint32(arrayType->getElementCount());
 
     // Emit the OpTypeArray instruction
-    const uint32_t elemTypeId = emitType(arrayType->getElementType(), rule);
+    const uint32_t elemTypeId =
+        emitType(arrayType->getElementType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeArray);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
@@ -1271,7 +1278,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // RuntimeArray types
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
-    const uint32_t elemTypeId = emitType(raType->getElementType(), rule);
+    const uint32_t elemTypeId =
+        emitType(raType->getElementType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypeRuntimeArray);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
@@ -1288,7 +1296,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
 
     llvm::SmallVector<uint32_t, 4> fieldTypeIds;
     for (auto &field : fields)
-      fieldTypeIds.push_back(emitType(field.type, rule));
+      fieldTypeIds.push_back(emitType(field.type, rule, field.isRowMajor));
     initTypeInstruction(spv::Op::OpTypeStruct);
     curTypeInst.push_back(id);
     for (auto fieldTypeId : fieldTypeIds)
@@ -1297,7 +1305,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // Pointer types
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
-    const uint32_t pointeeType = emitType(ptrType->getPointeeType(), rule);
+    const uint32_t pointeeType =
+        emitType(ptrType->getPointeeType(), rule, isRowMajor);
     initTypeInstruction(spv::Op::OpTypePointer);
     curTypeInst.push_back(id);
     curTypeInst.push_back(static_cast<uint32_t>(ptrType->getStorageClass()));
@@ -1306,10 +1315,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   }
   // Function types
   else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
-    const uint32_t retTypeId = emitType(fnType->getReturnType(), rule);
+    const uint32_t retTypeId =
+        emitType(fnType->getReturnType(), rule, isRowMajor);
     llvm::SmallVector<uint32_t, 4> paramTypeIds;
     for (auto *paramType : fnType->getParamTypes())
-      paramTypeIds.push_back(emitType(paramType, rule));
+      paramTypeIds.push_back(emitType(paramType, rule, isRowMajor));
 
     initTypeInstruction(spv::Op::OpTypeFunction);
     curTypeInst.push_back(id);
@@ -1340,7 +1350,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
 
 std::pair<uint32_t, uint32_t>
 EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
-                                     SpirvLayoutRule rule, uint32_t *stride) {
+                                     SpirvLayoutRule rule, uint32_t *stride,
+                                     llvm::Optional<bool> isRowMajor) {
   // std140 layout rules:
 
   // 1. If the member is a scalar consuming N basic machine units, the base
@@ -1436,7 +1447,8 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
       uint32_t alignment = 0, size = 0;
       uint32_t elemCount = vecType->getElementCount();
       const SpirvType *elemType = vecType->getElementType();
-      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+      std::tie(alignment, size) =
+          getAlignmentAndSize(elemType, rule, stride, isRowMajor);
       // Use element alignment for fxc rules
       if (rule != SpirvLayoutRule::FxcCTBuffer &&
           rule != SpirvLayoutRule::FxcSBuffer)
@@ -1452,15 +1464,16 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
       uint32_t rowCount = matType->numRows();
       uint32_t colCount = matType->numCols();
       uint32_t alignment = 0, size = 0;
-      std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+      std::tie(alignment, size) =
+          getAlignmentAndSize(elemType, rule, stride, isRowMajor);
 
       // Matrices are treated as arrays of vectors:
       // The base alignment and array stride are set to match the base alignment
       // of a single array element, according to rules 1, 2, and 3, and rounded
       // up to the base alignment of a vec4.
-      bool isRowMajor = matType->isRowMajorMat();
-
-      const uint32_t vecStorageSize = isRowMajor ? colCount : rowCount;
+      assert(isRowMajor.hasValue());
+      bool rowMajor = isRowMajor.getValue();
+      const uint32_t vecStorageSize = rowMajor ? colCount : rowCount;
 
       if (rule == SpirvLayoutRule::FxcSBuffer) {
         *stride = vecStorageSize * size;
@@ -1475,7 +1488,7 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
         alignment = roundToPow2(alignment, kStd140Vec4Alignment);
       }
       *stride = alignment;
-      size = (isRowMajor ? rowCount : colCount) * alignment;
+      size = (rowMajor ? rowCount : colCount) * alignment;
 
       return {alignment, size};
     }
@@ -1494,7 +1507,7 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
     for (auto &field : structType->getFields()) {
       uint32_t memberAlignment = 0, memberSize = 0;
       std::tie(memberAlignment, memberSize) =
-          getAlignmentAndSize(field.type, rule, stride);
+          getAlignmentAndSize(field.type, rule, stride, field.isRowMajor);
 
       if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
           rule == SpirvLayoutRule::RelaxedGLSLStd430 ||
@@ -1538,6 +1551,8 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
   // Rule 4, 6, 8, and 10
   auto *arrayType = dyn_cast<ArrayType>(type);
   auto *raType = dyn_cast<RuntimeArrayType>(type);
+  if (arrayType)
+    isRowMajor = arrayType->hasRowMajorElement();
   if (arrayType || raType) {
     // Some exaplanation about runtime arrays:
     // The number of elements in a runtime array is unknown at compile time. As
@@ -1551,7 +1566,8 @@ EmitTypeHandler::getAlignmentAndSize(const SpirvType *type,
         arrayType ? arrayType->getElementType() : raType->getElementType();
 
     uint32_t alignment = 0, size = 0;
-    std::tie(alignment, size) = getAlignmentAndSize(elemType, rule, stride);
+    std::tie(alignment, size) =
+        getAlignmentAndSize(elemType, rule, stride, isRowMajor);
 
     if (rule == SpirvLayoutRule::FxcSBuffer) {
       *stride = size;
@@ -1630,7 +1646,7 @@ void EmitTypeHandler::getLayoutDecorations(const StructType *structType,
     const SpirvType *fieldType = field.type;
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
     std::tie(memberAlignment, memberSize) =
-        getAlignmentAndSize(fieldType, rule, &stride);
+        getAlignmentAndSize(fieldType, rule, &stride, field.isRowMajor);
 
     // The next avaiable location after laying out the previous members
     const uint32_t nextLoc = offset;
@@ -1684,15 +1700,21 @@ void EmitTypeHandler::getLayoutDecorations(const StructType *structType,
       if (isa<FloatType>(matType->getElementType())) {
         memberAlignment = memberSize = stride = 0;
         std::tie(memberAlignment, memberSize) =
-            getAlignmentAndSize(fieldType, rule, &stride);
+            getAlignmentAndSize(fieldType, rule, &stride, field.isRowMajor);
 
         decs->push_back(
             DecorationInfo(spv::Decoration::MatrixStride, {stride}, index));
 
-        if (matType->isRowMajorMat())
-          decs->push_back(DecorationInfo(spv::Decoration::RowMajor, {}, index));
-        else
-          decs->push_back(DecorationInfo(spv::Decoration::ColMajor, {}, index));
+        if (field.isRowMajor.hasValue()) {
+          if (field.isRowMajor.getValue()) {
+            decs->push_back(
+                DecorationInfo(spv::Decoration::RowMajor, {}, index));
+          } else {
+            decs->push_back(
+                DecorationInfo(spv::Decoration::ColMajor, {}, index));
+          }
+        }
+
       }
     }
 

+ 11 - 7
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -308,13 +308,15 @@ void GlPerVertex::calculateClipCullDistanceArraySize() {
 
 SpirvVariable *GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
                                                       uint32_t arraySize) {
-  const ArrayType *type =
-      spvContext.getArrayType(spvContext.getFloatType(32), arraySize);
+  const ArrayType *type = spvContext.getArrayType(
+      spvContext.getFloatType(32), arraySize, /*rowMajorElem*/ llvm::None);
 
   if (asInput && inArraySize != 0) {
-    type = spvContext.getArrayType(type, inArraySize);
+    type =
+        spvContext.getArrayType(type, inArraySize, /*rowMajorElem*/ llvm::None);
   } else if (!asInput && outArraySize != 0) {
-    type = spvContext.getArrayType(type, outArraySize);
+    type = spvContext.getArrayType(type, outArraySize,
+                                   /*rowMajorElem*/ llvm::None);
   }
 
   spv::StorageClass sc =
@@ -429,7 +431,8 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(bool isClip,
   const ArrayType *arrayType = nullptr;
 
   if (isScalarType(asType)) {
-    arrayType = spvContext.getArrayType(f32Type, inArraySize);
+    arrayType = spvContext.getArrayType(f32Type, inArraySize,
+                                        /*rowMajorElem*/ llvm::None);
     for (uint32_t i = 0; i < inArraySize; ++i) {
       auto *ptr = spvBuilder.createAccessChain(
           ptrType, clipCullVar,
@@ -438,8 +441,9 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(bool isClip,
       arrayElements.push_back(spvBuilder.createLoad(astContext.FloatTy, ptr));
     }
   } else if (isVectorType(asType, &elemType, &count)) {
-    arrayType = spvContext.getArrayType(
-        spvContext.getVectorType(f32Type, count), inArraySize);
+    arrayType =
+        spvContext.getArrayType(spvContext.getVectorType(f32Type, count),
+                                inArraySize, /*rowMajorElem*/ llvm::None);
 
     for (uint32_t i = 0; i < inArraySize; ++i) {
       // For each gl_PerVertex block, we need to read a vector from it.

+ 47 - 42
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -130,9 +130,10 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     std::vector<StructType::FieldInfo> structFields;
     for (auto field : hybridStruct->getFields()) {
       const SpirvType *fieldSpirvType = lowerType(field.astType, rule, loc);
-      structFields.push_back(StructType::FieldInfo(fieldSpirvType, field.name,
-                                                   field.vkOffsetAttr,
-                                                   field.packOffsetAttr));
+      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(field.astType);
+      structFields.push_back(
+          StructType::FieldInfo(fieldSpirvType, field.name, field.vkOffsetAttr,
+                                field.packOffsetAttr, isRowMajor));
     }
     return spvContext.getStructType(structFields, hybridStruct->getStructName(),
                                     hybridStruct->isReadOnly(),
@@ -165,7 +166,8 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     // If array didn't contain any hybrid types, return itself.
     if (arrType->getElementType() == loweredElemType)
       return arrType;
-    return spvContext.getArrayType(loweredElemType, arrType->getElementCount());
+    return spvContext.getArrayType(loweredElemType, arrType->getElementCount(),
+                                   arrType->hasRowMajorElement());
   }
   // Runtime arrays could contain a hybrid type
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
@@ -185,9 +187,9 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
       const auto *loweredFieldType = lowerType(field.type, rule, loc);
       if (loweredFieldType != field.type) {
         wasLowered = true;
-        loweredFields.push_back(
-            StructType::FieldInfo(loweredFieldType, field.name,
-                                  field.vkOffsetAttr, field.packOffsetAttr));
+        loweredFields.push_back(StructType::FieldInfo(
+            loweredFieldType, field.name, field.vkOffsetAttr,
+            field.packOffsetAttr, field.isRowMajor));
       } else {
         loweredFields.push_back(field);
       }
@@ -242,11 +244,6 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
 
   if (desugaredType != type) {
     const auto *spvType = lowerType(desugaredType, rule, srcLoc);
-    // Clear matrix majorness potentially set by previous desugarType() calls.
-    // This field will only be set when we were saying a matrix type. And the
-    // above lowerType() call already takes the majorness into consideration.
-    // So should be fine to clear now.
-    typeMatMajorAttr = llvm::None;
     return spvType;
   }
 
@@ -348,15 +345,12 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
       // Non-float matrices are represented as an array of vectors.
       if (!elemType->isFloatingType()) {
         // This return type is ArrayType
-        return spvContext.getArrayType(vecType, rowCount);
+        // This is an array of vectors. No majorness information needed.
+        return spvContext.getArrayType(vecType, rowCount,
+                                       /*rowMajorElem*/ llvm::None);
       }
 
-      // HLSL matrices are conceptually row major, while SPIR-V matrices are
-      // conceptually column major. We are mapping what HLSL semantically mean
-      // a row into a column here.
-      const bool isSpirvRowMajor = !isRowMajorMatrix(type);
-
-      return spvContext.getMatrixType(vecType, rowCount, isSpirvRowMajor);
+      return spvContext.getMatrixType(vecType, rowCount);
     }
   }
 
@@ -385,7 +379,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
     // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
       const SpirvType *fieldType = lowerType(field->getType(), rule, srcLoc);
-      fields.push_back(StructType::FieldInfo(fieldType, field->getName()));
+      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(field->getType());
+      fields.push_back(StructType::FieldInfo(
+          fieldType, field->getName(), /*vkoffset*/ nullptr,
+          /*packoffset*/ nullptr, isRowMajor));
     }
 
     return spvContext.getStructType(fields, decl->getName());
@@ -397,7 +394,8 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
 
     if (const auto *caType = astContext.getAsConstantArrayType(type)) {
       const auto size = static_cast<uint32_t>(caType->getSize().getZExtValue());
-      return spvContext.getArrayType(elemType, size);
+      llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
+      return spvContext.getArrayType(elemType, size, isRowMajor);
     }
 
     assert(type->isIncompleteArrayType());
@@ -572,15 +570,17 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
   if (name == "InputPatch") {
     const auto elemType = hlsl::GetHLSLInputPatchElementType(type);
     const auto elemCount = hlsl::GetHLSLInputPatchCount(type);
-    return spvContext.getArrayType(lowerType(elemType, rule, srcLoc),
-                                   elemCount);
+    llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
+    return spvContext.getArrayType(lowerType(elemType, rule, srcLoc), elemCount,
+                                   isRowMajor);
   }
   // OutputPatch
   if (name == "OutputPatch") {
     const auto elemType = hlsl::GetHLSLOutputPatchElementType(type);
     const auto elemCount = hlsl::GetHLSLOutputPatchCount(type);
+    llvm::Optional<bool> isRowMajor = isRowMajorMatrix(type);
     return spvContext.getArrayType(lowerType(elemType, rule, srcLoc),
-                                   elemCount);
+                                   elemCount, isRowMajor);
   }
   // Output stream objects (TriangleStream, LineStream, and PointStream)
   if (name == "TriangleStream" || name == "LineStream" ||
@@ -639,16 +639,6 @@ LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
 
 QualType LowerTypeVisitor::desugarType(QualType type) {
   if (const auto *attrType = type->getAs<AttributedType>()) {
-    switch (auto kind = attrType->getAttrKind()) {
-    case AttributedType::attr_hlsl_row_major:
-    case AttributedType::attr_hlsl_column_major:
-      typeMatMajorAttr = kind;
-      break;
-    default:
-      // We only need to update internal bookkeeping for matrix majorness.
-      break;
-    }
-
     return desugarType(
         attrType->getLocallyUnqualifiedSingleStepDesugaredType());
   }
@@ -660,24 +650,39 @@ QualType LowerTypeVisitor::desugarType(QualType type) {
   return type;
 }
 
-bool LowerTypeVisitor::isRowMajorMatrix(QualType type) const {
-  assert(isMxNMatrix(type));
-
-  // Use the majorness info we recorded before.
-  if (typeMatMajorAttr.hasValue()) {
-    switch (typeMatMajorAttr.getValue()) {
+llvm::Optional<bool>
+LowerTypeVisitor::isHLSLRowMajorMatrix(QualType type) const {
+  if (const auto *attrType = type->getAs<AttributedType>()) {
+    switch (auto kind = attrType->getAttrKind()) {
     case AttributedType::attr_hlsl_row_major:
       return true;
     case AttributedType::attr_hlsl_column_major:
       return false;
+      break;
     default:
-      // Only oriented matrices are relevant.
       break;
     }
   }
-
+  if (const auto *typedefType = type->getAs<TypedefType>()) {
+    return isHLSLRowMajorMatrix(typedefType->desugar());
+  }
+  if (const auto *arrayType = astContext.getAsArrayType(type)) {
+    return isHLSLRowMajorMatrix(arrayType->getElementType());
+  }
   return getCodeGenOptions().defaultRowMajor;
 }
 
+llvm::Optional<bool> LowerTypeVisitor::isRowMajorMatrix(QualType type) const {
+  // Row/Col majorness only applies to matrices or array of matrices.
+  if (!isMatrixOrArrayOfMatrix(astContext, type))
+    return llvm::None;
+
+  const auto hlslRowMajor = isHLSLRowMajorMatrix(type);
+  if (!hlslRowMajor.hasValue())
+    return hlslRowMajor;
+
+  return !hlslRowMajor.getValue();
+}
+
 } // namespace spirv
 } // namespace clang

+ 25 - 5
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -137,26 +137,32 @@ const VectorType *SpirvContext::getVectorType(const SpirvType *elemType,
   return vecTypes[scalarType][count] = new (this) VectorType(scalarType, count);
 }
 
-const MatrixType *SpirvContext::getMatrixType(const SpirvType *elemType,
-                                              uint32_t count, bool isRowMajor) {
+const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
+                                             uint32_t count) {
   // We are certain this should be a vector type. Otherwise, cast causes an
   // assertion failure.
   const VectorType *vecType = cast<VectorType>(elemType);
   assert(count == 2 || count == 3 || count == 4);
 
+  // In the case of non-floating-point matrices, we represent them as array of
+  // vectors.
+  if (!isa<FloatType>(vecType->getElementType())) {
+    return getArrayType(elemType, count, llvm::None);
+  }
+
   auto foundVec = matTypes.find(vecType);
 
   if (foundVec != matTypes.end()) {
     const auto &matVector = foundVec->second;
     // Create a temporary object for finding in the vector.
-    MatrixType type(vecType, count, isRowMajor);
+    MatrixType type(vecType, count);
 
     for (const auto *cachedType : matVector)
       if (type == *cachedType)
         return cachedType;
   }
 
-  const auto *ptr = new (this) MatrixType(vecType, count, isRowMajor);
+  const auto *ptr = new (this) MatrixType(vecType, count);
 
   matTypes[vecType].push_back(ptr);
 
@@ -211,7 +217,20 @@ SpirvContext::getSampledImageType(QualType image) {
 }
 
 const ArrayType *SpirvContext::getArrayType(const SpirvType *elemType,
-                                            uint32_t elemCount) {
+                                            uint32_t elemCount,
+                                            llvm::Optional<bool> rowMajorElem) {
+  ArrayType type(elemType, elemCount, rowMajorElem);
+  auto found = std::find_if(
+      arrayTypes.begin(), arrayTypes.end(),
+      [&type](const ArrayType *cachedType) { return type == *cachedType; });
+
+  if (found != arrayTypes.end())
+    return *found;
+
+  arrayTypes.push_back(new (this) ArrayType(elemType, elemCount, rowMajorElem));
+  return arrayTypes.back();
+
+  /*
   auto foundElemType = arrayTypes.find(elemType);
 
   if (foundElemType != arrayTypes.end()) {
@@ -224,6 +243,7 @@ const ArrayType *SpirvContext::getArrayType(const SpirvType *elemType,
 
   return arrayTypes[elemType][elemCount] =
              new (this) ArrayType(elemType, elemCount);
+  */
 }
 
 const RuntimeArrayType *

+ 5 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -912,10 +912,9 @@ SpirvInstruction *SPIRVEmitter::loadIfGLValue(const Expr *expr,
             astContext.getExtVectorType(uintType, numCols);
         const auto boolRowQualType =
             astContext.getExtVectorType(boolType, numCols);
-        // TODO(ehsan): Verify the isRowMajor argument.
         const SpirvType *resultType = spvContext.getMatrixType(
             spvContext.getVectorType(spvContext.getBoolType(), numCols),
-            numRows, /*isRowMajor*/ false);
+            numRows);
 
         llvm::SmallVector<SpirvInstruction *, 4> rows;
         for (uint32_t i = 0; i < numRows; ++i) {
@@ -4467,7 +4466,8 @@ SPIRVEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
       indices.push_back(col);
 
     if (baseExpr->isGLValue()) {
-      llvm::SmallVector<SpirvInstruction *, 2> indexInstructions;
+      llvm::SmallVector<SpirvInstruction *, 2> indexInstructions(indices.size(),
+                                                                 nullptr);
       for (uint32_t i = 0; i < indices.size(); ++i)
         indexInstructions[i] = spvBuilder.getConstantInt32(indices[i]);
 
@@ -5776,7 +5776,6 @@ SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
     accessor.GetPosition(i, &row, &col);
 
     llvm::SmallVector<uint32_t, 2> indices;
-    llvm::SmallVector<SpirvInstruction *, 2> indexInstructions;
     // If the matrix only have one row/column, we are indexing into a vector
     // then. Only one index is needed for such cases.
     if (rowCount > 1)
@@ -5784,6 +5783,8 @@ SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
     if (colCount > 1)
       indices.push_back(col);
 
+    llvm::SmallVector<SpirvInstruction *, 2> indexInstructions(indices.size(),
+                                                               nullptr);
     for (uint32_t i = 0; i < indices.size(); ++i)
       indexInstructions[i] = spvBuilder.getConstantInt32(indices[i]);
 

+ 8 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -126,6 +126,9 @@ SpirvComposite *SpirvBuilder::createCompositeConstruct(
   auto *instruction =
       new (context) SpirvComposite(resultType, /*id*/ 0, loc, constituents);
   insertPoint->addInstruction(instruction);
+  if (!constituents.empty()) {
+    instruction->setLayoutRule(constituents[0]->getLayoutRule());
+  }
   return instruction;
 }
 
@@ -136,6 +139,9 @@ SpirvComposite *SpirvBuilder::createCompositeConstruct(
   auto *instruction = new (context)
       SpirvComposite(/*QualType*/ {}, /*id*/ 0, loc, constituents);
   instruction->setResultType(resultType);
+  if (!constituents.empty()) {
+    instruction->setLayoutRule(constituents[0]->getLayoutRule());
+  }
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -181,6 +187,7 @@ SpirvLoad *SpirvBuilder::createLoad(QualType resultType,
       new (context) SpirvLoad(resultType, /*id*/ 0, loc, pointer);
   instruction->setStorageClass(pointer->getStorageClass());
   instruction->setRValue();
+  instruction->setLayoutRule(pointer->getLayoutRule());
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -194,6 +201,7 @@ SpirvLoad *SpirvBuilder::createLoad(const SpirvType *resultType,
   instruction->setResultType(resultType);
   instruction->setStorageClass(pointer->getStorageClass());
   instruction->setRValue();
+  instruction->setLayoutRule(pointer->getLayoutRule());
   insertPoint->addInstruction(instruction);
   return instruction;
 }

+ 23 - 7
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -128,14 +128,20 @@ bool SpirvType::isOrContains16BitType(const SpirvType *type) {
   return false;
 }
 
-MatrixType::MatrixType(const VectorType *vecType, uint32_t vecCount,
-                       bool rowMajor)
-    : SpirvType(TK_Matrix), vectorType(vecType), vectorCount(vecCount),
-      isRowMajor(rowMajor) {}
+bool SpirvType::isMatrixOrArrayOfMatrix(const SpirvType *type) {
+  if (isa<MatrixType>(type))
+    return true;
+  if (const auto *arrayType = dyn_cast<ArrayType>(type))
+    return isMatrixOrArrayOfMatrix(arrayType->getElementType());
+
+  return false;
+}
+
+MatrixType::MatrixType(const VectorType *vecType, uint32_t vecCount)
+    : SpirvType(TK_Matrix), vectorType(vecType), vectorCount(vecCount) {}
 
 bool MatrixType::operator==(const MatrixType &that) const {
-  return vectorType == that.vectorType && vectorCount == that.vectorCount &&
-         isRowMajor == that.isRowMajor;
+  return vectorType == that.vectorType && vectorCount == that.vectorCount;
 }
 
 ImageType::ImageType(const NumericalType *type, spv::Dim dim, WithDepth depth,
@@ -183,6 +189,13 @@ bool ImageType::operator==(const ImageType &that) const {
          isSampled == that.isSampled && imageFormat == that.imageFormat;
 }
 
+bool ArrayType::operator==(const ArrayType &that) const {
+  return elementType == that.elementType && elementCount == that.elementCount &&
+         rowMajorElem.hasValue() == that.rowMajorElem.hasValue() &&
+         (!rowMajorElem.hasValue() ||
+          rowMajorElem.getValue() == that.rowMajorElem.getValue());
+}
+
 StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
                        llvm::StringRef name, bool isReadOnly,
                        StructInterfaceType iface)
@@ -192,7 +205,10 @@ StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
 bool StructType::FieldInfo::
 operator==(const StructType::FieldInfo &that) const {
   return type == that.type && vkOffsetAttr == that.vkOffsetAttr &&
-         packOffsetAttr == that.packOffsetAttr;
+         packOffsetAttr == that.packOffsetAttr &&
+         isRowMajor.hasValue() == that.isRowMajor.hasValue() &&
+         (!isRowMajor.hasValue() ||
+          isRowMajor.getValue() == that.isRowMajor.getValue());
 }
 
 bool StructType::operator==(const StructType &that) const {

+ 3 - 3
tools/clang/test/CodeGenSPIRV/vk.layout.cbuffer.boolean.hlsl

@@ -27,13 +27,13 @@ cbuffer CONSTANTS
   FrameConstants frameConstants;
 };
 
-// CHECK: [[v3uint0:%\d+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
-// CHECK: [[v2uint0:%\d+]] = OpConstantComposite %v2uint %uint_0 %uint_0
-
 // These are the types that hold SPIR-V booleans, rather than Uints.
 // CHECK:              %T_0 = OpTypeStruct %_arr_bool_uint_1
 // CHECK: %FrameConstants_0 = OpTypeStruct %bool %v3bool %_arr_v3bool_uint_2 %T_0
 
+// CHECK: [[v3uint0:%\d+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
+// CHECK: [[v2uint0:%\d+]] = OpConstantComposite %v2uint %uint_0 %uint_0
+
 float4 main(in float4 texcoords : TEXCOORD0) : SV_TARGET
 {
 // CHECK:      [[FrameConstants:%\d+]] = OpAccessChain %_ptr_Uniform_FrameConstants %CONSTANTS %int_0

+ 3 - 3
tools/clang/unittests/SPIRV/SpirvConstantTest.cpp

@@ -158,7 +158,7 @@ TEST(SpirvConstant, CheckOperatorEqualOnComposite) {
   SpirvContext ctx;
   const FloatType *f32Type = ctx.getFloatType(32);
   const VectorType *vecType = ctx.getVectorType(f32Type, 4);
-  const ArrayType *arrType = ctx.getArrayType(vecType, 2);
+  const ArrayType *arrType = ctx.getArrayType(vecType, 2, llvm::None);
   SpirvConstantFloat f1(f32Type, 3.14);
   SpirvConstantFloat f2(f32Type, 5.f);
   SpirvConstantFloat f3(f32Type, -1.f);
@@ -177,7 +177,7 @@ TEST(SpirvConstant, CheckOperatorEqualOnComposite2) {
   SpirvContext ctx;
   const FloatType *f32Type = ctx.getFloatType(32);
   const VectorType *vecType = ctx.getVectorType(f32Type, 4);
-  const ArrayType *arrType = ctx.getArrayType(vecType, 1);
+  const ArrayType *arrType = ctx.getArrayType(vecType, 1, llvm::None);
   SpirvConstantFloat f1(f32Type, 3.14);
   SpirvConstantFloat f2(f32Type, 5.f);
   SpirvConstantFloat f3(f32Type, -1.f);
@@ -221,7 +221,7 @@ TEST(SpirvConstant, CompositeConstNotEqualSpecConstComposite) {
   SpirvContext ctx;
   const FloatType *f32Type = ctx.getFloatType(32);
   const VectorType *vecType = ctx.getVectorType(f32Type, 4);
-  const ArrayType *arrType = ctx.getArrayType(vecType, 2);
+  const ArrayType *arrType = ctx.getArrayType(vecType, 2, llvm::None);
   SpirvConstantFloat f1(f32Type, 3.14);
   SpirvConstantFloat f2(f32Type, 5.f);
   SpirvConstantFloat f3(f32Type, -1.f);