Browse Source

[spirv] Support MatrixToVector cast for 2x2 matrices. (#2113)

Ehsan 6 years ago
parent
commit
428296a22d

+ 27 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2146,6 +2146,7 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
   const Expr *subExpr = expr->getSubExpr();
   const Expr *subExpr = expr->getSubExpr();
   const QualType subExprType = subExpr->getType();
   const QualType subExprType = subExpr->getType();
   const QualType toType = expr->getType();
   const QualType toType = expr->getType();
+  const auto srcLoc = expr->getExprLoc();
 
 
   switch (expr->getCastKind()) {
   switch (expr->getCastKind()) {
   case CastKind::CK_LValueToRValue:
   case CastKind::CK_LValueToRValue:
@@ -2343,9 +2344,32 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
     return doExpr(subExpr);
     return doExpr(subExpr);
   }
   }
   case CastKind::CK_HLSLMatrixToVectorCast: {
   case CastKind::CK_HLSLMatrixToVectorCast: {
-    // The underlying should already be a matrix of 1xN.
-    assert(is1xNMatrix(subExprType) || isMx1Matrix(subExprType));
-    return doExpr(subExpr);
+    // If the underlying matrix is Mx1 or 1xM for M in {1, 2,3,4}, we can return
+    // the underlying matrix because it'll be evaluated as a vector by default.
+    if (is1x1Matrix(subExprType) || is1xNMatrix(subExprType) ||
+        isMx1Matrix(subExprType))
+      return doExpr(subExpr);
+
+    // A vector can have no more than 4 elements. The only remaining case
+    // is casting from a 2x2 matrix to a vector of size 4.
+
+    auto *mat = loadIfGLValue(subExpr);
+    QualType elemType = {};
+    uint32_t rowCount = 0, colCount = 0, elemCount = 0;
+    const bool isMat =
+        isMxNMatrix(subExprType, &elemType, &rowCount, &colCount);
+    const bool isVec = isVectorType(toType, nullptr, &elemCount);
+    assert(isMat && rowCount == 2 && colCount == 2);
+    assert(isVec && elemCount == 4);
+    (void)isMat;
+    (void)isVec;
+    QualType vec2Type = astContext.getExtVectorType(elemType, 2);
+    auto *row0 = spvBuilder.createCompositeExtract(vec2Type, mat, {0}, srcLoc);
+    auto *row1 = spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc);
+    auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3},
+                                               srcLoc);
+    vec->setRValue();
+    return vec;
   }
   }
   case CastKind::CK_FunctionToPointerDecay:
   case CastKind::CK_FunctionToPointerDecay:
     // Just need to return the function id
     // Just need to return the function id

+ 33 - 0
tools/clang/test/CodeGenSPIRV/cast.mat-to-vec.hlsl

@@ -0,0 +1,33 @@
+// Run: %dxc -T ps_6_0 -E main
+
+float4 main(float4 input : A) : SV_Target {
+  float2x2 floatMat;
+  int2x2   intMat;
+  bool2x2  boolMat;
+
+// CHECK:      [[floatMat:%\d+]] = OpLoad %mat2v2float %floatMat
+// CHECK-NEXT:     [[row0:%\d+]] = OpCompositeExtract %v2float [[floatMat]] 0
+// CHECK-NEXT:     [[row1:%\d+]] = OpCompositeExtract %v2float [[floatMat]] 1
+// CHECK-NEXT:      [[vec:%\d+]] = OpVectorShuffle %v4float [[row0]] [[row1]] 0 1 2 3
+// CHECK-NEXT:                     OpStore %c [[vec]]
+  float4 c = floatMat;
+
+// CHECK:        [[intMat:%\d+]] = OpLoad %_arr_v2int_uint_2 %intMat
+// CHECK-NEXT:     [[row0:%\d+]] = OpCompositeExtract %v2int [[intMat]] 0
+// CHECK-NEXT:     [[row1:%\d+]] = OpCompositeExtract %v2int [[intMat]] 1
+// CHECK-NEXT:   [[vecInt:%\d+]] = OpVectorShuffle %v4int [[row0]] [[row1]] 0 1 2 3
+// CHECK-NEXT: [[vecFloat:%\d+]] = OpConvertSToF %v4float [[vecInt]]
+// CHECK-NEXT:                     OpStore %d [[vecFloat]]
+  float4 d = intMat;
+
+// CHECK:       [[boolMat:%\d+]] = OpLoad %_arr_v2bool_uint_2 %boolMat
+// CHECK-NEXT:     [[row0:%\d+]] = OpCompositeExtract %v2bool [[boolMat]] 0
+// CHECK-NEXT:     [[row1:%\d+]] = OpCompositeExtract %v2bool [[boolMat]] 1
+// CHECK-NEXT:      [[vec:%\d+]] = OpVectorShuffle %v4bool [[row0]] [[row1]] 0 1 2 3
+// CHECK-NEXT: [[vecFloat:%\d+]] = OpSelect %v4float [[vec]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT:                     OpStore %e [[vecFloat]]
+  float4 e = boolMat;
+
+  return 0.xxxx;
+}
+

