Browse Source

[spirv] Consider majorness when accessing matrix in cbuffer (#1116)

Moved majorness info into SpirvEvalInfo so that we can translate
type correctly based on it. This affects accessing matrices inside
cbuffer; otherwise, we will have wrong type for the matrix field
loaded out of a cbuffer.

Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1112
Lei Zhang 7 years ago
parent
commit
307c860193

+ 12 - 12
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -304,15 +304,13 @@ SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
           cast<VarDecl>(decl)->getType(),
           // We need to set decorateLayout here to avoid creating SPIR-V
           // instructions for the current type without decorations.
-          info->info.getLayoutRule(), info->isRowMajor);
+          info->info.getLayoutRule(), info->info.isRowMajor());
 
       const uint32_t elemId = theBuilder.createAccessChain(
           theBuilder.getPointerType(varType, info->info.getStorageClass()),
           info->info, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
 
-      return SpirvEvalInfo(elemId)
-          .setStorageClass(info->info.getStorageClass())
-          .setLayoutRule(info->info.getLayoutRule());
+      return info->info.substResultId(elemId);
     } else {
       return *info;
     }
@@ -432,11 +430,12 @@ uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
   astDecls[var] =
       SpirvEvalInfo(id).setStorageClass(storageClass).setLayoutRule(rule);
   if (isMatType) {
+    astDecls[var].info.setRowMajor(
+        typeTranslator.isRowMajorMatrix(var->getType(), var));
+
     // We have wrapped the stand-alone matrix inside a struct. Mark it as
     // needing an extra index to access.
     astDecls[var].indexInCTBuffer = 0;
-    astDecls[var].isRowMajor =
-        typeTranslator.isRowMajorMatrix(var->getType(), var);
   }
 
   // Variables in Workgroup do not need descriptor decorations.
@@ -579,12 +578,13 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
     const auto *varDecl = cast<VarDecl>(subDecl);
     const bool isRowMajor =
         typeTranslator.isRowMajorMatrix(varDecl->getType(), varDecl);
-    astDecls[varDecl] = {SpirvEvalInfo(bufferVar)
-                             .setStorageClass(spv::StorageClass::Uniform)
-                             .setLayoutRule(decl->isCBuffer()
-                                                ? LayoutRule::GLSLStd140
-                                                : LayoutRule::GLSLStd430),
-                         index++, isRowMajor};
+    astDecls[varDecl] =
+        SpirvEvalInfo(bufferVar)
+            .setStorageClass(spv::StorageClass::Uniform)
+            .setLayoutRule(decl->isCBuffer() ? LayoutRule::GLSLStd140
+                                             : LayoutRule::GLSLStd430)
+            .setRowMajor(isRowMajor);
+    astDecls[varDecl].indexInCTBuffer = index++;
   }
   resourceVars.emplace_back(
       bufferVar, ResourceVar::Category::Other, getResourceBinding(decl),

+ 2 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -356,8 +356,8 @@ private:
     /// Default constructor to satisfy DenseMap
     DeclSpirvInfo() : info(0), indexInCTBuffer(-1) {}
 
-    DeclSpirvInfo(const SpirvEvalInfo &info_, int index = -1, bool row = false)
-        : info(info_), indexInCTBuffer(index), isRowMajor(row) {}
+    DeclSpirvInfo(const SpirvEvalInfo &info_, int index = -1)
+        : info(info_), indexInCTBuffer(index) {}
 
     /// Implicit conversion to SpirvEvalInfo.
     operator SpirvEvalInfo() const { return info; }
@@ -366,8 +366,6 @@ private:
     /// Value >= 0 means that this decl is a VarDecl inside a cbuffer/tbuffer
     /// and this is the index; value < 0 means this is just a standalone decl.
     int indexInCTBuffer;
-    /// Whether this decl should be row major.
-    bool isRowMajor;
   };
 
   /// \brief Returns the SPIR-V information for the given decl.

+ 7 - 5
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -783,8 +783,8 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
   if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
     valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
   } else {
-    valType =
-        typeTranslator.translateType(expr->getType(), info.getLayoutRule());
+    valType = typeTranslator.translateType(
+        expr->getType(), info.getLayoutRule(), info.isRowMajor());
   }
   return info.setResultId(theBuilder.createLoad(valType, info)).setRValue();
 }
@@ -4603,13 +4603,15 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
   } else if (const auto *recordType = lhsValType->getAs<RecordType>()) {
     uint32_t index = 0;
     for (const auto *field : recordType->getDecl()->fields()) {
+      bool isRowMajor =
+          typeTranslator.isRowMajorMatrix(field->getType(), field);
       const auto subRhsValType = typeTranslator.translateType(
-          field->getType(), rhsVal.getLayoutRule());
+          field->getType(), rhsVal.getLayoutRule(), isRowMajor);
       const auto subRhsVal =
           theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
       const auto subLhsPtrType = theBuilder.getPointerType(
-          typeTranslator.translateType(field->getType(),
-                                       lhsPtr.getLayoutRule()),
+          typeTranslator.translateType(field->getType(), lhsPtr.getLayoutRule(),
+                                       isRowMajor),
           lhsPtr.getStorageClass());
       const auto subLhsPtr = theBuilder.createAccessChain(
           subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});

