浏览代码

[spirv] Revamp matrix majorness handling (#1144)

Previously we use an additional parameter to translateType()
to convey the matrix majorness info. It causes lots of type
inconsistency issue.

Now the majorness info is queried directly from the QualType.
This should be more robust.
Lei Zhang 7 年之前
父节点
当前提交
5d557ccb34

+ 6 - 13
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -304,7 +304,7 @@ SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
           cast<VarDecl>(decl)->getType(),
           cast<VarDecl>(decl)->getType(),
           // We need to set decorateLayout here to avoid creating SPIR-V
           // We need to set decorateLayout here to avoid creating SPIR-V
           // instructions for the current type without decorations.
           // instructions for the current type without decorations.
-          info->info.getLayoutRule(), info->info.isRowMajor());
+          info->info.getLayoutRule());
 
 
       const uint32_t elemId = theBuilder.createAccessChain(
       const uint32_t elemId = theBuilder.createAccessChain(
           theBuilder.getPointerType(varType, info->info.getStorageClass()),
           theBuilder.getPointerType(varType, info->info.getStorageClass()),
@@ -452,10 +452,10 @@ uint32_t DeclResultIdMapper::getMatrixStructType(const VarDecl *matVar,
 
 
   auto &context = *theBuilder.getSPIRVContext();
   auto &context = *theBuilder.getSPIRVContext();
   llvm::SmallVector<const Decoration *, 4> decorations;
   llvm::SmallVector<const Decoration *, 4> decorations;
-  const bool isRowMajor = typeTranslator.isRowMajorMatrix(matType, matVar);
+  const bool isRowMajor = typeTranslator.isRowMajorMatrix(matType);
 
 
   uint32_t stride;
   uint32_t stride;
-  (void)typeTranslator.getAlignmentAndSize(matType, rule, isRowMajor, &stride);
+  (void)typeTranslator.getAlignmentAndSize(matType, rule, &stride);
   decorations.push_back(Decoration::getOffset(context, 0, 0));
   decorations.push_back(Decoration::getOffset(context, 0, 0));
   decorations.push_back(Decoration::getMatrixStride(context, stride, 0));
   decorations.push_back(Decoration::getMatrixStride(context, stride, 0));
   decorations.push_back(isRowMajor ? Decoration::getColMajor(context, 0)
   decorations.push_back(isRowMajor ? Decoration::getColMajor(context, 0)
@@ -521,9 +521,7 @@ uint32_t DeclResultIdMapper::createVarOfExplicitLayoutStruct(
     auto varType = declDecl->getType();
     auto varType = declDecl->getType();
     varType.removeLocalConst();
     varType.removeLocalConst();
 
 
-    const bool isRowMajor = typeTranslator.isRowMajorMatrix(varType, declDecl);
-    fieldTypes.push_back(
-        typeTranslator.translateType(varType, layoutRule, isRowMajor));
+    fieldTypes.push_back(typeTranslator.translateType(varType, layoutRule));
     fieldNames.push_back(declDecl->getName());
     fieldNames.push_back(declDecl->getName());
 
 
     // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
     // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
@@ -570,14 +568,11 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
       continue;
       continue;
 
 
     const auto *varDecl = cast<VarDecl>(subDecl);
     const auto *varDecl = cast<VarDecl>(subDecl);
-    const bool isRowMajor =
-        typeTranslator.isRowMajorMatrix(varDecl->getType(), varDecl);
     astDecls[varDecl] =
     astDecls[varDecl] =
         SpirvEvalInfo(bufferVar)
         SpirvEvalInfo(bufferVar)
             .setStorageClass(spv::StorageClass::Uniform)
             .setStorageClass(spv::StorageClass::Uniform)
             .setLayoutRule(decl->isCBuffer() ? LayoutRule::GLSLStd140
             .setLayoutRule(decl->isCBuffer() ? LayoutRule::GLSLStd140
-                                             : LayoutRule::GLSLStd430)
-            .setRowMajor(isRowMajor);
+                                             : LayoutRule::GLSLStd430);
     astDecls[varDecl].indexInCTBuffer = index++;
     astDecls[varDecl].indexInCTBuffer = index++;
   }
   }
   resourceVars.emplace_back(
   resourceVars.emplace_back(
@@ -664,9 +659,7 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
 
 
       astDecls[varDecl] = SpirvEvalInfo(globals)
       astDecls[varDecl] = SpirvEvalInfo(globals)
                               .setStorageClass(spv::StorageClass::Uniform)
                               .setStorageClass(spv::StorageClass::Uniform)
-                              .setLayoutRule(LayoutRule::GLSLStd140)
-                              .setRowMajor(typeTranslator.isRowMajorMatrix(
-                                  varDecl->getType(), varDecl));
+                              .setLayoutRule(LayoutRule::GLSLStd140);
       astDecls[varDecl].indexInCTBuffer = index++;
       astDecls[varDecl].indexInCTBuffer = index++;
     }
     }
 }
 }

+ 6 - 8
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -825,8 +825,8 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
   if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
   if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
     valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
     valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
   } else {
   } else {
-    valType = typeTranslator.translateType(
-        expr->getType(), info.getLayoutRule(), info.isRowMajor());
+    valType =
+        typeTranslator.translateType(expr->getType(), info.getLayoutRule());
   }
   }
   return info.setResultId(theBuilder.createLoad(valType, info)).setRValue();
   return info.setResultId(theBuilder.createLoad(valType, info)).setRValue();
 }
 }