+ 11 - 0
tools/clang/test/CodeGenSPIRV/type.append.consume-structured-buffer.cast.hlsl

@@ -30,6 +30,9 @@ AppendStructuredBuffer<int> append_int;
 RWStructuredBuffer<bool> rw_bool;
 RWStructuredBuffer<bool> rw_bool;
 RWStructuredBuffer<bool2> rw_v2bool;
 RWStructuredBuffer<bool2> rw_v2bool;
 
 
+ConsumeStructuredBuffer<float2x2> consume_float2x2;
+AppendStructuredBuffer<float4> append_v4float;
+
 void main() {
 void main() {
 // CHECK:       [[p_0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %append_bool %uint_0 {{%\d+}}
 // CHECK:       [[p_0:%\d+]] = OpAccessChain %_ptr_Uniform_uint %append_bool %uint_0 {{%\d+}}
 
 
@@ -360,4 +363,12 @@ void main() {
 // CHECK-NEXT: [[bi_61:%\d+]] = OpSelect %float [[b_61]] %float_1 %float_0
 // CHECK-NEXT: [[bi_61:%\d+]] = OpSelect %float [[b_61]] %float_1 %float_0
 // CHECK-NEXT:                  OpStore {{%\d+}} [[bi_61]]
 // CHECK-NEXT:                  OpStore {{%\d+}} [[bi_61]]
   append_float.Append(rw_v2bool[0].x);
   append_float.Append(rw_v2bool[0].x);
+
+// CHECK:      [[matPtr:%\d+]] = OpAccessChain %_ptr_Uniform_mat2v2float %consume_float2x2 %uint_0 {{%\d+}}
+// CHECK-NEXT:    [[mat:%\d+]] = OpLoad %mat2v2float [[matPtr]]
+// CHECK-NEXT:   [[row0:%\d+]] = OpCompositeExtract %v2float [[mat]] 0
+// CHECK-NEXT:   [[row1:%\d+]] = OpCompositeExtract %v2float [[mat]] 1
+// CHECK-NEXT:   [[vec4:%\d+]] = OpVectorShuffle %v4float [[row0]] [[row1]] 0 1 2 3
+// CHECK-NEXT:                   OpStore {{%\d+}} [[vec4]]
+  append_v4float.Append(consume_float2x2.Consume());
 }
 }

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -417,6 +417,9 @@ TEST_F(FileTest, CastFlatConversionDecomposeVector) {
 TEST_F(FileTest, CastExplicitVecToMat) {
 TEST_F(FileTest, CastExplicitVecToMat) {
   runFileTest("cast.vec-to-mat.explicit.hlsl");
   runFileTest("cast.vec-to-mat.explicit.hlsl");
 }
 }
+TEST_F(FileTest, CastMatrixToVector) {
+  runFileTest("cast.mat-to-vec.hlsl");
+}
 TEST_F(FileTest, CastBitwidth) { runFileTest("cast.bitwidth.hlsl"); }
 TEST_F(FileTest, CastBitwidth) { runFileTest("cast.bitwidth.hlsl"); }
 
 
 TEST_F(FileTest, CastLiteralTypeForArraySubscript) {
 TEST_F(FileTest, CastLiteralTypeForArraySubscript) {