Pārlūkot izejas kodu

[spirv] Support ++ and -- for floating point matrices (#512)

Lei Zhang 8 gadi atpakaļ
vecāks
revīzija
d6c810d23d

+ 90 - 29
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -1071,6 +1071,34 @@ public:
     return isCompoundAssignment ? lhsPtr : rhs;
   }
 
+  /// Processes each vector within the given matrix by calling actOnEachVector.
+  /// matrixVal should be the loaded value of the matrix. actOnEachVector takes
+  /// three parameters for the current vector: the index, the <type-id>, and
+  /// the value. It returns the <result-id> of the processed vector.
+  uint32_t processEachVectorInMatrix(
+      const Expr *matrix, const uint32_t matrixVal,
+      llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
+          actOnEachVector) {
+    const auto matType = matrix->getType();
+    assert(TypeTranslator::isSpirvAcceptableMatrixType(matType));
+    const uint32_t vecType = typeTranslator.getComponentVectorType(matType);
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
+
+    llvm::SmallVector<uint32_t, 4> vectors;
+    // Extract each component vector and do operation on it
+    for (uint32_t i = 0; i < rowCount; ++i) {
+      const uint32_t lhsVec =
+          theBuilder.createCompositeExtract(vecType, matrixVal, {i});
+      vectors.push_back(actOnEachVector(i, vecType, lhsVec));
+    }
+
+    // Construct the result matrix
+    return theBuilder.createCompositeConstruct(
+        typeTranslator.translateType(matType), vectors);
+  }
+
   /// Generates the necessary instructions for conducting the given binary
   /// operation on lhs and rhs.
   ///
@@ -1107,24 +1135,16 @@ public:
     case BO_DivAssign:
     case BO_RemAssign: {
       const uint32_t vecType = typeTranslator.getComponentVectorType(lhsType);
-
-      uint32_t rowCount = 0, colCount = 0;
-      hlsl::GetHLSLMatRowColCount(lhsType, rowCount, colCount);
-
-      llvm::SmallVector<uint32_t, 4> vectors;
-      // Extract each component vector and do operation on it
-      for (uint32_t i = 0; i < rowCount; ++i) {
-        const uint32_t lhsVec =
-            theBuilder.createCompositeExtract(vecType, lhsVal, {i});
+      const auto actOnEachVec = [this, spvOp, rhsVal](
+          uint32_t index, uint32_t vecType, uint32_t lhsVec) {
+        // For each vector of lhs, we need to load the corresponding vector of
+        // rhs and do the operation on them.
         const uint32_t rhsVec =
-            theBuilder.createCompositeExtract(vecType, rhsVal, {i});
-        vectors.push_back(
-            theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec));
-      }
+            theBuilder.createCompositeExtract(vecType, rhsVal, {index});
+        return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec);
 
-      // Construct the result matrix
-      return theBuilder.createCompositeConstruct(
-          typeTranslator.translateType(lhsType), vectors);
+      };
+      return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec);
     }
     case BO_Assign:
       llvm_unreachable("assignment should not be handled here");
@@ -1270,10 +1290,23 @@ public:
       const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
 
       const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
-      const uint32_t one = getValueOne(subType);
       const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
-      const uint32_t incValue =
-          theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
+      const uint32_t one = hlsl::IsHLSLMatType(subType)
+                               ? getMatElemValueOne(subType)
+                               : getValueOne(subType);
+      uint32_t incValue = 0;
+      if (TypeTranslator::isSpirvAcceptableMatrixType(subType)) {
+        // For matrices, we can only incremnt/decrement each vector of it.
+        const auto actOnEachVec = [this, spvOp, one](
+            uint32_t /*index*/, uint32_t vecType, uint32_t lhsVec) {
+          return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, one);
+        };
+        incValue =
+            processEachVectorInMatrix(subExpr, originValue, actOnEachVec);
+      } else {
+        incValue =
+            theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
+      }
       theBuilder.createStore(subValue, incValue);
 
       // Prefix increment/decrement operator returns a lvalue, while postfix
@@ -2071,6 +2104,42 @@ case BO_##kind : {                                                             \
     return spv::Op::OpNop;
   }
 