@@ -2551,7 +2551,7 @@ uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
     // size of the struct) must also be written to the second argument.
     // size of the struct) must also be written to the second argument.
     uint32_t size = 0, stride = 0;
     uint32_t size = 0, stride = 0;
     std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
     std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
-        type, LayoutRule::GLSLStd430, /*isRowMajor*/ false, &stride);
+        type, LayoutRule::GLSLStd430, &stride);
     const auto sizeId = theBuilder.getConstantUint32(size);
     const auto sizeId = theBuilder.getConstantUint32(size);
     theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
     theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
   }
   }
@@ -4657,15 +4657,13 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
   } else if (const auto *recordType = lhsValType->getAs<RecordType>()) {
   } else if (const auto *recordType = lhsValType->getAs<RecordType>()) {
     uint32_t index = 0;
     uint32_t index = 0;
     for (const auto *field : recordType->getDecl()->fields()) {
     for (const auto *field : recordType->getDecl()->fields()) {
-      bool isRowMajor =
-          typeTranslator.isRowMajorMatrix(field->getType(), field);
       const auto subRhsValType = typeTranslator.translateType(
       const auto subRhsValType = typeTranslator.translateType(
-          field->getType(), rhsVal.getLayoutRule(), isRowMajor);
+          field->getType(), rhsVal.getLayoutRule());
       const auto subRhsVal =
       const auto subRhsVal =
           theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
           theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
       const auto subLhsPtrType = theBuilder.getPointerType(
       const auto subLhsPtrType = theBuilder.getPointerType(
-          typeTranslator.translateType(field->getType(), lhsPtr.getLayoutRule(),
-                                       isRowMajor),
+          typeTranslator.translateType(field->getType(),
+                                       lhsPtr.getLayoutRule()),
           lhsPtr.getStorageClass());
           lhsPtr.getStorageClass());
       const auto subLhsPtr = theBuilder.createAccessChain(
       const auto subLhsPtr = theBuilder.createAccessChain(
           subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});
           subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});

+ 1 - 10
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -100,9 +100,6 @@ public:
   inline SpirvEvalInfo &setRelaxedPrecision();
   inline SpirvEvalInfo &setRelaxedPrecision();
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
 
 
-  inline SpirvEvalInfo &setRowMajor(bool);
-  bool isRowMajor() const { return isRowMajor_; }
-
 private:
 private:
   uint32_t resultId;
   uint32_t resultId;
   /// Indicates whether this evaluation result contains alias variables
   /// Indicates whether this evaluation result contains alias variables
@@ -122,14 +119,13 @@ private:
   bool isConstant_;
   bool isConstant_;
   bool isSpecConstant_;
   bool isSpecConstant_;
   bool isRelaxedPrecision_;
   bool isRelaxedPrecision_;
-  bool isRowMajor_;
 };
 };
 
 
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
     : resultId(id), containsAlias(false),
     : resultId(id), containsAlias(false),
       storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void),
       storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void),
       isRValue_(false), isConstant_(false), isSpecConstant_(false),
       isRValue_(false), isConstant_(false), isSpecConstant_(false),
-      isRelaxedPrecision_(false), isRowMajor_(false) {}
+      isRelaxedPrecision_(false) {}
 
 
 SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
 SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
   resultId = id;
   resultId = id;
@@ -178,11 +174,6 @@ SpirvEvalInfo &SpirvEvalInfo::setRelaxedPrecision() {
   return *this;
   return *this;
 }
 }
 
 
-SpirvEvalInfo &SpirvEvalInfo::setRowMajor(bool rm) {
-  isRowMajor_ = rm;
-  return *this;
-}
-
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang
 
 

