Explorar el Código

[spirv] Add support for HLSLMatrixTruncationCast (#753)

Ehsan hace 7 años
padre
commit
4b7fab259f

+ 16 - 0
docs/SPIR-V.rst

@@ -877,6 +877,22 @@ Casting between (vectors) of scalar types is translated according to the followi
 |   Float    |                   | ``OpConvertFToS`` | ``OpConvertFToU`` |      no-op        |
 +------------+-------------------+-------------------+-------------------+-------------------+
 
+It is also feasible in HLSL to cast a float matrix to another float matrix with a smaller size.
+This is known as matrix truncation cast. For instance, the following code casts a 3x4 matrix
+into a 2x3 matrix.
+
+.. code:: hlsl
+
+  float3x4 m = { 1,  2,  3, 4,
+                 5,  6,  7, 8,
+                 9, 10, 11, 12 };
+
+  float2x3 a = (float2x3)m;
+
+Such casting takes the upper-left most corner of the original matrix to generate the result.
+In the above example, matrix ``a`` will have 2 rows, with 3 columns each. First row will be
+``1, 2, 3`` and the second row will be ``5, 6, 7``.
+
 Indexing operator
 -----------------
 

+ 58 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1339,6 +1339,13 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
   switch (expr->getCastKind()) {
   case CastKind::CK_LValueToRValue: {
     auto info = doExpr(subExpr);
+
+    // There are cases where the AST includes incorrect LValueToRValue nodes in
+    // the tree where not necessary. To make sure we emit the correct SPIR-V, we
+    // should bypass such casts.
+    if (subExpr->IgnoreParenNoopCasts(astContext)->isRValue())
+      return info;
+
     if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr) ||
         isBufferTextureIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr)) ||
         isTextureMipsSampleIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr))) {
@@ -1466,6 +1473,57 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
       return theBuilder.createCompositeConstruct(matType, vectors);
     }
   }
