|
@@ -2497,6 +2497,8 @@ uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
|
|
switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
|
|
switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
|
|
case hlsl::IntrinsicOp::IOP_dot:
|
|
case hlsl::IntrinsicOp::IOP_dot:
|
|
return processIntrinsicDot(callExpr);
|
|
return processIntrinsicDot(callExpr);
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_mul:
|
|
|
|
+ return processIntrinsicMul(callExpr);
|
|
case hlsl::IntrinsicOp::IOP_all:
|
|
case hlsl::IntrinsicOp::IOP_all:
|
|
return processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
|
|
return processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
|
|
case hlsl::IntrinsicOp::IOP_any:
|
|
case hlsl::IntrinsicOp::IOP_any:
|
|
@@ -2550,6 +2552,139 @@ uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
|
|
return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, actOnEachVecInMat);
|
|
return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, actOnEachVecInMat);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+uint32_t SPIRVEmitter::processIntrinsicMul(const CallExpr *callExpr) {
|
|
|
|
+ const QualType returnType = callExpr->getType();
|
|
|
|
+ const uint32_t returnTypeId =
|
|
|
|
+ typeTranslator.translateType(callExpr->getType());
|
|
|
|
+
|
|
|
|
+ // Get the function parameters. Expect 2 parameters.
|
|
|
|
+ assert(callExpr->getNumArgs() == 2u);
|
|
|
|
+ const Expr *arg0 = callExpr->getArg(0);
|
|
|
|
+ const Expr *arg1 = callExpr->getArg(1);
|
|
|
|
+ const QualType arg0Type = arg0->getType();
|
|
|
|
+ const QualType arg1Type = arg1->getType();
|
|
|
|
+
|
|
|
|
+ // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
|
|
|
|
+ // vector, or matrix. The frontend ensures that the two arguments have the
|
|
|
|
+ // same component type. The only allowed component types are int and float.
|
|
|
|
+
|
|
|
|
+ // mul(scalar, vector)
|
|
|
|
+ {
|
|
|
|
+ uint32_t elemCount = 0;
|
|
|
|
+ if (TypeTranslator::isScalarType(arg0Type) &&
|
|
|
|
+ TypeTranslator::isVectorType(arg1Type, nullptr, &elemCount)) {
|
|
|
|
+
|
|
|
|
+ const uint32_t arg1Id = doExpr(arg1);
|
|
|
|
+
|
|
|
|
+ // We can use OpVectorTimesScalar if arguments are floats.
|
|
|
|
+ if (arg0Type->isFloatingType())
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
|
|
|
|
+ returnTypeId, arg1Id, doExpr(arg0));
|
|
|
|
+
|
|
|
|
+ // Use OpIMul for integers
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId,
|
|
|
|
+ createVectorSplat(arg0, elemCount),
|
|
|
|
+ arg1Id);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(vector, scalar)
|
|
|
|
+ {
|
|
|
|
+ uint32_t elemCount = 0;
|
|
|
|
+ if (TypeTranslator::isVectorType(arg0Type, nullptr, &elemCount) &&
|
|
|
|
+ TypeTranslator::isScalarType(arg1Type)) {
|
|
|
|
+
|
|
|
|
+ const uint32_t arg0Id = doExpr(arg0);
|
|
|
|
+
|
|
|
|
+ // We can use OpVectorTimesScalar if arguments are floats.
|
|
|
|
+ if (arg1Type->isFloatingType())
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
|
|
|
|
+ returnTypeId, arg0Id, doExpr(arg1));
|
|
|
|
+
|
|
|
|
+ // Use OpIMul for integers
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId, arg0Id,
|
|
|
|
+ createVectorSplat(arg1, elemCount));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(vector, vector)
|
|
|
|
+ if (TypeTranslator::isVectorType(arg0Type) &&
|
|
|
|
+ TypeTranslator::isVectorType(arg1Type))
|
|
|
|
+ return processIntrinsicDot(callExpr);
|
|
|
|
+
|
|
|
|
+ // All the following cases require handling arg0 and arg1 expressions first.
|
|
|
|
+ const uint32_t arg0Id = doExpr(arg0);
|
|
|
|
+ const uint32_t arg1Id = doExpr(arg1);
|
|
|
|
+
|
|
|
|
+ // mul(scalar, scalar)
|
|
|
|
+ if (TypeTranslator::isScalarType(arg0Type) &&
|
|
|
|
+ TypeTranslator::isScalarType(arg1Type))
|
|
|
|
+ return theBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type),
|
|
|
|
+ returnTypeId, arg0Id, arg1Id);
|
|
|
|
+
|
|
|
|
+ // mul(scalar, matrix)
|
|
|
|
+ if (TypeTranslator::isScalarType(arg0Type) &&
|
|
|
|
+ TypeTranslator::isMxNMatrix(arg1Type)) {
|
|
|
|
+ // We currently only support float matrices. So we can use
|
|
|
|
+ // OpMatrixTimesScalar
|
|
|
|
+ if (arg0Type->isFloatingType())
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
|
|
|
|
+ returnTypeId, arg1Id, arg0Id);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(matrix, scalar)
|
|
|
|
+ if (TypeTranslator::isScalarType(arg1Type) &&
|
|
|
|
+ TypeTranslator::isMxNMatrix(arg0Type)) {
|
|
|
|
+ // We currently only support float matrices. So we can use
|
|
|
|
+ // OpMatrixTimesScalar
|
|
|
|
+ if (arg1Type->isFloatingType())
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
|
|
|
|
+ returnTypeId, arg0Id, arg1Id);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(vector, matrix)
|
|
|
|
+ {
|
|
|
|
+ QualType elemType = {};
|
|
|
|
+ uint32_t elemCount = 0, numRows = 0;
|
|
|
|
+ if (TypeTranslator::isVectorType(arg0Type, &elemType, &elemCount) &&
|
|
|
|
+ TypeTranslator::isMxNMatrix(arg1Type, nullptr, &numRows, nullptr) &&
|
|
|
|
+ elemType->isFloatingType()) {
|
|
|
|
+ assert(elemCount == numRows);
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
|
|
|
|
+ returnTypeId, arg1Id, arg0Id);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(matrix, vector)
|
|
|
|
+ {
|
|
|
|
+ QualType elemType = {};
|
|
|
|
+ uint32_t elemCount = 0, numCols = 0;
|
|
|
|
+ if (TypeTranslator::isMxNMatrix(arg0Type, nullptr, nullptr, &numCols) &&
|
|
|
|
+ TypeTranslator::isVectorType(arg1Type, &elemType, &elemCount) &&
|
|
|
|
+ elemType->isFloatingType()) {
|
|
|
|
+ assert(elemCount == numCols);
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
|
|
|
|
+ returnTypeId, arg1Id, arg0Id);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // mul(matrix, matrix)
|
|
|
|
+ {
|
|
|
|
+ QualType elemType = {};
|
|
|
|
+ uint32_t arg0Cols = 0, arg1Rows = 0;
|
|
|
|
+ if (TypeTranslator::isMxNMatrix(arg0Type, &elemType, nullptr, &arg0Cols) &&
|
|
|
|
+ TypeTranslator::isMxNMatrix(arg1Type, nullptr, &arg1Rows, nullptr) &&
|
|
|
|
+ elemType->isFloatingType()) {
|
|
|
|
+ assert(arg0Cols == arg1Rows);
|
|
|
|
+ return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
|
|
|
|
+ returnTypeId, arg1Id, arg0Id);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ emitError("Unsupported arguments passed to mul() function.");
|
|
|
|
+ return 0;
|
|
|
|
+}
|
|
|
|
+
|
|
uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
|
|
uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
|
|
const QualType returnType = callExpr->getType();
|
|
const QualType returnType = callExpr->getType();
|
|
const uint32_t returnTypeId =
|
|
const uint32_t returnTypeId =
|