+ 88 - 56
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -41,6 +41,28 @@ bool improperStraddle(QualType type, int size, int offset) {
                     : offset % 16 != 0;
                     : offset % 16 != 0;
 }
 }
 
 
+// From https://github.com/Microsoft/DirectXShaderCompiler/pull/1032.
+// TODO: use that after it is landed.
+bool hasHLSLMatOrientation(QualType type, bool *pIsRowMajor) {
+  const AttributedType *AT = type->getAs<AttributedType>();
+  while (AT) {
+    AttributedType::Kind kind = AT->getAttrKind();
+    switch (kind) {
+    case AttributedType::attr_hlsl_row_major:
+      if (pIsRowMajor)
+        *pIsRowMajor = true;
+      return true;
+    case AttributedType::attr_hlsl_column_major:
+      if (pIsRowMajor)
+        *pIsRowMajor = false;
+      return true;
+    }
+    AT = AT->getLocallyUnqualifiedSingleStepDesugaredType()
+             ->getAs<AttributedType>();
+  }
+  return false;
+}
+
 } // anonymous namespace
 } // anonymous namespace
 
 
 bool TypeTranslator::isRelaxedPrecisionType(QualType type,
 bool TypeTranslator::isRelaxedPrecisionType(QualType type,
@@ -418,18 +440,14 @@ uint32_t TypeTranslator::getElementSpirvBitwidth(QualType type) {
   llvm_unreachable("invalid type passed to getElementSpirvBitwidth");
   llvm_unreachable("invalid type passed to getElementSpirvBitwidth");
 }
 }
 
 
-uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
-                                       bool isRowMajor) {
-  // We can only apply row_major to matrices or arrays of matrices.
-  // isRowMajor will be ignored for scalar and vector types.
-  if (isRowMajor)
-    assert(type->isScalarType() || type->isArrayType() ||
-           hlsl::IsHLSLVecMatType(type));
-
-  // Try to translate the canonical type first
-  const auto canonicalType = type.getCanonicalType();
-  if (canonicalType != type)
-    return translateType(canonicalType, rule, isRowMajor);
+uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule) {
+  const auto desugaredType = desugarType(type);
+  if (desugaredType != type) {
+    const auto id = translateType(desugaredType, rule);
+    // Clear potentially set matrix majorness info
+    typeMatMajorAttr = llvm::None;
+    return id;
+  }
 
 
   // Primitive types
   // Primitive types
   {
   {
@@ -475,10 +493,6 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
     }
     }
   }
   }
 
 
-  // Typedefs
-  if (const auto *typedefType = type->getAs<TypedefType>())
-    return translateType(typedefType->desugar(), rule, isRowMajor);
-
   // Reference types
   // Reference types
   if (const auto *refType = type->getAs<ReferenceType>()) {
   if (const auto *refType = type->getAs<ReferenceType>()) {
     // Note: Pointer/reference types are disallowed in HLSL source code.
     // Note: Pointer/reference types are disallowed in HLSL source code.
@@ -487,13 +501,13 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
     // We already pass function arguments via pointers to tempoary local
     // We already pass function arguments via pointers to tempoary local
     // variables. So it should be fine to drop the pointer type and treat it
     // variables. So it should be fine to drop the pointer type and treat it
     // as the underlying pointee type here.
     // as the underlying pointee type here.
-    return translateType(refType->getPointeeType(), rule, isRowMajor);
+    return translateType(refType->getPointeeType(), rule);
   }
   }
 
 
   // Pointer types
   // Pointer types
   if (const auto *ptrType = type->getAs<PointerType>()) {
   if (const auto *ptrType = type->getAs<PointerType>()) {
     // The this object in a struct member function is of pointer type.
     // The this object in a struct member function is of pointer type.
-    return translateType(ptrType->getPointeeType(), rule, isRowMajor);
+    return translateType(ptrType->getPointeeType(), rule);
   }
   }
 
 
   // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
   // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
@@ -522,7 +536,7 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
       llvm::SmallVector<const Decoration *, 4> decorations;
       llvm::SmallVector<const Decoration *, 4> decorations;
       if (!elemType->isFloatingType() && rule != LayoutRule::Void) {
       if (!elemType->isFloatingType() && rule != LayoutRule::Void) {
         uint32_t stride = 0;
         uint32_t stride = 0;
-        (void)getAlignmentAndSize(type, rule, isRowMajor, &stride);
+        (void)getAlignmentAndSize(type, rule, &stride);
         decorations.push_back(
         decorations.push_back(
             Decoration::getArrayStride(*theBuilder.getSPIRVContext(), stride));
             Decoration::getArrayStride(*theBuilder.getSPIRVContext(), stride));
       }
       }
