Kaynağa Gözat

[spirv] Respect default matrix order set from command line(#961)

Sebastian Tafuri 7 yıl önce
ebeveyn
işleme
1eaf553cfb

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h

@@ -17,6 +17,7 @@ namespace clang {
 struct EmitSPIRVOptions {
   /// Disable legalization and optimization and emit raw SPIR-V
   bool codeGenHighLevel;
+  bool defaultRowMajor;
   bool disableValidation;
   bool ignoreUnusedResources;
   bool enable16BitTypes;

+ 3 - 2
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -396,8 +396,9 @@ uint32_t DeclResultIdMapper::createVarOfExplicitLayoutStruct(
     auto varType = declDecl->getType();
     varType.removeLocalConst();
 
-    fieldTypes.push_back(typeTranslator.translateType(
-        varType, layoutRule, declDecl->hasAttr<HLSLRowMajorAttr>()));
+    const bool isRowMajor = typeTranslator.isRowMajorMatrix(varType, declDecl);
+    fieldTypes.push_back(
+        typeTranslator.translateType(varType, layoutRule, isRowMajor));
     fieldNames.push_back(declDecl->getName());
 
     // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate

+ 15 - 5
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -342,8 +342,8 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
     llvm::SmallVector<uint32_t, 4> fieldTypes;
     llvm::SmallVector<llvm::StringRef, 4> fieldNames;
     for (const auto *field : decl->fields()) {
-      fieldTypes.push_back(translateType(field->getType(), rule,
-                                         field->hasAttr<HLSLRowMajorAttr>()));
+      fieldTypes.push_back(translateType(
+          field->getType(), rule, isRowMajorMatrix(field->getType(), field)));
       fieldNames.push_back(field->getName());
     }
 
@@ -671,6 +671,15 @@ bool TypeTranslator::isMxNMatrix(QualType type, QualType *elemType,
   return true;
 }
 
+bool TypeTranslator::isRowMajorMatrix(QualType type, const Decl *decl) const {
+  if (!isMxNMatrix(type) && !type->isArrayType())
+    return false;
+  if (!decl)
+    return spirvOptions.defaultRowMajor;
+  return decl->hasAttr<HLSLRowMajorAttr>() ||
+         !decl->hasAttr<HLSLColumnMajorAttr>() && spirvOptions.defaultRowMajor;
+}
+
 bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
   QualType elemType = {};
   return isMxNMatrix(type, &elemType) && elemType->isFloatingType();
@@ -724,7 +733,7 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule) {
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).
     auto fieldType = cast<DeclaratorDecl>(field)->getType();
-    const bool isRowMajor = field->hasAttr<HLSLRowMajorAttr>();
+    const bool isRowMajor = isRowMajorMatrix(fieldType, field);
 
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
     std::tie(memberAlignment, memberSize) =
@@ -1056,8 +1065,9 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
 
     for (const auto *field : structType->getDecl()->fields()) {
       uint32_t memberAlignment = 0, memberSize = 0;
-      std::tie(memberAlignment, memberSize) = getAlignmentAndSize(
-          field->getType(), rule, field->hasAttr<HLSLRowMajorAttr>(), stride);
+      const bool isRowMajor = isRowMajorMatrix(field->getType(), field);
+      std::tie(memberAlignment, memberSize) =
+          getAlignmentAndSize(field->getType(), rule, isRowMajor, stride);
 
       // The base alignment of the structure is N, where N is the largest
       // base alignment value of any of its members...

+ 4 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -153,6 +153,10 @@ public:
                           uint32_t *rowCount = nullptr,
                           uint32_t *colCount = nullptr);
 
+  /// \broef returns true if type is a matrix and matrix is row major
+  /// If decl is not nullptr, is is checked for attributes specifying majorness
+  bool isRowMajorMatrix(QualType type, const Decl *decl = nullptr) const;
+
   /// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
   /// i.e., with floating point elements and greater than 1 row and column
   /// counts.

+ 49 - 0
tools/clang/test/CodeGenSPIRV/type.matrix.majorness.zpr.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T vs_6_0 -E main /Zpr
+
+struct S {
+// CHECK: OpMemberDecorate %S 0 ColMajor
+               float2x3 mat1[2];
+// CHECK: OpMemberDecorate %S 1 ColMajor
+  row_major    float2x3 mat2[2];
+// CHECK: OpMemberDecorate %S 2 RowMajor
+  column_major float2x3 mat3[2];
+               float    f;
+};
+
+cbuffer MyCBuffer {
+// CHECK: OpMemberDecorate %type_MyCBuffer 0 ColMajor
+               float2x3 field1;
+// CHECK: OpMemberDecorate %type_MyCBuffer 1 ColMajor
+  row_major    float2x3 field2;
+// CHECK: OpMemberDecorate %type_MyCBuffer 2 RowMajor
+  column_major float2x3 field3;
+               S        field4;
+}
+
+struct T {
+               float    f;
+// CHECK: OpMemberDecorate %T 1 ColMajor
+               float2x3 mat1;
+// CHECK: OpMemberDecorate %T 2 ColMajor
+  row_major    float2x3 mat2;
+// CHECK: OpMemberDecorate %T 3 RowMajor
+  column_major float2x3 mat3;
+};
+
+struct U {
+               T        t;
+// CHECK: OpMemberDecorate %U 1 ColMajor
+               float2x3 mat1[2];
+// CHECK: OpMemberDecorate %U 2 ColMajor
+  row_major    float2x3 mat2[2];
+// CHECK: OpMemberDecorate %U 3 RowMajor
+  column_major float2x3 mat3[2];
+               float    f;
+};
+
+
+RWStructuredBuffer<U> MySBuffer;
+
+float3 main() : A {
+  return MySBuffer[0].mat1[1][1];
+}

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -468,6 +468,7 @@ public:
           spirvOpts.codeGenHighLevel = opts.CodeGenHighLevel;
           spirvOpts.disableValidation = opts.DisableValidation;
           spirvOpts.ignoreUnusedResources = opts.VkIgnoreUnusedResources;
+          spirvOpts.defaultRowMajor = opts.DefaultRowMajor;
           spirvOpts.stageIoOrder = opts.VkStageIoOrder;
           spirvOpts.bShift = opts.VkBShift;
           spirvOpts.tShift = opts.VkTShift;

+ 6 - 1
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -45,6 +45,9 @@ TEST_F(WholeFileTest, EmptyStructInterfaceVS) {
 TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 TEST_F(FileTest, VectorTypes) { runFileTest("type.vector.hlsl"); }
 TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
+TEST_F(FileTest, MatrixTypesMajornessZpr) {
+  runFileTest("type.matrix.majorness.zpr.hlsl");
+}
 TEST_F(FileTest, MatrixTypesMajorness) {
   runFileTest("type.matrix.majorness.hlsl", FileTest::Expect::Warning);
 }
@@ -1068,7 +1071,9 @@ TEST_F(FileTest, VulkanStructuredBufferCounter) {
 }
 
 TEST_F(FileTest, VulkanPushConstant) { runFileTest("vk.push-constant.hlsl"); }
-TEST_F(FileTest, VulkanPushConstantOffset) { runFileTest("vk.push-constant.offset.hlsl"); }
+TEST_F(FileTest, VulkanPushConstantOffset) {
+  runFileTest("vk.push-constant.offset.hlsl");
+}
 TEST_F(FileTest, VulkanMultiplePushConstant) {
   runFileTest("vk.push-constant.multiple.hlsl", FileTest::Expect::Failure);
 }