|
@@ -27,6 +27,8 @@ namespace spirv {
|
|
|
|
|
|
namespace {
|
|
namespace {
|
|
|
|
|
|
|
|
+// TODO: Maybe we should move these type probing functions to TypeTranslator.
|
|
|
|
+
|
|
/// Returns true if the given type is a bool or vector of bool type.
|
|
/// Returns true if the given type is a bool or vector of bool type.
|
|
bool isBoolOrVecOfBoolType(QualType type) {
|
|
bool isBoolOrVecOfBoolType(QualType type) {
|
|
return type->isBooleanType() ||
|
|
return type->isBooleanType() ||
|
|
@@ -57,6 +59,36 @@ bool isFloatOrVecOfFloatType(QualType type) {
|
|
hlsl::GetHLSLVecElementType(type)->isFloatingType());
|
|
hlsl::GetHLSLVecElementType(type)->isFloatingType());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+/// Returns true if the given type is a bool or vector/matrix of bool type.
|
|
|
|
+bool isBoolOrVecMatOfBoolType(QualType type) {
|
|
|
|
+ return isBoolOrVecOfBoolType(type) ||
|
|
|
|
+ (hlsl::IsHLSLMatType(type) &&
|
|
|
|
+ hlsl::GetHLSLMatElementType(type)->isBooleanType());
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+/// Returns true if the given type is a signed integer or vector/matrix of
|
|
|
|
+/// signed integer type.
|
|
|
|
+bool isSintOrVecMatOfSintType(QualType type) {
|
|
|
|
+ return isSintOrVecOfSintType(type) ||
|
|
|
|
+ (hlsl::IsHLSLMatType(type) &&
|
|
|
|
+ hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+/// Returns true if the given type is an unsigned integer or vector/matrix of
|
|
|
|
+/// unsigned integer type.
|
|
|
|
+bool isUintOrVecMatOfUintType(QualType type) {
|
|
|
|
+ return isUintOrVecOfUintType(type) ||
|
|
|
|
+ (hlsl::IsHLSLMatType(type) &&
|
|
|
|
+ hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+/// Returns true if the given type is a float or vector/matrix of float type.
|
|
|
|
+bool isFloatOrVecMatOfFloatType(QualType type) {
|
|
|
|
+ return isFloatOrVecOfFloatType(type) ||
|
|
|
|
+ (hlsl::IsHLSLMatType(type) &&
|
|
|
|
+ hlsl::GetHLSLMatElementType(type)->isFloatingType());
|
|
|
|
+}
|
|
|
|
+
|
|
bool isCompoundAssignment(BinaryOperatorKind opcode) {
|
|
bool isCompoundAssignment(BinaryOperatorKind opcode) {
|
|
switch (opcode) {
|
|
switch (opcode) {
|
|
case BO_AddAssign:
|
|
case BO_AddAssign:
|
|
@@ -1039,6 +1071,72 @@ public:
|
|
return isCompoundAssignment ? lhsPtr : rhs;
|
|
return isCompoundAssignment ? lhsPtr : rhs;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /// Generates the necessary instructions for conducting the given binary
|
|
|
|
+ /// operation on lhs and rhs.
|
|
|
|
+ ///
|
|
|
|
+ /// This method expects that both lhs and rhs are SPIR-V acceptable matrices.
|
|
|
|
+ uint32_t processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
|
|
|
|
+ const BinaryOperatorKind opcode) {
|
|
|
|
+ // TODO: some code are duplicated from processBinaryOp. Try to unify them.
|
|
|
|
+ const auto lhsType = lhs->getType();
|
|
|
|
+ assert(TypeTranslator::isSpirvAcceptableMatrixType(lhsType));
|
|
|
|
+ const spv::Op spvOp = translateOp(opcode, lhsType);
|
|
|
|
+
|
|
|
|
+ uint32_t rhsVal, lhsPtr, lhsVal;
|
|
|
|
+ if (isCompoundAssignment(opcode)) {
|
|
|
|
+ // Evalute rhs before lhs
|
|
|
|
+ rhsVal = doExpr(rhs);
|
|
|
|
+ lhsPtr = doExpr(lhs);
|
|
|
|
+ const uint32_t lhsTy = typeTranslator.translateType(lhsType);
|
|
|
|
+ lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
|
|
|
|
+ } else {
|
|
|
|
+ // Evalute lhs before rhs
|
|
|
|
+ lhsVal = lhsPtr = doExpr(lhs);
|
|
|
|
+ rhsVal = doExpr(rhs);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ switch (opcode) {
|
|
|
|
+ case BO_Add:
|
|
|
|
+ case BO_Sub:
|
|
|
|
+ case BO_Mul:
|
|
|
|
+ case BO_Div:
|
|
|
|
+ case BO_Rem:
|
|
|
|
+ case BO_AddAssign:
|
|
|
|
+ case BO_SubAssign:
|
|
|
|
+ case BO_MulAssign:
|
|
|
|
+ case BO_DivAssign:
|
|
|
|
+ case BO_RemAssign: {
|
|
|
|
+ const uint32_t vecType = typeTranslator.getComponentVectorType(lhsType);
|
|
|
|
+
|
|
|
|
+ uint32_t rowCount = 0, colCount = 0;
|
|
|
|
+ hlsl::GetHLSLMatRowColCount(lhsType, rowCount, colCount);
|
|
|
|
+
|
|
|
|
+ llvm::SmallVector<uint32_t, 4> vectors;
|
|
|
|
+ // Extract each component vector and do operation on it
|
|
|
|
+ for (uint32_t i = 0; i < rowCount; ++i) {
|
|
|
|
+ const uint32_t lhsVec =
|
|
|
|
+ theBuilder.createCompositeExtract(vecType, lhsVal, {i});
|
|
|
|
+ const uint32_t rhsVec =
|
|
|
|
+ theBuilder.createCompositeExtract(vecType, rhsVal, {i});
|
|
|
|
+ vectors.push_back(
|
|
|
|
+ theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Construct the result matrix
|
|
|
|
+ return theBuilder.createCompositeConstruct(
|
|
|
|
+ typeTranslator.translateType(lhsType), vectors);
|
|
|
|
+ }
|
|
|
|
+ case BO_Assign:
|
|
|
|
+ llvm_unreachable("assignment should not be handled here");
|
|
|
|
+ default:
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ emitError("BinaryOperator '%0' for matrices not supported yet")
|
|
|
|
+ << BinaryOperator::getOpcodeStr(opcode);
|
|
|
|
+ return 0;
|
|
|
|
+ }
|
|
|
|
+
|
|
/// Generates the necessary instructions for conducting the given binary
|
|
/// Generates the necessary instructions for conducting the given binary
|
|
/// operation on lhs and rhs. If lhsResultId is not nullptr, the evaluated
|
|
/// operation on lhs and rhs. If lhsResultId is not nullptr, the evaluated
|
|
/// pointer from lhs during the process will be written into it. If
|
|
/// pointer from lhs during the process will be written into it. If
|
|
@@ -1049,6 +1147,11 @@ public:
|
|
const uint32_t resultType,
|
|
const uint32_t resultType,
|
|
uint32_t *lhsResultId = nullptr,
|
|
uint32_t *lhsResultId = nullptr,
|
|
const spv::Op mandateGenOpcode = spv::Op::Max) {
|
|
const spv::Op mandateGenOpcode = spv::Op::Max) {
|
|
|
|
+
|
|
|
|
+ if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
|
|
|
|
+ return processMatrixBinaryOp(lhs, rhs, opcode);
|
|
|
|
+ }
|
|
|
|
+
|
|
const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
|
|
const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
|
|
? translateOp(opcode, lhs->getType())
|
|
? translateOp(opcode, lhs->getType())
|
|
: mandateGenOpcode;
|
|
: mandateGenOpcode;
|
|
@@ -1484,6 +1587,28 @@ public:
|
|
assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
|
|
assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
|
|
return doExpr(subExpr);
|
|
return doExpr(subExpr);
|
|
}
|
|
}
|
|
|
|
+ case CastKind::CK_HLSLVectorToMatrixCast: {
|
|
|
|
+ // The target type should already be a 1xN matrix type.
|
|
|
|
+ assert(TypeTranslator::is1xNMatrixType(toType));
|
|
|
|
+ return doExpr(subExpr);
|
|
|
|
+ }
|
|
|
|
+ case CastKind::CK_HLSLMatrixSplat: {
|
|
|
|
+ if (TypeTranslator::is1x1MatrixType(toType))
|
|
|
|
+ return doExpr(subExpr);
|
|
|
|
+
|
|
|
|
+ emitError("matrix splatting not supported yet");
|
|
|
|
+ return 0;
|
|
|
|
+ }
|
|
|
|
+ case CastKind::CK_HLSLMatrixToScalarCast: {
|
|
|
|
+ // The underlying should already be a matrix of 1x1.
|
|
|
|
+ assert(TypeTranslator::is1x1MatrixType(subExpr->getType()));
|
|
|
|
+ return doExpr(subExpr);
|
|
|
|
+ }
|
|
|
|
+ case CastKind::CK_HLSLMatrixToVectorCast: {
|
|
|
|
+ // The underlying should already be a matrix of 1xN.
|
|
|
|
+ assert(TypeTranslator::is1xNMatrixType(subExpr->getType()));
|
|
|
|
+ return doExpr(subExpr);
|
|
|
|
+ }
|
|
case CastKind::CK_FunctionToPointerDecay:
|
|
case CastKind::CK_FunctionToPointerDecay:
|
|
// Just need to return the function id
|
|
// Just need to return the function id
|
|
return doExpr(subExpr);
|
|
return doExpr(subExpr);
|
|
@@ -1840,10 +1965,9 @@ public:
|
|
/// Translates the given frontend binary operator into its SPIR-V equivalent
|
|
/// Translates the given frontend binary operator into its SPIR-V equivalent
|
|
/// taking consideration of the operand type.
|
|
/// taking consideration of the operand type.
|
|
spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
|
|
spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
|
|
- // TODO: the following is not considering vector types yet.
|
|
|
|
- const bool isSintType = isSintOrVecOfSintType(type);
|
|
|
|
- const bool isUintType = isUintOrVecOfUintType(type);
|
|
|
|
- const bool isFloatType = isFloatOrVecOfFloatType(type);
|
|
|
|
|
|
+ const bool isSintType = isSintOrVecMatOfSintType(type);
|
|
|
|
+ const bool isUintType = isUintOrVecMatOfUintType(type);
|
|
|
|
+ const bool isFloatType = isFloatOrVecMatOfFloatType(type);
|
|
|
|
|
|
#define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
|
|
#define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
|
|
\
|
|
\
|