|
@@ -20,6 +20,7 @@
|
|
|
#include "clang/SPIRV/TypeTranslator.h"
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
+#include "llvm/ADT/StringExtras.h"
|
|
|
|
|
|
namespace clang {
|
|
|
namespace spirv {
|
|
@@ -325,7 +326,7 @@ public:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- void doStmt(const Stmt *stmt) {
|
|
|
+ void doStmt(const Stmt *stmt, llvm::ArrayRef<const Attr *> attrs = {}) {
|
|
|
if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
|
|
|
for (auto *st : compoundStmt->body())
|
|
|
doStmt(st);
|
|
@@ -337,6 +338,14 @@ public:
|
|
|
}
|
|
|
} else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
|
|
|
doIfStmt(ifStmt);
|
|
|
+ } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
|
|
|
+ doSwitchStmt(switchStmt, attrs);
|
|
|
+ } else if (const auto *caseStmt = dyn_cast<CaseStmt>(stmt)) {
|
|
|
+ processCaseStmtOrDefaultStmt(stmt);
|
|
|
+ } else if (const auto *defaultStmt = dyn_cast<DefaultStmt>(stmt)) {
|
|
|
+ processCaseStmtOrDefaultStmt(stmt);
|
|
|
+ } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
|
|
|
+ doBreakStmt(breakStmt);
|
|
|
} else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
|
|
|
doForStmt(forStmt);
|
|
|
} else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
|
|
@@ -344,6 +353,8 @@ public:
|
|
|
} else if (const auto *expr = dyn_cast<Expr>(stmt)) {
|
|
|
// All cases for expressions used as statements
|
|
|
doExpr(expr);
|
|
|
+ } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
|
|
|
+ doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
|
|
|
} else {
|
|
|
emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
|
|
|
}
|
|
@@ -411,6 +422,225 @@ public:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// \brief Returns true iff *all* the case values in the given switch
|
|
|
+ /// statement are integer literals. In such cases OpSwitch can be used to
|
|
|
+ /// represent the switch statement.
|
|
|
+ /// We only care about the case values to be compared with the selector. They
|
|
|
+ /// may appear in the top level CaseStmt or be nested in a CompoundStmt.Fall
|
|
|
+ /// through cases will result in the second situation.
|
|
|
+ bool allSwitchCasesAreIntegerLiterals(const Stmt *root) {
|
|
|
+ if (!root)
|
|
|
+ return false;
|
|
|
+
|
|
|
+ const auto *caseStmt = dyn_cast<CaseStmt>(root);
|
|
|
+ const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
|
|
|
+ if (!caseStmt && !compoundStmt)
|
|
|
+ return true;
|
|
|
+
|
|
|
+ if (caseStmt) {
|
|
|
+ const Expr *caseExpr = caseStmt->getLHS();
|
|
|
+ return caseExpr && caseExpr->isEvaluatable(astContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Recurse down if facing a compound statement.
|
|
|
+ for (auto *st : compoundStmt->body())
|
|
|
+ if (!allSwitchCasesAreIntegerLiterals(st))
|
|
|
+ return false;
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ /// \brief Recursively discovers all CaseStmt and DefaultStmt under the
|
|
|
+ /// sub-tree of the given root. Recursively goes down the tree iff it finds a
|
|
|
+ /// CaseStmt, DefaultStmt, or CompoundStmt. It does not recurse on other
|
|
|
+ /// statement types. For each discovered case, a basic block is created and
|
|
|
+ /// registered within the module, and added as a successor to the current
|
|
|
+ /// active basic block.
|
|
|
+ ///
|
|
|
+ /// Writes a vector of (integer, basic block label) pairs for all cases to the
|
|
|
+ /// given 'targets' argument. If a DefaultStmt is found, it also returns the
|
|
|
+ /// label for the default basic block through the defaultBB parameter. This
|
|
|
+ /// method panics if it finds a case value that is not an integer literal.
|
|
|
+ void discoverAllCaseStmtInSwitchStmt(
|
|
|
+ const Stmt *root, uint32_t *defaultBB,
|
|
|
+ std::vector<std::pair<uint32_t, uint32_t>> *targets) {
|
|
|
+ if (!root)
|
|
|
+ return;
|
|
|
+
|
|
|
+ // A switch case can only appear in DefaultStmt, CaseStmt, or
|
|
|
+ // CompoundStmt. For the rest, we can just return.
|
|
|
+ const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
|
|
|
+ const auto *caseStmt = dyn_cast<CaseStmt>(root);
|
|
|
+ const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
|
|
|
+ if (!defaultStmt && !caseStmt && !compoundStmt)
|
|
|
+ return;
|
|
|
+
|
|
|
+ // Recurse down if facing a compound statement.
|
|
|
+ if (compoundStmt) {
|
|
|
+ for (auto *st : compoundStmt->body())
|
|
|
+ discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::string caseLabel;
|
|
|
+ uint32_t caseValue = 0;
|
|
|
+ if (defaultStmt) {
|
|
|
+ // This is the default branch.
|
|
|
+ caseLabel = "switch.default";
|
|
|
+ } else if (caseStmt) {
|
|
|
+ // This is a non-default case.
|
|
|
+ // When using OpSwitch, we only allow integer literal cases. e.g:
|
|
|
+ // case <literal_integer>: {...; break;}
|
|
|
+ const Expr *caseExpr = caseStmt->getLHS();
|
|
|
+ assert(caseExpr && caseExpr->isEvaluatable(astContext));
|
|
|
+ auto bitWidth = astContext.getIntWidth(caseExpr->getType());
|
|
|
+ if (bitWidth != 32)
|
|
|
+ emitError("Switch statement translation currently only supports 32-bit "
|
|
|
+ "integer case values.");
|
|
|
+ Expr::EvalResult evalResult;
|
|
|
+ caseExpr->EvaluateAsRValue(evalResult, astContext);
|
|
|
+ const int64_t value = evalResult.Val.getInt().getSExtValue();
|
|
|
+ caseValue = static_cast<uint32_t>(value);
|
|
|
+ caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
|
|
|
+ llvm::itostr(std::abs(value));
|
|
|
+ }
|
|
|
+ const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
|
|
|
+ theBuilder.addSuccessor(caseBB);
|
|
|
+ stmtBasicBlock[root] = caseBB;
|
|
|
+
|
|
|
+ // Add all cases to the 'targets' vector.
|
|
|
+ if (caseStmt)
|
|
|
+ targets->emplace_back(caseValue, caseBB);
|
|
|
+
|
|
|
+ // The default label is not part of the 'targets' vector that is passed
|
|
|
+ // to the OpSwitch instruction.
|
|
|
+ // If default statement was discovered, return its label via defaultBB.
|
|
|
+ if (defaultStmt)
|
|
|
+ *defaultBB = caseBB;
|
|
|
+
|
|
|
+ // Process cases nested in other cases. It happens when we have fall through
|
|
|
+ // cases. For example:
|
|
|
+ // case 1: case 2: ...; break;
|
|
|
+ // will result in the CaseSmt for case 2 nested in the one for case 1.
|
|
|
+ discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
|
|
|
+ : defaultStmt->getSubStmt(),
|
|
|
+ defaultBB, targets);
|
|
|
+ }
|
|
|
+
|
|
|
+ void processSwitchStmtUsingSpirvOpSwitch(const SwitchStmt *switchStmt) {
|
|
|
+ // First handle the condition variable DeclStmt if one exists.
|
|
|
+ // For example: handle 'int a = b' in the following:
|
|
|
+ // switch (int a = b) {...}
|
|
|
+ const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt();
|
|
|
+ if (condVarDeclStmt)
|
|
|
+ doStmt(condVarDeclStmt);
|
|
|
+
|
|
|
+ const uint32_t selector = doExpr(switchStmt->getCond());
|
|
|
+
|
|
|
+ // We need a merge block regardless of the number of switch cases.
|
|
|
+ // Since OpSwitch always requires a default label, if the switch statement
|
|
|
+ // does not have a default branch, we use the merge block as the default
|
|
|
+ // target.
|
|
|
+ const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
|
|
|
+ theBuilder.setMergeTarget(mergeBB);
|
|
|
+ breakStack.push(mergeBB);
|
|
|
+ uint32_t defaultBB = mergeBB;
|
|
|
+
|
|
|
+ // (literal, labelId) pairs to pass to the OpSwitch instruction.
|
|
|
+ std::vector<std::pair<uint32_t, uint32_t>> targets;
|
|
|
+ discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB,
|
|
|
+ &targets);
|
|
|
+
|
|
|
+ // Create the OpSelectionMerge and OpSwitch.
|
|
|
+ theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
|
|
|
+
|
|
|
+ // Handle the switch body.
|
|
|
+ doStmt(switchStmt->getBody());
|
|
|
+
|
|
|
+ if (!theBuilder.isCurrentBasicBlockTerminated())
|
|
|
+ theBuilder.createBranch(mergeBB);
|
|
|
+ theBuilder.setInsertPoint(mergeBB);
|
|
|
+ breakStack.pop();
|
|
|
+ }
|
|
|
+
|
|
|
+ void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
|
|
|
+ emitError("Translating Switch statements using If statements is not "
|
|
|
+ "implemented yet.");
|
|
|
+ }
|
|
|
+
|
|
|
+ void doSwitchStmt(const SwitchStmt *switchStmt,
|
|
|
+ llvm::ArrayRef<const Attr *> attrs = {}) {
|
|
|
+ // Switch statements are composed of:
|
|
|
+ // switch (<condition variable>) {
|
|
|
+ // <CaseStmt>
|
|
|
+ // <CaseStmt>
|
|
|
+ // <CaseStmt>
|
|
|
+ // <DefaultStmt> (optional)
|
|
|
+ // }
|
|
|
+ //
|
|
|
+ // +-------+
|
|
|
+ // | check |
|
|
|
+ // +-------+
|
|
|
+ // |
|
|
|
+ // +-------+-------+----------------+---------------+
|
|
|
+ // | 1 | 2 | 3 | (others)
|
|
|
+ // v v v v
|
|
|
+ // +-------+ +-------------+ +-------+ +------------+
|
|
|
+ // | case1 | | case2 | | case3 | ... | default |
|
|
|
+ // | | |(fallthrough)|---->| | | (optional) |
|
|
|
+ // +-------+ |+------------+ +-------+ +------------+
|
|
|
+ // | | |
|
|
|
+ // | | |
|
|
|
+ // | +-------+ | |
|
|
|
+ // | | | <--------------------+ |
|
|
|
+ // +-> | merge | |
|
|
|
+ // | | <-------------------------------------+
|
|
|
+ // +-------+
|
|
|
+
|
|
|
+ // If no attributes are given, or if "forcecase" attribute was provided,
|
|
|
+ // we'll do our best to use OpSwitch if possible.
|
|
|
+ // If any of the cases compares to a variable (rather than an integer
|
|
|
+ // literal), we cannot use OpSwitch because OpSwitch expects literal
|
|
|
+ // numbers as parameters.
|
|
|
+ const bool isAttrForceCase =
|
|
|
+ !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
|
|
|
+ const bool canUseSpirvOpSwitch =
|
|
|
+ (attrs.empty() || isAttrForceCase) &&
|
|
|
+ allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
|
|
|
+
|
|
|
+ if (isAttrForceCase && !canUseSpirvOpSwitch)
|
|
|
+ emitWarning("Ignored 'forcecase' attribute for the switch statement "
|
|
|
+ "since one or more case values are not integer literals.");
|
|
|
+
|
|
|
+ if (canUseSpirvOpSwitch)
|
|
|
+ processSwitchStmtUsingSpirvOpSwitch(switchStmt);
|
|
|
+ else
|
|
|
+ processSwitchStmtUsingIfStmts(switchStmt);
|
|
|
+ }
|
|
|
+
|
|
|
+ void processCaseStmtOrDefaultStmt(const Stmt *stmt) {
|
|
|
+ auto *caseStmt = dyn_cast<CaseStmt>(stmt);
|
|
|
+ auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
|
|
|
+ assert(caseStmt || defaultStmt);
|
|
|
+
|
|
|
+ uint32_t caseBB = stmtBasicBlock[stmt];
|
|
|
+ if (!theBuilder.isCurrentBasicBlockTerminated()) {
|
|
|
+ // We are about to handle the case passed in as parameter. If the current
|
|
|
+ // basic block is not terminated, it means the previous case is a fall
|
|
|
+ // through case. We need to link it to the case to be processed.
|
|
|
+ theBuilder.createBranch(caseBB);
|
|
|
+ theBuilder.addSuccessor(caseBB);
|
|
|
+ }
|
|
|
+ theBuilder.setInsertPoint(caseBB);
|
|
|
+ doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
|
|
|
+ }
|
|
|
+
|
|
|
+ void doBreakStmt(const BreakStmt *breakStmt) {
|
|
|
+ uint32_t breakTargetBB = breakStack.top();
|
|
|
+ theBuilder.addSuccessor(breakTargetBB);
|
|
|
+ theBuilder.createBranch(breakTargetBB);
|
|
|
+ }
|
|
|
+
|
|
|
void doIfStmt(const IfStmt *ifStmt) {
|
|
|
// if statements are composed of:
|
|
|
// if (<check>) { <then> } else { <else> }
|
|
@@ -1862,6 +2092,24 @@ private:
|
|
|
uint32_t entryFunctionId;
|
|
|
/// The current function under traversal.
|
|
|
const FunctionDecl *curFunction;
|
|
|
+
|
|
|
+ /// For loops, while loops, and switch statements may encounter "break"
|
|
|
+ /// statements that alter their control flow. At any point the break statement
|
|
|
+ /// is observed, the control flow jumps to the inner-most scope's merge block.
|
|
|
+ /// For instance: the break in the following example should cause a branch to
|
|
|
+ /// the SwitchMergeBB, not ForLoopMergeBB:
|
|
|
+ /// for (...) {
|
|
|
+ /// switch(...) {
|
|
|
+ /// case 1: break;
|
|
|
+ /// }
|
|
|
+ /// <--- SwitchMergeBB ---->
|
|
|
+ /// }
|
|
|
+ /// <----- ForLoopMergeBB --->
|
|
|
+ /// This stack keeps track of the basic blocks to which branching could occur.
|
|
|
+ std::stack<uint32_t> breakStack;
|
|
|
+
|
|
|
+ /// Maps a given statement to the basic block that is associated with it.
|
|
|
+ llvm::DenseMap<const Stmt *, uint32_t> stmtBasicBlock;
|
|
|
};
|
|
|
|
|
|
} // end namespace spirv
|