+  /// Returns the <result-id> for a constant one vector of the given size and
+  /// element type.
+  uint32_t getVecValueOne(QualType elemType, uint32_t size) {
+    const uint32_t elemOneId = getValueOne(elemType);
+
+    if (size == 1)
+      return elemOneId;
+
+    llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemOneId);
+    const uint32_t vecType =
+        theBuilder.getVecType(typeTranslator.translateType(elemType), size);
+
+    return theBuilder.getConstantComposite(vecType, elements);
+  }
+
+  /// Returns the <result-id> for a constant one (vector) having the same
+  /// element type as the given matrix type.
+  ///
+  /// If a 1x1 matrix is given, the returned value one will be a scalar;
+  /// if a Mx1 or 1xN matrix is given, the returned value one will be a
+  /// vector of size M or N; if a MxN matrix is given, the returned value
+  /// one will be a vector of size N.
+  uint32_t getMatElemValueOne(QualType type) {
+    assert(hlsl::IsHLSLMatType(type));
+    const auto elemType = hlsl::GetHLSLMatElementType(type);
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+    if (rowCount == 1 && colCount == 1)
+      return getValueOne(elemType);
+    if (colCount == 1)
+      return getVecValueOne(elemType, rowCount);
+    return getVecValueOne(elemType, colCount);
+  }
+
   /// Returns the <result-id> for constant value 1 of the given type.
   uint32_t getValueOne(QualType type) {
     if (type->isSignedIntegerType()) {
@@ -2087,16 +2156,8 @@ case BO_##kind : {                                                             \
 
     if (hlsl::IsHLSLVecType(type)) {
       const QualType elemType = hlsl::GetHLSLVecElementType(type);
-      const uint32_t elemOneId = getValueOne(elemType);
-
-      const size_t size = hlsl::GetHLSLVecSize(type);
-      if (size == 1)
-        return elemOneId;
-
-      llvm::SmallVector<uint32_t, 4> elements(size, elemOneId);
-
-      const uint32_t vecTypeId = typeTranslator.translateType(type);
-      return theBuilder.getConstantComposite(vecTypeId, elements);
+      const auto size = hlsl::GetHLSLVecSize(type);
+      return getVecValueOne(elemType, size);
     }
 
     emitError("getting value 1 for type '%0' unimplemented") << type;

+ 43 - 0
tools/clang/test/CodeGenSPIRV/unary-op.postfix-dec.matrix.hlsl

@@ -0,0 +1,43 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: [[v2f1:%\d+]] = OpConstantComposite %v2float %float_1 %float_1
+// CHECK: [[v3f1:%\d+]] = OpConstantComposite %v3float %float_1 %float_1 %float_1
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // 1x1
+    float1x1 a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a1:%\d+]] = OpFSub %float [[a0]] %float_1
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: OpStore %b [[a0]]
+    b = a--;
+
+    // Mx1
+    float2x1 c, d;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c1:%\d+]] = OpFSub %v2float [[c0]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+// CHECK-NEXT: OpStore %d [[c0]]
+    d = c--;
+
+    // 1xN
+    float1x3 e, f;
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e1:%\d+]] = OpFSub %v3float [[e0]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e1]]
+// CHECK-NEXT: OpStore %f [[e0]]
+    f = e--;
+
+    // MxN
+    float2x3 g, h;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g0v0:%\d+]] = OpCompositeExtract %v3float [[g0]] 0
+// CHECK-NEXT: [[inc0:%\d+]] = OpFSub %v3float [[g0v0]] [[v3f1]]
+// CHECK-NEXT: [[g0v1:%\d+]] = OpCompositeExtract %v3float [[g0]] 1
+// CHECK-NEXT: [[inc1:%\d+]] = OpFSub %v3float [[g0v1]] [[v3f1]]
+// CHECK-NEXT: [[g1:%\d+]] = OpCompositeConstruct %mat2v3float [[inc0]] [[inc1]]
+// CHECK-NEXT: OpStore %g [[g1]]
+// CHECK-NEXT: OpStore %h [[g0]]
+    h = g--;
+}

+ 43 - 0
tools/clang/test/CodeGenSPIRV/unary-op.postfix-inc.matrix.hlsl

@@ -0,0 +1,43 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: [[v2f1:%\d+]] = OpConstantComposite %v2float %float_1 %float_1
+// CHECK: [[v3f1:%\d+]] = OpConstantComposite %v3float %float_1 %float_1 %float_1
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // 1x1
+    float1x1 a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a1:%\d+]] = OpFAdd %float [[a0]] %float_1
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: OpStore %b [[a0]]
+    b = a++;
+
+    // Mx1
+    float2x1 c, d;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c1:%\d+]] = OpFAdd %v2float [[c0]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+// CHECK-NEXT: OpStore %d [[c0]]
+    d = c++;
+
+    // 1xN
+    float1x3 e, f;
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e1:%\d+]] = OpFAdd %v3float [[e0]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e1]]
+// CHECK-NEXT: OpStore %f [[e0]]
+    f = e++;
+
+    // MxN
+    float2x3 g, h;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g0v0:%\d+]] = OpCompositeExtract %v3float [[g0]] 0
+// CHECK-NEXT: [[inc0:%\d+]] = OpFAdd %v3float [[g0v0]] [[v3f1]]
+// CHECK-NEXT: [[g0v1:%\d+]] = OpCompositeExtract %v3float [[g0]] 1
+// CHECK-NEXT: [[inc1:%\d+]] = OpFAdd %v3float [[g0v1]] [[v3f1]]
+// CHECK-NEXT: [[g1:%\d+]] = OpCompositeConstruct %mat2v3float [[inc0]] [[inc1]]
+// CHECK-NEXT: OpStore %g [[g1]]
+// CHECK-NEXT: OpStore %h [[g0]]
+    h = g++;
+}

