Răsfoiți Sursa

[spirv] Handle corner cases of transpose and mul. (#3127)

Fixes #3109.
Ehsan 5 ani în urmă
părinte
comite
c2afcb7e81

+ 14 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -38,6 +38,20 @@ bool isScalarType(QualType type, QualType *scalarType = nullptr);
 bool isVectorType(QualType type, QualType *elemType = nullptr,
                   uint32_t *elemCount = nullptr);
 
+/// Returns true if the given type will be translated into a SPIR-V scalar type
+/// or vector type.
+///
+/// This includes:
+/// scalar types
+/// vector types (vec1, vec2, vec3, and vec4)
+/// Mx1 matrices (where M can be 1,2,3,4)
+/// 1xN matrices (where N can be 1,2,3,4)
+///
+/// Writes the element type and count into *elementType and *count respectively
+/// if they are not nullptr.
+bool isScalarOrVectorType(QualType type, QualType *elemType = nullptr,
+                          uint32_t *elemCount = nullptr);
+
 /// Returns true if the given type is an array with constant known size.
 bool isConstantArrayType(const ASTContext &, QualType);
 

+ 11 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -152,6 +152,17 @@ bool isVectorType(QualType type, QualType *elemType, uint32_t *elemCount) {
   return isVec;
 }
 
+bool isScalarOrVectorType(QualType type, QualType *elemType,
+                          uint32_t *elemCount) {
+  if (isScalarType(type, elemType)) {
+    if (elemCount)
+      *elemCount = 1;
+    return true;
+  }
+
+  return isVectorType(type, elemType, elemCount);
+}
+
 bool isConstantArrayType(const ASTContext &astContext, QualType type) {
   return astContext.getAsConstantArrayType(type) != nullptr;
 }

+ 62 - 15
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2330,10 +2330,18 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
     llvm::SmallVector<uint32_t, 4> indexes;
 
     // It is possible that the source matrix is in fact a vector.
-    // For example: Truncate float1x3 --> float1x2.
+    // Example 1: Truncate float1x3 --> float1x2.
+    // Example 2: Truncate float1x3 --> float1x1.
     // The front-end disallows float1x3 --> float2x1.
     {
       uint32_t srcVecSize = 0, dstVecSize = 0;
+      if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) {
+        auto *val = spvBuilder.createCompositeExtract(toType, src, {0},
+                                                      expr->getLocStart());
+        val->setRValue();
+        return val;
+      }
+
       if (isVectorType(srcType, nullptr, &srcVecSize) &&
           isVectorType(toType, nullptr, &dstVecSize)) {
         for (uint32_t i = 0; i < dstVecSize; ++i)
@@ -7177,12 +7185,17 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_transpose: {
     const Expr *mat = callExpr->getArg(0);
     const QualType matType = mat->getType();
-    if (hlsl::GetHLSLMatElementType(matType)->isFloatingType())
-      retVal =
-          processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpTranspose, false);
-    else
-      retVal = processNonFpMatrixTranspose(matType, doExpr(mat), srcLoc);
-
+    if (isVectorType(matType) || isScalarType(matType)) {
+      // A 1xN or Nx1 or 1x1 matrix is a SPIR-V vector/scalar, and its transpose
+      // is the vector/scalar itself.
+      retVal = doExpr(mat);
+    } else {
+      if (hlsl::GetHLSLMatElementType(matType)->isFloatingType())
+        retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpTranspose,
+                                                false);
+      else
+        retVal = processNonFpMatrixTranspose(matType, doExpr(mat), srcLoc);
+    }
     break;
   }
   // DXR raytracing intrinsics
@@ -8705,8 +8718,30 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) {
   }
 
   // mul(vector, vector)
