|
@@ -122,6 +122,18 @@ bool isCompoundAssignment(BinaryOperatorKind opcode) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+bool isSpirvMatrixOp(spv::Op opcode) {
|
|
|
|
+ switch (opcode) {
|
|
|
|
+ case spv::Op::OpMatrixTimesMatrix:
|
|
|
|
+ case spv::Op::OpMatrixTimesVector:
|
|
|
|
+ case spv::Op::OpMatrixTimesScalar:
|
|
|
|
+ return true;
|
|
|
|
+ default:
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ return false;
|
|
|
|
+}
|
|
|
|
+
|
|
} // namespace
|
|
} // namespace
|
|
|
|
|
|
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
|
|
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
|
|
@@ -731,8 +743,10 @@ uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
|
|
if (opcode == BO_Assign)
|
|
if (opcode == BO_Assign)
|
|
return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
|
|
return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
|
|
|
|
|
|
- // Try to optimize floatN * float case
|
|
|
|
|
|
+ // Try to optimize floatMxN * float and floatN * float case
|
|
if (opcode == BO_Mul) {
|
|
if (opcode == BO_Mul) {
|
|
|
|
+ if (const uint32_t result = tryToGenFloatMatrixScale(expr))
|
|
|
|
+ return result;
|
|
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
|
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
|
return result;
|
|
return result;
|
|
}
|
|
}
|
|
@@ -940,8 +954,10 @@ uint32_t
|
|
SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
|
|
SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
|
|
const auto opcode = expr->getOpcode();
|
|
const auto opcode = expr->getOpcode();
|
|
|
|
|
|
- // Try to optimize floatN *= float case
|
|
|
|
|
|
+ // Try to optimize floatMxN *= float and floatN *= float case
|
|
if (opcode == BO_MulAssign) {
|
|
if (opcode == BO_MulAssign) {
|
|
|
|
+ if (const uint32_t result = tryToGenFloatMatrixScale(expr))
|
|
|
|
+ return result;
|
|
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
|
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
|
return result;
|
|
return result;
|
|
}
|
|
}
|
|
@@ -1370,7 +1386,11 @@ uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
|
|
const uint32_t resultType,
|
|
const uint32_t resultType,
|
|
uint32_t *lhsResultId,
|
|
uint32_t *lhsResultId,
|
|
const spv::Op mandateGenOpcode) {
|
|
const spv::Op mandateGenOpcode) {
|
|
- if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
|
|
|
|
|
|
+ // If the operands are of matrix type, we need to dispatch the operation
|
|
|
|
+ // onto each element vector iff the operands are not degenerated matrices
|
|
|
|
+ // and we don't have a matrix specific SPIR-V instruction for the operation.
|
|
|
|
+ if (!isSpirvMatrixOp(mandateGenOpcode) &&
|
|
|
|
+ TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
|
|
return processMatrixBinaryOp(lhs, rhs, opcode);
|
|
return processMatrixBinaryOp(lhs, rhs, opcode);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1647,6 +1667,68 @@ uint32_t SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
|
|
return 0;
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+uint32_t SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
|
|
|
|
+ const QualType type = expr->getType();
|
|
|
|
+ // We can only translate floatMxN * float into OpMatrixTimesScalar.
|
|
|
|
+ // So the result type must be floatMxN.
|
|
|
|
+ if (!hlsl::IsHLSLMatType(type) ||
|
|
|
|
+ !hlsl::GetHLSLMatElementType(type)->isFloatingType())
|
|
|
|
+ return 0;
|
|
|
|
+
|
|
|
|
+ const Expr *lhs = expr->getLHS();
|
|
|
|
+ const Expr *rhs = expr->getRHS();
|
|
|
|
+ const QualType lhsType = lhs->getType();
|
|
|
|
+ const QualType rhsType = rhs->getType();
|
|
|
|
+
|
|
|
|
+ const auto selectOpcode = [](const QualType ty) {
|
|
|
|
+ return TypeTranslator::isMx1MatrixType(ty) ||
|
|
|
|
+ TypeTranslator::is1xNMatrixType(ty)
|
|
|
|
+ ? spv::Op::OpVectorTimesScalar
|
|
|
|
+ : spv::Op::OpMatrixTimesScalar;
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ // Multiplying a float matrix with a float scalar will be represented in
|
|
|
|
+ // AST via a binary operation with two float matrices as operands; one of
|
|
|
|
+ // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
|
|
|
|
+
|
|
|
|
+ // matrix * scalar
|
|
|
|
+ if (hlsl::IsHLSLMatType(lhsType)) {
|
|
|
|
+ if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
|
|
|
|
+ if (cast->getCastKind() == CK_HLSLMatrixSplat) {
|
|
|
|
+ const uint32_t matType = typeTranslator.translateType(expr->getType());
|
|
|
|
+ const spv::Op opcode = selectOpcode(lhsType);
|
|
|
|
+ if (isa<CompoundAssignOperator>(expr)) {
|
|
|
|
+ uint32_t lhsPtr = 0;
|
|
|
|
+ const uint32_t result =
|
|
|
|
+ processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
|
|
|
|
+ matType, &lhsPtr, opcode);
|
|
|
|
+ return processAssignment(lhs, result, true, lhsPtr);
|
|
|
|
+ } else {
|
|
|
|
+ return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
|
|
|
|
+ matType, nullptr, opcode);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // scalar * matrix
|
|
|
|
+ if (hlsl::IsHLSLMatType(rhsType)) {
|
|
|
|
+ if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
|
|
|
|
+ if (cast->getCastKind() == CK_HLSLMatrixSplat) {
|
|
|
|
+ const uint32_t matType = typeTranslator.translateType(expr->getType());
|
|
|
|
+ const spv::Op opcode = selectOpcode(rhsType);
|
|
|
|
+ // We need to switch the positions of lhs and rhs here because
|
|
|
|
+ // OpMatrixTimesScalar requires the first operand to be a matrix and
|
|
|
|
+ // the second to be a scalar.
|
|
|
|
+ return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
|
|
|
|
+ matType, nullptr, opcode);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return 0;
|
|
|
|
+}
|
|
|
|
+
|
|
uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
|
|
uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
|
|
const uint32_t rhs) {
|
|
const uint32_t rhs) {
|
|
// Assigning to a vector swizzling lhs is tricky if we are neither
|
|
// Assigning to a vector swizzling lhs is tricky if we are neither
|