@@ -556,8 +570,7 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
 
 
     // Create fields for all members of this struct
     // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
     for (const auto *field : decl->fields()) {
-      fieldTypes.push_back(translateType(
-          field->getType(), rule, isRowMajorMatrix(field->getType(), field)));
+      fieldTypes.push_back(translateType(field->getType(), rule));
       fieldNames.push_back(field->getName());
       fieldNames.push_back(field->getName());
     }
     }
 
 
@@ -571,8 +584,7 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
   }
   }
 
 
   if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
   if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
-    const uint32_t elemType =
-        translateType(arrayType->getElementType(), rule, isRowMajor);
+    const uint32_t elemType = translateType(arrayType->getElementType(), rule);
     // TODO: handle extra large array size?
     // TODO: handle extra large array size?
     const auto size =
     const auto size =
         static_cast<uint32_t>(arrayType->getSize().getZExtValue());
         static_cast<uint32_t>(arrayType->getSize().getZExtValue());
@@ -580,7 +592,7 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
     llvm::SmallVector<const Decoration *, 4> decorations;
     llvm::SmallVector<const Decoration *, 4> decorations;
     if (rule != LayoutRule::Void) {
     if (rule != LayoutRule::Void) {
       uint32_t stride = 0;
       uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, isRowMajor, &stride);
+      (void)getAlignmentAndSize(type, rule, &stride);
       decorations.push_back(
       decorations.push_back(
           Decoration::getArrayStride(*theBuilder.getSPIRVContext(), stride));
           Decoration::getArrayStride(*theBuilder.getSPIRVContext(), stride));
     }
     }
@@ -964,19 +976,23 @@ bool TypeTranslator::isResourceType(const ValueDecl *decl) {
   return hlsl::IsHLSLResourceType(declType);
   return hlsl::IsHLSLResourceType(declType);
 }
 }
 
 
-bool TypeTranslator::isRowMajorMatrix(QualType type, const Decl *decl) const {
-  if (!isMxNMatrix(type) && !type->isArrayType())
-    return false;
+bool TypeTranslator::isRowMajorMatrix(QualType type) const {
+  // The type passed in may not be desugared. Check attributes on itself first.
+  bool attrRowMajor = false;
+  if (hasHLSLMatOrientation(type, &attrRowMajor))
+    return attrRowMajor;
 
 
-  if (const auto *arrayType = astContext.getAsConstantArrayType(type))
-    if (!isMxNMatrix(arrayType->getElementType()))
+  // Use the majorness info we recorded before.
+  if (typeMatMajorAttr.hasValue()) {
+    switch (typeMatMajorAttr.getValue()) {
+    case AttributedType::attr_hlsl_row_major:
+      return true;
+    case AttributedType::attr_hlsl_column_major:
       return false;
       return false;
+    }
+  }
 
 
-  if (!decl)
-    return spirvOptions.defaultRowMajor;
-
-  return decl->hasAttr<HLSLRowMajorAttr>() ||
-         !decl->hasAttr<HLSLColumnMajorAttr>() && spirvOptions.defaultRowMajor;
+  return spirvOptions.defaultRowMajor;
 }
 }
 
 
 bool TypeTranslator::canTreatAsSameScalarType(QualType type1, QualType type2) {
 bool TypeTranslator::canTreatAsSameScalarType(QualType type1, QualType type2) {
@@ -1111,11 +1127,9 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule,
         (!declDecl->hasExternalFormalLinkage() || isResourceType(declDecl)))
         (!declDecl->hasExternalFormalLinkage() || isResourceType(declDecl)))
       continue;
       continue;
 
 
-    const bool isRowMajor = isRowMajorMatrix(fieldType, field);
-
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
     std::tie(memberAlignment, memberSize) =
     std::tie(memberAlignment, memberSize) =
-        getAlignmentAndSize(fieldType, rule, isRowMajor, &stride);
+        getAlignmentAndSize(fieldType, rule, &stride);
 
 
     alignUsingHLSLRelaxedLayout(fieldType, memberSize, &memberAlignment,
     alignUsingHLSLRelaxedLayout(fieldType, memberSize, &memberAlignment,
                                 &offset);
                                 &offset);