+ 75 - 0
tools/clang/test/CodeGenSPIRV/unary-op.prefix-dec.matrix.hlsl

@@ -0,0 +1,75 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: [[v2f1:%\d+]] = OpConstantComposite %v2float %float_1 %float_1
+// CHECK: [[v3f1:%\d+]] = OpConstantComposite %v3float %float_1 %float_1 %float_1
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // 1x1
+    float1x1 a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a1:%\d+]] = OpFSub %float [[a0]] %float_1
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: OpStore %b [[a2]]
+    b = --a;
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a4:%\d+]] = OpFSub %float [[a3]] %float_1
+// CHECK-NEXT: OpStore %a [[a4]]
+// CHECK-NEXT: OpStore %a [[b0]]
+    --a = b;
+
+    // Mx1
+    float2x1 c, d;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c1:%\d+]] = OpFSub %v2float [[c0]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+// CHECK-NEXT: [[c2:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: OpStore %d [[c2]]
+    d = --c;
+// CHECK-NEXT: [[d0:%\d+]] = OpLoad %v2float %d
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c4:%\d+]] = OpFSub %v2float [[c3]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c4]]
+// CHECK-NEXT: OpStore %c [[d0]]
+    --c = d;
+
+    // 1xN
+    float1x3 e, f;
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e1:%\d+]] = OpFSub %v3float [[e0]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e1]]
+// CHECK-NEXT: [[e2:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: OpStore %f [[e2]]
+    f = --e;
+// CHECK-NEXT: [[f0:%\d+]] = OpLoad %v3float %f
+// CHECK-NEXT: [[e3:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e4:%\d+]] = OpFSub %v3float [[e3]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e4]]
+// CHECK-NEXT: OpStore %e [[f0]]
+    --e = f;
+
+    // MxN
+    float2x3 g, h;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g0v0:%\d+]] = OpCompositeExtract %v3float [[g0]] 0
+// CHECK-NEXT: [[inc0:%\d+]] = OpFSub %v3float [[g0v0]] [[v3f1]]
+// CHECK-NEXT: [[g0v1:%\d+]] = OpCompositeExtract %v3float [[g0]] 1
+// CHECK-NEXT: [[inc1:%\d+]] = OpFSub %v3float [[g0v1]] [[v3f1]]
+// CHECK-NEXT: [[g1:%\d+]] = OpCompositeConstruct %mat2v3float [[inc0]] [[inc1]]
+// CHECK-NEXT: OpStore %g [[g1]]
+// CHECK-NEXT: [[g2:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: OpStore %h [[g2]]
+    h = --g;
+// CHECK-NEXT: [[h0:%\d+]] = OpLoad %mat2v3float %h
+// CHECK-NEXT: [[g3:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g3v0:%\d+]] = OpCompositeExtract %v3float [[g3]] 0
+// CHECK-NEXT: [[inc2:%\d+]] = OpFSub %v3float [[g3v0]] [[v3f1]]
+// CHECK-NEXT: [[g3v1:%\d+]] = OpCompositeExtract %v3float [[g3]] 1
+// CHECK-NEXT: [[inc3:%\d+]] = OpFSub %v3float [[g3v1]] [[v3f1]]
+// CHECK-NEXT: [[g4:%\d+]] = OpCompositeConstruct %mat2v3float [[inc2]] [[inc3]]
+// CHECK-NEXT: OpStore %g [[g4]]
+// CHECK-NEXT: OpStore %g [[h0]]
+    --g = h;
+}

+ 75 - 0
tools/clang/test/CodeGenSPIRV/unary-op.prefix-inc.matrix.hlsl