-  if (isVectorType(arg0Type) && isVectorType(arg1Type))
-    return processIntrinsicDot(callExpr);
+  if (isVectorType(arg0Type) && isVectorType(arg1Type)) {
+    // mul( Mat(1xM), Mat(Mx1) ) results in a scalar (same as dot product)
+    if (isScalarType(returnType)) {
+      return processIntrinsicDot(callExpr);
+    }
+
+    // mul( Mat(Mx1), Mat(1xN) ) results in a MxN matrix.
+    QualType elemType = {};
+    uint32_t numRows = 0;
+    if (isMxNMatrix(returnType, &elemType, &numRows)) {
+      llvm::SmallVector<SpirvInstruction *, 4> rows;
+      auto *arg0Id = doExpr(arg0);
+      auto *arg1Id = doExpr(arg1);
+      for (uint32_t i = 0; i < numRows; ++i) {
+        auto *scalar =
+            spvBuilder.createCompositeExtract(elemType, arg0Id, {i}, loc);
+        rows.push_back(spvBuilder.createBinaryOp(
+            spv::Op::OpVectorTimesScalar, arg1Type, arg1Id, scalar, loc));
+      }
+      return spvBuilder.createCompositeConstruct(returnType, rows, loc);
+    }
+
+    llvm_unreachable("bad arguments passed to mul");
+  }
 
   // All the following cases require handling arg0 and arg1 expressions first.
   auto *arg0Id = doExpr(arg0);