+ 10 - 1
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -100,6 +100,9 @@ public:
   inline SpirvEvalInfo &setRelaxedPrecision();
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
 
+  inline SpirvEvalInfo &setRowMajor(bool);
+  bool isRowMajor() const { return isRowMajor_; }
+
 private:
   uint32_t resultId;
   /// Indicates whether this evaluation result contains alias variables
@@ -119,13 +122,14 @@ private:
   bool isConstant_;
   bool isSpecConstant_;
   bool isRelaxedPrecision_;
+  bool isRowMajor_;
 };
 
 SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
     : resultId(id), containsAlias(false),
       storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void),
       isRValue_(false), isConstant_(false), isSpecConstant_(false),
-      isRelaxedPrecision_(false) {}
+      isRelaxedPrecision_(false), isRowMajor_(false) {}
 
 SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
   resultId = id;
@@ -174,6 +178,11 @@ SpirvEvalInfo &SpirvEvalInfo::setRelaxedPrecision() {
   return *this;
 }
 
+SpirvEvalInfo &SpirvEvalInfo::setRowMajor(bool rm) {
+  isRowMajor_ = rm;
+  return *this;
+}
+
 } // end namespace spirv
 } // end namespace clang
 

+ 3 - 1
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -412,8 +412,10 @@ uint32_t TypeTranslator::getElementSpirvBitwidth(QualType type) {
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
                                        bool isRowMajor) {
   // We can only apply row_major to matrices or arrays of matrices.
+  // isRowMajor will be ignored for scalar and vector types.
   if (isRowMajor)
-    assert(isMxNMatrix(type) || type->isArrayType());
+    assert(type->isScalarType() || type->isArrayType() ||
+           hlsl::IsHLSLVecMatType(type));
 
   // Try to translate the canonical type first
   const auto canonicalType = type.getCanonicalType();

+ 34 - 0
tools/clang/test/CodeGenSPIRV/op.cbuffer.access.majorness.hlsl

@@ -0,0 +1,34 @@
+// Run: %dxc -T cs_6_0 -E main -Zpr
+
+// CHECK: %SData = OpTypeStruct %_arr_mat3v4float_uint_2 %_arr_mat3v4float_uint_2_0
+struct SData {
+                float3x4 mat1[2];
+   column_major float3x4 mat2[2];
+};
+
+// CHECK: %type_SBufferData = OpTypeStruct %SData %_arr_mat3v4float_uint_2 %_arr_mat3v4float_uint_2_0
+cbuffer SBufferData {
+                SData    BufferData;
+                float3x4 Mat1[2];
+   column_major float3x4 Mat2[2];
+};
+
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SData %SBufferData %int_0
+// CHECK: [[val:%\d+]] = OpLoad %SData [[ptr]]
+// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2 %32 0
+// CHECK:     {{%\d+}} = OpCompositeExtract %_arr_mat3v4float_uint_2_0 %32 1
+static const SData Data = BufferData;
+
+RWStructuredBuffer<float4> Out;
+
+[numthreads(4, 4, 4)]
+void main() {
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_mat3v4float_uint_2 %SBufferData %int_1
+// CHECK:     {{%\d+}} = OpLoad %_arr_mat3v4float_uint_2 [[ptr]]
+  float3x4 a[2] = Mat1;
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_mat3v4float_uint_2_0 %SBufferData %int_2
+// CHECK:     {{%\d+}} = OpLoad %_arr_mat3v4float_uint_2_0 [[ptr]]
+  float3x4 b[2] = Mat2;
+
+  Out[0] = Data.mat1[0][0];
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -278,6 +278,10 @@ TEST_F(FileTest, OpArrayAccess) { runFileTest("op.array.access.hlsl"); }
 TEST_F(FileTest, OpBufferAccess) { runFileTest("op.buffer.access.hlsl"); }
 TEST_F(FileTest, OpRWBufferAccess) { runFileTest("op.rwbuffer.access.hlsl"); }
 TEST_F(FileTest, OpCBufferAccess) { runFileTest("op.cbuffer.access.hlsl"); }
+TEST_F(FileTest, OpCBufferAccessMajorness) {
+  /// Tests that we correctly consider majorness when accessing matrices
+  runFileTest("op.cbuffer.access.majorness.hlsl");
+}
 TEST_F(FileTest, OpConstantBufferAccess) {
   runFileTest("op.constant-buffer.access.hlsl");
 }