@@ -0,0 +1,75 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: [[v2f1:%\d+]] = OpConstantComposite %v2float %float_1 %float_1
+// CHECK: [[v3f1:%\d+]] = OpConstantComposite %v3float %float_1 %float_1 %float_1
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // 1x1
+    float1x1 a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a1:%\d+]] = OpFAdd %float [[a0]] %float_1
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: OpStore %b [[a2]]
+    b = ++a;
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[a4:%\d+]] = OpFAdd %float [[a3]] %float_1
+// CHECK-NEXT: OpStore %a [[a4]]
+// CHECK-NEXT: OpStore %a [[b0]]
+    ++a = b;
+
+    // Mx1
+    float2x1 c, d;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c1:%\d+]] = OpFAdd %v2float [[c0]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+// CHECK-NEXT: [[c2:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: OpStore %d [[c2]]
+    d = ++c;
+// CHECK-NEXT: [[d0:%\d+]] = OpLoad %v2float %d
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[c4:%\d+]] = OpFAdd %v2float [[c3]] [[v2f1]]
+// CHECK-NEXT: OpStore %c [[c4]]
+// CHECK-NEXT: OpStore %c [[d0]]
+    ++c = d;
+
+    // 1xN
+    float1x3 e, f;
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e1:%\d+]] = OpFAdd %v3float [[e0]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e1]]
+// CHECK-NEXT: [[e2:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: OpStore %f [[e2]]
+    f = ++e;
+// CHECK-NEXT: [[f0:%\d+]] = OpLoad %v3float %f
+// CHECK-NEXT: [[e3:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[e4:%\d+]] = OpFAdd %v3float [[e3]] [[v3f1]]
+// CHECK-NEXT: OpStore %e [[e4]]
+// CHECK-NEXT: OpStore %e [[f0]]
+    ++e = f;
+
+    // MxN
+    float2x3 g, h;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g0v0:%\d+]] = OpCompositeExtract %v3float [[g0]] 0
+// CHECK-NEXT: [[inc0:%\d+]] = OpFAdd %v3float [[g0v0]] [[v3f1]]
+// CHECK-NEXT: [[g0v1:%\d+]] = OpCompositeExtract %v3float [[g0]] 1
+// CHECK-NEXT: [[inc1:%\d+]] = OpFAdd %v3float [[g0v1]] [[v3f1]]
+// CHECK-NEXT: [[g1:%\d+]] = OpCompositeConstruct %mat2v3float [[inc0]] [[inc1]]
+// CHECK-NEXT: OpStore %g [[g1]]
+// CHECK-NEXT: [[g2:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: OpStore %h [[g2]]
+    h = ++g;
+// CHECK-NEXT: [[h0:%\d+]] = OpLoad %mat2v3float %h
+// CHECK-NEXT: [[g3:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[g3v0:%\d+]] = OpCompositeExtract %v3float [[g3]] 0
+// CHECK-NEXT: [[inc2:%\d+]] = OpFAdd %v3float [[g3v0]] [[v3f1]]
+// CHECK-NEXT: [[g3v1:%\d+]] = OpCompositeExtract %v3float [[g3]] 1
+// CHECK-NEXT: [[inc3:%\d+]] = OpFAdd %v3float [[g3v1]] [[v3f1]]
+// CHECK-NEXT: [[g4:%\d+]] = OpCompositeConstruct %mat2v3float [[inc2]] [[inc3]]
+// CHECK-NEXT: OpStore %g [[g4]]
+// CHECK-NEXT: OpStore %g [[h0]]
+    ++g = h;
+}

+ 12 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -50,15 +50,27 @@ TEST_F(FileTest, VariableInitializer) { runFileTest("var.init.hlsl"); }
 TEST_F(FileTest, UnaryOpPrefixIncrement) {
   runFileTest("unary-op.prefix-inc.hlsl");
 }
+TEST_F(FileTest, UnaryOpPrefixIncrementMatrix) {
+  runFileTest("unary-op.prefix-inc.matrix.hlsl");
+}
 TEST_F(FileTest, UnaryOpPrefixDecrement) {
   runFileTest("unary-op.prefix-dec.hlsl");
 }
+TEST_F(FileTest, UnaryOpPrefixDecrementMatrix) {
+  runFileTest("unary-op.prefix-dec.matrix.hlsl");
+}
 TEST_F(FileTest, UnaryOpPostfixIncrement) {
   runFileTest("unary-op.postfix-inc.hlsl");
 }
+TEST_F(FileTest, UnaryOpPostfixIncrementMatrix) {
+  runFileTest("unary-op.postfix-inc.matrix.hlsl");
+}
 TEST_F(FileTest, UnaryOpPostfixDecrement) {
   runFileTest("unary-op.postfix-dec.hlsl");
 }
+TEST_F(FileTest, UnaryOpPostfixDecrementMatrix) {
+  runFileTest("unary-op.postfix-dec.matrix.hlsl");
+}
 
 // For unary operators
 TEST_F(FileTest, UnaryOpPlus) { runFileTest("unary-op.plus.hlsl"); }