@@ -8829,8 +8864,6 @@ SpirvEmitter::processIntrinsicPrintf(const CallExpr *callExpr) {
 }
 
 SpirvInstruction *SpirvEmitter::processIntrinsicDot(const CallExpr *callExpr) {
-  const QualType returnType = callExpr->getType();
-
   // Get the function parameters. Expect 2 vectors as parameters.
   assert(callExpr->getNumArgs() == 2u);
   const Expr *arg0 = callExpr->getArg(0);
@@ -8839,14 +8872,28 @@ SpirvInstruction *SpirvEmitter::processIntrinsicDot(const CallExpr *callExpr) {
   auto *arg1Id = doExpr(arg1);
   QualType arg0Type = arg0->getType();
   QualType arg1Type = arg1->getType();
-  const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
-  const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
-  const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
-  const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
+  uint32_t vec0Size = 0, vec1Size = 0;
+  QualType vec0ComponentType = {}, vec1ComponentType = {};
+  QualType returnType = {};
+  const bool arg0isScalarOrVec =
+      isScalarOrVectorType(arg0Type, &vec0ComponentType, &vec0Size);
+  const bool arg1isScalarOrVec =
+      isScalarOrVectorType(arg1Type, &vec1ComponentType, &vec1Size);
+  const bool returnIsScalar = isScalarType(callExpr->getType(), &returnType);
+  // Each argument should either be a vector or a scalar
+  assert(arg0isScalarOrVec && arg1isScalarOrVec);
+  // The result type must be a scalar.
+  assert(returnIsScalar);
+  // The element type of each argument and the return type must be the same.
   assert(returnType == vec1ComponentType);
   assert(vec0ComponentType == vec1ComponentType);
+  // The size of the two arguments must be equal.
   assert(vec0Size == vec1Size);
+  // Acceptable vector sizes are 1,2,3,4.
   assert(vec0Size >= 1 && vec0Size <= 4);
+  (void)arg0isScalarOrVec;
+  (void)arg1isScalarOrVec;
+  (void)returnIsScalar;
   (void)vec0ComponentType;
   (void)vec1ComponentType;
   (void)vec1Size;

+ 8 - 0
tools/clang/test/CodeGenSPIRV/cast.matrix.trunc.hlsl

@@ -68,6 +68,14 @@ void main() {
   float1x3 f = n;
 
 
+// TEST float1x4 --> float1x3
+
+// CHECK:       [[n:%\d+]] = OpLoad %v4float %n
+// CHECK-NEXT: [[n0:%\d+]] = OpCompositeExtract %float [[n]] 0
+// CHECK-NEXT:               OpStore %scalar [[n0]]
+  float1x1 scalar = n;
+
+
 // TEST float3x1 --> float2x1
 
 // CHECK:      [[o:%\d+]] = OpLoad %v3float %o

+ 37 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl

@@ -437,4 +437,41 @@ void main() {
 // CHECK-NEXT:    {{%\d+}} = OpCompositeConstruct %_arr_v3int_uint_2 [[t0]] [[t1]]
   int4x3 intMat4x3;
   int2x3 t = mul(intMat2x4, intMat4x3);
+
+
+//
+// 1-D matrices passed to mul
+//
+
+// mul( Mat(1xM) * Mat(MxN) ) --> Mat(1xN) vector
+// mul( Mat(1xM) * Mat(Mx1) ) --> Scalar
+// mul( Mat(Mx1) * Mat(1xN) ) --> Mat(MxN) matrix
+  float1x3 mat1x3;
+  float3x2 mat3x2;
+  float3x1 mat3x1;
+  float1x4 mat1x4;
+
+// CHECK:       [[mat1x3:%\d+]] = OpLoad %v3float %mat1x3
+// CHECK-NEXT:  [[mat3x2:%\d+]] = OpLoad %mat3v2float %mat3x2
+// CHECK-NEXT: [[result1:%\d+]] = OpMatrixTimesVector %v2float [[mat3x2]] [[mat1x3]]
+// CHECK-NEXT:                    OpStore %result1 [[result1]]
+  float1x2   result1 = mul( mat1x3, mat3x2 ); // result is float2 vector
+
+// CHECK:       [[mat1x3:%\d+]] = OpLoad %v3float %mat1x3
+// CHECK-NEXT:  [[mat3x1:%\d+]] = OpLoad %v3float %mat3x1
+// CHECK-NEXT: [[result2:%\d+]] = OpDot %float [[mat1x3]] [[mat3x1]]
+// CHECK-NEXT:                    OpStore %result2 [[result2]]
+  float      result2 = mul( mat1x3, mat3x1 ); // result is scalar
+
+// CHECK:       [[mat3x1:%\d+]] = OpLoad %v3float %mat3x1
+// CHECK-NEXT:  [[mat1x4:%\d+]] = OpLoad %v4float %mat1x4
+// CHECK-NEXT:   [[elem0:%\d+]] = OpCompositeExtract %float [[mat3x1]] 0
+// CHECK-NEXT:    [[row0:%\d+]] = OpVectorTimesScalar %v4float [[mat1x4]] [[elem0]]
+// CHECK-NEXT:   [[elem1:%\d+]] = OpCompositeExtract %float [[mat3x1]] 1
+// CHECK-NEXT:    [[row1:%\d+]] = OpVectorTimesScalar %v4float [[mat1x4]] [[elem1]]
+// CHECK-NEXT:   [[elem2:%\d+]] = OpCompositeExtract %float [[mat3x1]] 2
+// CHECK-NEXT:    [[row2:%\d+]] = OpVectorTimesScalar %v4float [[mat1x4]] [[elem2]]
+// CHECK-NEXT: [[result3:%\d+]] = OpCompositeConstruct %mat3v4float [[row0]] [[row1]] [[row2]]
+// CHECK-NEXT:                    OpStore %result3 [[result3]]
+  float3x4   result3 = mul( mat3x1, mat1x4 ); // result is float3x4 matrix
 }

+ 12 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.transpose.hlsl

@@ -61,4 +61,16 @@ void main() {
 // CHECK-NEXT:                 OpStore %rt [[rt]]
   uint4x4 r;
   uint4x4 rt = transpose(r);
+
+// A 1-D matrix is in fact a vector, and its transpose is the vector itself.
+//
+// CHECK:      [[s:%\d+]] = OpLoad %v4float %s
+// CHECK-NEXT:              OpStore %st [[s]]
+  float1x4 s;
+  float4x1 st = transpose(s);
+
+// CHECK:      [[t:%\d+]] = OpLoad %float %t
+// CHECK-NEXT:              OpStore %tt [[t]]
+  float1x1 t;
+  float1x1 tt = transpose(t);
 }