+  case CastKind::CK_HLSLMatrixTruncationCast: {
+    const QualType srcType = subExpr->getType();
+    const uint32_t srcId = doExpr(subExpr);
+    const QualType elemType = hlsl::GetHLSLMatElementType(srcType);
+    const uint32_t dstTypeId = typeTranslator.translateType(toType);
+    llvm::SmallVector<uint32_t, 4> indexes;
+
+    // It is possible that the source matrix is in fact a vector.
+    // For example: Truncate float1x3 --> float1x2.
+    // The front-end disallows float1x3 --> float2x1.
+    {
+      uint32_t srcVecSize = 0, dstVecSize = 0;
+      if (TypeTranslator::isVectorType(srcType, nullptr, &srcVecSize) &&
+          TypeTranslator::isVectorType(toType, nullptr, &dstVecSize)) {
+        for (uint32_t i = 0; i < dstVecSize; ++i)
+          indexes.push_back(i);
+        return theBuilder.createVectorShuffle(dstTypeId, srcId, srcId, indexes);
+      }
+    }
+
+    uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0;
+    hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols);
+    hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols);
+    const uint32_t elemTypeId = typeTranslator.translateType(elemType);
+    const uint32_t srcRowType = theBuilder.getVecType(elemTypeId, srcCols);
+
+    // Indexes to pass to OpVectorShuffle
+    for (uint32_t i = 0; i < dstCols; ++i)
+      indexes.push_back(i);
+
+    llvm::SmallVector<uint32_t, 4> extractedVecs;
+    for (uint32_t row = 0; row < dstRows; ++row) {
+      // Extract a row
+      uint32_t rowId =
+          theBuilder.createCompositeExtract(srcRowType, srcId, {row});
+      // Extract the necessary columns from that row.
+      // The front-end ensures dstCols <= srcCols.
+      // If dstCols equals srcCols, we can use the whole row directly.
+      if (dstCols == 1) {
+        rowId = theBuilder.createCompositeExtract(elemTypeId, rowId, {0});
+      } else if (dstCols < srcCols) {
+        rowId = theBuilder.createVectorShuffle(
+            theBuilder.getVecType(elemTypeId, dstCols), rowId, rowId, indexes);
+      }
+      extractedVecs.push_back(rowId);
+    }
+    if (extractedVecs.size() == 1)
+      return extractedVecs.front();
+    return theBuilder.createCompositeConstruct(
+        typeTranslator.translateType(toType), extractedVecs);
+  }
   case CastKind::CK_HLSLMatrixToScalarCast: {
     // The underlying should already be a matrix of 1x1.
     assert(TypeTranslator::is1x1Matrix(subExpr->getType()));

+ 76 - 0
tools/clang/test/CodeGenSPIRV/cast.matrix.trunc.hlsl

@@ -0,0 +1,76 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// Note: Matrix truncation cast does not allow truncation into a mat1x1.
+
+void main() {
+  float3x4 m = { 1,  2,  3, 4,
+                 5,  6,  7, 8,
+                 9, 10, 11, 12 };
+
+  float1x4 n = {1, 2, 3, 4};
+  float3x1 o = {1, 2, 3};
+
+
+// TEST: float3x4 --> float2x3
+
+// CHECK:           [[m_0:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT: [[m_0_row0:%\d+]] = OpCompositeExtract %v4float [[m_0]] 0
+// CHECK-NEXT:   [[a_row0:%\d+]] = OpVectorShuffle %v3float [[m_0_row0]] [[m_0_row0]] 0 1 2
+// CHECK-NEXT: [[m_0_row1:%\d+]] = OpCompositeExtract %v4float [[m_0]] 1
+// CHECK-NEXT:   [[a_row1:%\d+]] = OpVectorShuffle %v3float [[m_0_row1]] [[m_0_row1]] 0 1 2
+// CHECK-NEXT:          {{%\d+}} = OpCompositeConstruct %mat2v3float [[a_row0]] [[a_row1]]
+  float2x3 a = (float2x3)m;
+
+
+// TEST: float3x4 --> float1x3
+
+// CHECK:           [[m_1:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT: [[m_1_row0:%\d+]] = OpCompositeExtract %v4float [[m_1]] 0
+// CHECK-NEXT:          {{%\d+}} = OpVectorShuffle %v3float [[m_1_row0]] [[m_1_row0]] 0 1 2
+  float1x3 b = m;
+
+
+// TEST: float3x4 --> float1x4
+
+// CHECK:      [[m_2:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT:     {{%\d+}} = OpCompositeExtract %v4float [[m_2]] 0
+  float1x4 c = m;
+
+
+// TEST: float3x4 --> float2x1
+
+// CHECK:           [[m_3:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT: [[m_3_row0:%\d+]] = OpCompositeExtract %v4float [[m_3]] 0
+// CHECK-NEXT:   [[d_row0:%\d+]] = OpCompositeExtract %float [[m_3_row0]] 0
+// CHECK-NEXT: [[m_3_row1:%\d+]] = OpCompositeExtract %v4float [[m_3]] 1
+// CHECK-NEXT:   [[d_row1:%\d+]] = OpCompositeExtract %float [[m_3_row1]] 0
+// CHECK-NEXT:          {{%\d+}} = OpCompositeConstruct %v2float [[d_row0]] [[d_row1]]
+  float2x1 d = m;
+
+
+// TEST: float3x4 --> float3x1
+
+// CHECK:           [[m_4:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT: [[m_4_row0:%\d+]] = OpCompositeExtract %v4float [[m_4]] 0
+// CHECK-NEXT:   [[e_row0:%\d+]] = OpCompositeExtract %float [[m_4_row0]] 0
+// CHECK-NEXT: [[m_4_row1:%\d+]] = OpCompositeExtract %v4float [[m_4]] 1
+// CHECK-NEXT:   [[e_row1:%\d+]] = OpCompositeExtract %float [[m_4_row1]] 0
+// CHECK-NEXT: [[m_4_row2:%\d+]] = OpCompositeExtract %v4float [[m_4]] 2
+// CHECK-NEXT:   [[e_row2:%\d+]] = OpCompositeExtract %float [[m_4_row2]] 0
+// CHECK-NEXT:          {{%\d+}} = OpCompositeConstruct %v3float [[e_row0]] [[e_row1]] [[e_row2]]
+  float3x1 e = (float3x1)m;
+
+
+// TEST float1x4 --> float1x3
+
+// CHECK:      [[n:%\d+]] = OpLoad %v4float %n
+// CHECK-NEXT:   {{%\d+}} = OpVectorShuffle %v3float [[n]] [[n]] 0 1 2
+  float1x3 f = n;
+
+
+// TEST float3x1 --> float2x1
+
+// CHECK:      [[o:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT:   {{%\d+}} = OpVectorShuffle %v2float [[o]] [[o]] 0 1
+  float2x1 g = (float2x1)o;
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -276,6 +276,7 @@ TEST_F(FileTest, CastImplicitFlatConversion) {
 
 // For vector/matrix splatting and trunction
 TEST_F(FileTest, CastTruncateVector) { runFileTest("cast.vector.trunc.hlsl"); }
+TEST_F(FileTest, CastTruncateMatrix) { runFileTest("cast.matrix.trunc.hlsl"); }
 TEST_F(FileTest, CastSplatVector) { runFileTest("cast.vector.splat.hlsl"); }
 TEST_F(FileTest, CastSplatMatrix) { runFileTest("cast.matrix.splat.hlsl"); }