@@ -1144,7 +1158,7 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule,
     if (isMxNMatrix(fieldType, &elemType) && elemType->isFloatingType()) {
     if (isMxNMatrix(fieldType, &elemType) && elemType->isFloatingType()) {
       memberAlignment = memberSize = stride = 0;
       memberAlignment = memberSize = stride = 0;
       std::tie(memberAlignment, memberSize) =
       std::tie(memberAlignment, memberSize) =
-          getAlignmentAndSize(fieldType, rule, isRowMajor, &stride);
+          getAlignmentAndSize(fieldType, rule, &stride);
 
 
       decorations.push_back(
       decorations.push_back(
           Decoration::getMatrixStride(*spirvContext, stride, index));
           Decoration::getMatrixStride(*spirvContext, stride, index));
@@ -1152,7 +1166,7 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule,
       // We need to swap the RowMajor and ColMajor decorations since HLSL
       // We need to swap the RowMajor and ColMajor decorations since HLSL
       // matrices are conceptually row-major while SPIR-V are conceptually
       // matrices are conceptually row-major while SPIR-V are conceptually
       // column-major.
       // column-major.
-      if (isRowMajor) {
+      if (isRowMajorMatrix(fieldType)) {
         decorations.push_back(Decoration::getColMajor(*spirvContext, index));
         decorations.push_back(Decoration::getColMajor(*spirvContext, index));
       } else {
       } else {
         // If the source code has neither row_major nor column_major annotated,
         // If the source code has neither row_major nor column_major annotated,
@@ -1249,8 +1263,7 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) {
 
 
     // The stride for the runtime array is the size of S.
     // The stride for the runtime array is the size of S.
     uint32_t size = 0, stride = 0;
     uint32_t size = 0, stride = 0;
-    std::tie(std::ignore, size) =
-        getAlignmentAndSize(s, rule, isRowMajor, &stride);
+    std::tie(std::ignore, size) = getAlignmentAndSize(s, rule, &stride);
     decorations.push_back(Decoration::getArrayStride(context, size));
     decorations.push_back(Decoration::getArrayStride(context, size));
     const uint32_t raType =
     const uint32_t raType =
         theBuilder.getRuntimeArrayType(structType, decorations);
         theBuilder.getRuntimeArrayType(structType, decorations);
@@ -1384,7 +1397,7 @@ void TypeTranslator::alignUsingHLSLRelaxedLayout(QualType fieldType,
     if (fieldIsVecType = isVectorType(fieldType, &vecElemType)) {
     if (fieldIsVecType = isVectorType(fieldType, &vecElemType)) {
       uint32_t scalarAlignment = 0;
       uint32_t scalarAlignment = 0;
       std::tie(scalarAlignment, std::ignore) =
       std::tie(scalarAlignment, std::ignore) =
-          getAlignmentAndSize(vecElemType, LayoutRule::Void, false, nullptr);
+          getAlignmentAndSize(vecElemType, LayoutRule::Void, nullptr);
       if (scalarAlignment <= 4)
       if (scalarAlignment <= 4)
         *fieldAlignment = scalarAlignment;
         *fieldAlignment = scalarAlignment;
     }
     }
@@ -1403,7 +1416,7 @@ void TypeTranslator::alignUsingHLSLRelaxedLayout(QualType fieldType,
 
 
 std::pair<uint32_t, uint32_t>
 std::pair<uint32_t, uint32_t>
 TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
 TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
-                                    const bool isRowMajor, uint32_t *stride) {
+                                    uint32_t *stride) {
   // std140 layout rules:
   // std140 layout rules:
 
 
   // 1. If the member is a scalar consuming N basic machine units, the base
   // 1. If the member is a scalar consuming N basic machine units, the base
@@ -1451,13 +1464,14 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
   //
   //
   // 10. If the member is an array of S structures, the S elements of the array
   // 10. If the member is an array of S structures, the S elements of the array
   //     are laid out in order, according to rule (9).
   //     are laid out in order, according to rule (9).
-  const auto canonicalType = type.getCanonicalType();
-  if (canonicalType != type)
-    return getAlignmentAndSize(canonicalType, rule, isRowMajor, stride);
 
 
-  if (const auto *typedefType = type->getAs<TypedefType>())
-    return getAlignmentAndSize(typedefType->desugar(), rule, isRowMajor,
-                               stride);
+  const auto desugaredType = desugarType(type);
+  if (desugaredType != type) {
+    const auto id = getAlignmentAndSize(desugaredType, rule, stride);
+    // Clear potentially set matrix majorness info
+    typeMatMajorAttr = llvm::None;
+    return id;
+  }
 
 
   { // Rule 1
   { // Rule 1
     QualType ty = {};
     QualType ty = {};
@@ -1487,8 +1501,7 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
     uint32_t elemCount = {};
     uint32_t elemCount = {};
     if (isVectorType(type, &elemType, &elemCount)) {
     if (isVectorType(type, &elemType, &elemCount)) {
       uint32_t size = 0;
       uint32_t size = 0;
-      std::tie(std::ignore, size) =
-          getAlignmentAndSize(elemType, rule, isRowMajor, stride);
+      std::tie(std::ignore, size) = getAlignmentAndSize(elemType, rule, stride);
 
 
       return {(elemCount == 3 ? 4 : elemCount) * size, elemCount * size};
       return {(elemCount == 3 ? 4 : elemCount) * size, elemCount * size};
     }
     }
@@ -1500,12 +1513,14 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
     if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
     if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
       uint32_t alignment = 0, size = 0;
       uint32_t alignment = 0, size = 0;
       std::tie(alignment, std::ignore) =
       std::tie(alignment, std::ignore) =
-          getAlignmentAndSize(elemType, rule, isRowMajor, stride);
+          getAlignmentAndSize(elemType, rule, stride);
 
 
       // Matrices are treated as arrays of vectors:
       // Matrices are treated as arrays of vectors:
       // The base alignment and array stride are set to match the base alignment
       // The base alignment and array stride are set to match the base alignment
       // of a single array element, according to rules 1, 2, and 3, and rounded
       // of a single array element, according to rules 1, 2, and 3, and rounded
       // up to the base alignment of a vec4.
       // up to the base alignment of a vec4.
+      bool isRowMajor = isRowMajorMatrix(type);
+
       const uint32_t vecStorageSize = isRowMajor ? colCount : rowCount;
       const uint32_t vecStorageSize = isRowMajor ? colCount : rowCount;
       alignment *= (vecStorageSize == 3 ? 4 : vecStorageSize);
       alignment *= (vecStorageSize == 3 ? 4 : vecStorageSize);
       if (rule == LayoutRule::GLSLStd140) {
       if (rule == LayoutRule::GLSLStd140) {
@@ -1530,9 +1545,8 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
 
 
     for (const auto *field : structType->getDecl()->fields()) {
     for (const auto *field : structType->getDecl()->fields()) {
       uint32_t memberAlignment = 0, memberSize = 0;
       uint32_t memberAlignment = 0, memberSize = 0;
-      const bool isRowMajor = isRowMajorMatrix(field->getType(), field);
       std::tie(memberAlignment, memberSize) =
       std::tie(memberAlignment, memberSize) =
-          getAlignmentAndSize(field->getType(), rule, isRowMajor, stride);
+          getAlignmentAndSize(field->getType(), rule, stride);
 
 
       alignUsingHLSLRelaxedLayout(field->getType(), memberSize,
       alignUsingHLSLRelaxedLayout(field->getType(), memberSize,
                                   &memberAlignment, &structSize);
                                   &memberAlignment, &structSize);
@@ -1556,8 +1570,8 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
   // Rule 4, 6, 8, and 10
   // Rule 4, 6, 8, and 10
   if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
   if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
     uint32_t alignment = 0, size = 0;
     uint32_t alignment = 0, size = 0;
-    std::tie(alignment, size) = getAlignmentAndSize(arrayType->getElementType(),
-                                                    rule, isRowMajor, stride);
+    std::tie(alignment, size) =
+        getAlignmentAndSize(arrayType->getElementType(), rule, stride);
 
 
     if (rule == LayoutRule::GLSLStd140) {
     if (rule == LayoutRule::GLSLStd140) {
       // The base alignment and array stride are set to match the base alignment
       // The base alignment and array stride are set to match the base alignment
@@ -1623,5 +1637,23 @@ std::string TypeTranslator::getName(QualType type) {
   return "";
   return "";
 }
 }
 
 
+QualType TypeTranslator::desugarType(QualType type) {
+  if (const auto *attrType = type->getAs<AttributedType>()) {
+    switch (auto kind = attrType->getAttrKind()) {
+    case AttributedType::attr_hlsl_row_major:
+    case AttributedType::attr_hlsl_column_major:
+      typeMatMajorAttr = kind;
+    }
+    return desugarType(
+        attrType->getLocallyUnqualifiedSingleStepDesugaredType());
+  }
+
+  if (const auto *typedefType = type->getAs<TypedefType>()) {
+    return desugarType(typedefType->desugar());
+  }
+
+  return type;
+}
+
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang

+ 22 - 16
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -16,6 +16,7 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/ModuleBuilder.h"
+#include "llvm/ADT/Optional.h"
 
 
 #include "SpirvEvalInfo.h"
 #include "SpirvEvalInfo.h"
 
 
@@ -46,15 +47,13 @@ public:
   /// the error and returns 0. If decorateLayout is true, layout decorations
   /// the error and returns 0. If decorateLayout is true, layout decorations
   /// (Offset, MatrixStride, ArrayStride, RowMajor, ColMajor) will be attached
   /// (Offset, MatrixStride, ArrayStride, RowMajor, ColMajor) will be attached
   /// to the struct or array types. If layoutRule is not Void and type is a
   /// to the struct or array types. If layoutRule is not Void and type is a
-  /// matrix or array of matrix type, isRowMajor will indicate whether it is
-  /// decorated with row_major in the source code.
+  /// matrix or array of matrix type.
   ///
   ///
   /// The translation is recursive; all the types that the target type depends
   /// The translation is recursive; all the types that the target type depends
   /// on will be generated and all with layout decorations (if decorateLayout
   /// on will be generated and all with layout decorations (if decorateLayout
   /// is true).
   /// is true).
   uint32_t translateType(QualType type,
   uint32_t translateType(QualType type,
-                         LayoutRule layoutRule = LayoutRule::Void,
-                         bool isRowMajor = false);
+                         LayoutRule layoutRule = LayoutRule::Void);
 
 
   /// \brief Generates the SPIR-V type for the counter associated with a
   /// \brief Generates the SPIR-V type for the counter associated with a
   /// {Append|Consume}StructuredBuffer: an OpTypeStruct with a single 32-bit
   /// {Append|Consume}StructuredBuffer: an OpTypeStruct with a single 32-bit
@@ -178,9 +177,9 @@ public:
                           uint32_t *rowCount = nullptr,
                           uint32_t *rowCount = nullptr,
                           uint32_t *colCount = nullptr);
                           uint32_t *colCount = nullptr);
 
 
-  /// \brief Returns true if type is a matrix and matrix is row major
-  /// If decl is not nullptr, it is checked for attributes specifying majorness.
-  bool isRowMajorMatrix(QualType type, const Decl *decl = nullptr) const;
+  /// \brief Returns true if type is a row-major matrix, either with explicit
+  /// attribute or implicit command-line option.
+  bool isRowMajorMatrix(QualType type) const;
 
 
   /// \brief Returns true if the decl type is a non-floating-point matrix and
   /// \brief Returns true if the decl type is a non-floating-point matrix and
   /// the matrix is column major, or if it is an array/struct containing such
   /// the matrix is column major, or if it is an array/struct containing such
@@ -286,26 +285,21 @@ public:
   /// according to the given LayoutRule.
   /// according to the given LayoutRule.
 
 
   /// If the type is an array/matrix type, writes the array/matrix stride to
   /// If the type is an array/matrix type, writes the array/matrix stride to
-  /// stride. If the type is a matrix, isRowMajor will be used to indicate
-  /// whether it is labelled as row_major in the source code.
+  /// stride. If the type is a matrix.
   ///
   ///
   /// Note that the size returned is not exactly how many bytes the type
   /// Note that the size returned is not exactly how many bytes the type
   /// will occupy in memory; rather it is used in conjunction with alignment
   /// will occupy in memory; rather it is used in conjunction with alignment
   /// to get the next available location (alignment + size), which means
   /// to get the next available location (alignment + size), which means
   /// size contains post-paddings required by the given type.
   /// size contains post-paddings required by the given type.
-  std::pair<uint32_t, uint32_t> getAlignmentAndSize(QualType type,
-                                                    LayoutRule rule,
-                                                    bool isRowMajor,
-                                                    uint32_t *stride);
+  std::pair<uint32_t, uint32_t>
+  getAlignmentAndSize(QualType type, LayoutRule rule, uint32_t *stride);
 
 
-public:
   /// \brief If a hint exists regarding the usage of literal types, it
   /// \brief If a hint exists regarding the usage of literal types, it
   /// is returned. Otherwise, the given type itself is returned.
   /// is returned. Otherwise, the given type itself is returned.
   /// The hint is the type on top of the intendedLiteralTypes stack. This is the
   /// The hint is the type on top of the intendedLiteralTypes stack. This is the
   /// type we suspect the literal under question should be interpreted as.
   /// type we suspect the literal under question should be interpreted as.
   QualType getIntendedLiteralType(QualType type);
   QualType getIntendedLiteralType(QualType type);
 
 
-public:
   /// A RAII class for maintaining the intendedLiteralTypes stack.
   /// A RAII class for maintaining the intendedLiteralTypes stack.
   ///
   ///
   /// Instantiating an object of this class ensures that as long as the
   /// Instantiating an object of this class ensures that as long as the
@@ -334,7 +328,11 @@ private:
   /// \brief Removes the type at the top of the intendedLiteralTypes stack.
   /// \brief Removes the type at the top of the intendedLiteralTypes stack.
   void popIntendedLiteralType();
   void popIntendedLiteralType();
 
 
-private:
+  /// \brief Strip the attributes and typedefs fromthe given type and returns
+  /// the desugared one. This method will update internal bookkeeping regarding
+  /// matrix majorness.
+  QualType desugarType(QualType type);
+
   ASTContext &astContext;
   ASTContext &astContext;
   ModuleBuilder &theBuilder;
   ModuleBuilder &theBuilder;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
@@ -346,6 +344,14 @@ private:
   /// float; but if the top of the stack is a double type, the literal should be
   /// float; but if the top of the stack is a double type, the literal should be
   /// evaluated as a double.
   /// evaluated as a double.
   std::stack<QualType> intendedLiteralTypes;
   std::stack<QualType> intendedLiteralTypes;
+
+  /// \brief A place to keep the matrix majorness attributes so that we can
+  /// retrieve the information when really processing the desugared matrix type.
+  /// This is needed because the majorness attribute is decorated on a
+  /// TypedefType (i.e., floatMxN) of the real matrix type (i.e., matrix<elem,
+  /// row, col>). When we reach the desugared matrix type, this information will
+  /// already be gone.
+  llvm::Optional<AttributedType::Kind> typeMatMajorAttr;
 };
 };
 
 
 } // end namespace spirv
 } // end namespace spirv

+ 18 - 3
tools/clang/test/CodeGenSPIRV/op.cbuffer.access.majorness.hlsl

@@ -6,19 +6,28 @@ struct SData {
    column_major float3x4 mat2[2];
    column_major float3x4 mat2[2];
 };
 };
 
 
-// CHECK: %type_SBufferData = OpTypeStruct %SData %_arr_mat3v4float_uint_2 %_arr_mat3v4float_uint_2_0
+// CHECK: %type_SBufferData = OpTypeStruct %SData %_arr_mat3v4float_uint_2 %_arr_mat3v4float_uint_2_0 %mat3v4float
 cbuffer SBufferData {
 cbuffer SBufferData {
                 SData    BufferData;
                 SData    BufferData;
                 float3x4 Mat1[2];
                 float3x4 Mat1[2];
    column_major float3x4 Mat2[2];
    column_major float3x4 Mat2[2];
+   row_major    float3x4 Mat3;
 };
 };
 
 
 // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SData %SBufferData %int_0
 // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SData %SBufferData %int_0
 // CHECK: [[val:%\d+]] = OpLoad %SData [[ptr]]
 // CHECK: [[val:%\d+]] = OpLoad %SData [[ptr]]
-// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2 %32 0
-// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2_0 %32 1
+// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2 [[val]] 0
+// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2_0 [[val]] 1
 static const SData Data = BufferData;
 static const SData Data = BufferData;
 
 
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SData %SBufferData %int_0
+// CHECK:     {{%\d+}} = OpAccessChain %_ptr_Uniform__arr_mat3v4float_uint_2 [[ptr]] %int_0
+static const float3x4 Matrices[2] = BufferData.mat1;
+
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SData %SBufferData %int_0
+// CHECK:     {{%\d+}} = OpAccessChain %_ptr_Uniform_mat3v4float [[ptr]] %int_1 %int_1
+static const float3x4 Matrix = BufferData.mat2[1];
+
 RWStructuredBuffer<float4> Out;
 RWStructuredBuffer<float4> Out;
 
 
 [numthreads(4, 4, 4)]
 [numthreads(4, 4, 4)]
@@ -30,5 +39,11 @@ void main() {
 // CHECK:     {{%\d+}} = OpLoad %_arr_mat3v4float_uint_2_0 [[ptr]]
 // CHECK:     {{%\d+}} = OpLoad %_arr_mat3v4float_uint_2_0 [[ptr]]
   float3x4 b[2] = Mat2;
   float3x4 b[2] = Mat2;
 
 
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_mat3v4float_uint_2_0 %SBufferData %int_2
+// CHECK:     {{%\d+}} = OpAccessChain %_ptr_Uniform_mat3v4float [[ptr]] %int_1
+  float3x4 c = Mat2[1];
+// CHECK:     {{%\d+}} = OpAccessChain %_ptr_Uniform_mat3v4float %SBufferData %int_3
+  float3x4 d = Mat3;
+
   Out[0] = Data.mat1[0][0];
   Out[0] = Data.mat1[0][0];
 }
 }