||
- // Copyright (c) 2018 Google LLC.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include "source/opt/scalar_analysis.h"
- #include <functional>
- #include <string>
- #include <utility>
- #include "source/opt/ir_context.h"
- // Transforms a given scalar operation instruction into a DAG representation.
- //
- // 1. Take an instruction and traverse its operands until we reach a
- // constant node or an instruction which we do not know how to compute the
- // value, such as a load.
- //
- // 2. Create a new node for each instruction traversed and build the nodes for
- // the in operands of that instruction as well.
- //
- // 3. Add the operand nodes as children of the first and hash the node. Use the
- // hash to see if the node is already in the cache. We ensure the children are
- // always in sorted order so that two nodes with the same children but inserted
- // in a different order have the same hash and so that the overloaded operator==
- // will return true. If the node is already in the cache return the cached
- // version instead.
- //
- // 4. The created DAG can then be simplified by
- // ScalarAnalysis::SimplifyExpression, implemented in
- // scalar_analysis_simplification.cpp. See that file for further information on
- // the simplification process.
- //
- namespace spvtools {
- namespace opt {
- uint32_t SENode::NumberOfNodes = 0;
- ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
- : context_(context), pretend_equal_{} {
- // Create and cached the CantComputeNode.
- cached_cant_compute_ =
- GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
- }
- SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
- // If operand is can't compute then the whole graph is can't compute.
- if (operand->IsCantCompute()) return CreateCantComputeNode();
- if (operand->GetType() == SENode::Constant) {
- return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
- }
- std::unique_ptr<SENode> negation_node{new SENegative(this)};
- negation_node->AddChild(operand);
- return GetCachedOrAdd(std::move(negation_node));
- }
- SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
- return GetCachedOrAdd(
- std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
- }
- SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
- const Loop* loop, SENode* offset, SENode* coefficient) {
- assert(loop && "Recurrent add expressions must have a valid loop.");
- // If operands are can't compute then the whole graph is can't compute.
- if (offset->IsCantCompute() || coefficient->IsCantCompute())
- return CreateCantComputeNode();
- const Loop* loop_to_use = nullptr;
- if (pretend_equal_[loop]) {
- loop_to_use = pretend_equal_[loop];
- } else {
- loop_to_use = loop;
- }
- std::unique_ptr<SERecurrentNode> phi_node{
- new SERecurrentNode(this, loop_to_use)};
- phi_node->AddOffset(offset);
- phi_node->AddCoefficient(coefficient);
- return GetCachedOrAdd(std::move(phi_node));
- }
- SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
- const Instruction* multiply) {
- assert(multiply->opcode() == spv::Op::OpIMul &&
- "Multiply node did not come from a multiply instruction");
- analysis::DefUseManager* def_use = context_->get_def_use_mgr();
- SENode* op1 =
- AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
- SENode* op2 =
- AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
- return CreateMultiplyNode(op1, op2);
- }
- SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
- SENode* operand_2) {
- // If operands are can't compute then the whole graph is can't compute.
- if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
- return CreateCantComputeNode();
- if (operand_1->GetType() == SENode::Constant &&
- operand_2->GetType() == SENode::Constant) {
- return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
- operand_2->AsSEConstantNode()->FoldToSingleValue());
- }
- std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
- multiply_node->AddChild(operand_1);
- multiply_node->AddChild(operand_2);
- return GetCachedOrAdd(std::move(multiply_node));
- }
- SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
- SENode* operand_2) {
- // Fold if both operands are constant.
- if (operand_1->GetType() == SENode::Constant &&
- operand_2->GetType() == SENode::Constant) {
- return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
- operand_2->AsSEConstantNode()->FoldToSingleValue());
- }
- return CreateAddNode(operand_1, CreateNegation(operand_2));
- }
- SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
- SENode* operand_2) {
- // Fold if both operands are constant and the |simplify| flag is true.
- if (operand_1->GetType() == SENode::Constant &&
- operand_2->GetType() == SENode::Constant) {
- return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
- operand_2->AsSEConstantNode()->FoldToSingleValue());
- }
- // If operands are can't compute then the whole graph is can't compute.
- if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
- return CreateCantComputeNode();
- std::unique_ptr<SENode> add_node{new SEAddNode(this)};
- add_node->AddChild(operand_1);
- add_node->AddChild(operand_2);
- return GetCachedOrAdd(std::move(add_node));
- }
- SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
- auto itr = recurrent_node_map_.find(inst);
- if (itr != recurrent_node_map_.end()) return itr->second;
- SENode* output = nullptr;
- switch (inst->opcode()) {
- case spv::Op::OpPhi: {
- output = AnalyzePhiInstruction(inst);
- break;
- }
- case spv::Op::OpConstant:
- case spv::Op::OpConstantNull: {
- output = AnalyzeConstant(inst);
- break;
- }
- case spv::Op::OpISub:
- case spv::Op::OpIAdd: {
- output = AnalyzeAddOp(inst);
- break;
- }
- case spv::Op::OpIMul: {
- output = AnalyzeMultiplyOp(inst);
- break;
- }
- default: {
- output = CreateValueUnknownNode(inst);
- break;
- }
- }
- return output;
- }
- SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
- if (inst->opcode() == spv::Op::OpConstantNull) return CreateConstant(0);
- assert(inst->opcode() == spv::Op::OpConstant);
- assert(inst->NumInOperands() == 1);
- int64_t value = 0;
- // Look up the instruction in the constant manager.
- const analysis::Constant* constant =
- context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
- if (!constant) return CreateCantComputeNode();
- const analysis::IntConstant* int_constant = constant->AsIntConstant();
- // Exit out if it is a 64 bit integer.
- if (!int_constant || int_constant->words().size() != 1)
- return CreateCantComputeNode();
- if (int_constant->type()->AsInteger()->IsSigned()) {
- value = int_constant->GetS32BitValue();
- } else {
- value = int_constant->GetU32BitValue();
- }
- return CreateConstant(value);
- }
- // Handles both addition and subtraction. If the |sub| flag is set then the
- // addition will be op1+(-op2) otherwise op1+op2.
- SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
- assert((inst->opcode() == spv::Op::OpIAdd ||
- inst->opcode() == spv::Op::OpISub) &&
- "Add node must be created from a OpIAdd or OpISub instruction");
- analysis::DefUseManager* def_use = context_->get_def_use_mgr();
- SENode* op1 =
- AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
- SENode* op2 =
- AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
- // To handle subtraction we wrap the second operand in a unary negation node.
- if (inst->opcode() == spv::Op::OpISub) {
- op2 = CreateNegation(op2);
- }
- return CreateAddNode(op1, op2);
- }
- SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
- // The phi should only have two incoming value pairs.
- if (phi->NumInOperands() != 4) {
- return CreateCantComputeNode();
- }
- analysis::DefUseManager* def_use = context_->get_def_use_mgr();
- // Get the basic block this instruction belongs to.
- BasicBlock* basic_block =
- context_->get_instr_block(const_cast<Instruction*>(phi));
- // And then the function that the basic blocks belongs to.
- Function* function = basic_block->GetParent();
- // Use the function to get the loop descriptor.
- LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
- // We only handle phis in loops at the moment.
- if (!loop_descriptor) return CreateCantComputeNode();
- // Get the innermost loop which this block belongs to.
- Loop* loop = (*loop_descriptor)[basic_block->id()];
- // If the loop doesn't exist or doesn't have a preheader or latch block, exit
- // out.
- if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
- loop->GetHeaderBlock() != basic_block)
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- const Loop* loop_to_use = nullptr;
- if (pretend_equal_[loop]) {
- loop_to_use = pretend_equal_[loop];
- } else {
- loop_to_use = loop;
- }
- std::unique_ptr<SERecurrentNode> phi_node{
- new SERecurrentNode(this, loop_to_use)};
- // We add the node to this map to allow it to be returned before the node is
- // fully built. This is needed as the subsequent call to AnalyzeInstruction
- // could lead back to this |phi| instruction so we return the pointer
- // immediately in AnalyzeInstruction to break the recursion.
- recurrent_node_map_[phi] = phi_node.get();
- // Traverse the operands of the instruction an create new nodes for each one.
- for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
- uint32_t value_id = phi->GetSingleWordInOperand(i);
- uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
- Instruction* value_inst = def_use->GetDef(value_id);
- SENode* value_node = AnalyzeInstruction(value_inst);
- // If any operand is CantCompute then the whole graph is CantCompute.
- if (value_node->IsCantCompute())
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- // If the value is coming from the preheader block then the value is the
- // initial value of the phi.
- if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
- phi_node->AddOffset(value_node);
- } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
- // Assumed to be in the form of step + phi.
- if (value_node->GetType() != SENode::Add)
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- SENode* step_node = nullptr;
- SENode* phi_operand = nullptr;
- SENode* operand_1 = value_node->GetChild(0);
- SENode* operand_2 = value_node->GetChild(1);
- // Find which node is the step term.
- if (!operand_1->AsSERecurrentNode())
- step_node = operand_1;
- else if (!operand_2->AsSERecurrentNode())
- step_node = operand_2;
- // Find which node is the recurrent expression.
- if (operand_1->AsSERecurrentNode())
- phi_operand = operand_1;
- else if (operand_2->AsSERecurrentNode())
- phi_operand = operand_2;
- // If it is not in the form step + phi exit out.
- if (!(step_node && phi_operand))
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- // If the phi operand is not the same phi node exit out.
- if (phi_operand != phi_node.get())
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- if (!IsLoopInvariant(loop, step_node))
- return recurrent_node_map_[phi] = CreateCantComputeNode();
- phi_node->AddCoefficient(step_node);
- }
- }
- // Once the node is fully built we update the map with the version from the
- // cache (if it has already been added to the cache).
- return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
- }
- SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
- const Instruction* inst) {
- std::unique_ptr<SEValueUnknown> load_node{
- new SEValueUnknown(this, inst->result_id())};
- return GetCachedOrAdd(std::move(load_node));
- }
- SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
- return cached_cant_compute_;
- }
- // Add the created node into the cache of nodes. If it already exists return it.
- SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
- std::unique_ptr<SENode> prospective_node) {
- auto itr = node_cache_.find(prospective_node);
- if (itr != node_cache_.end()) {
- return (*itr).get();
- }
- SENode* raw_ptr_to_node = prospective_node.get();
- node_cache_.insert(std::move(prospective_node));
- return raw_ptr_to_node;
- }
- bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
- const SENode* node) const {
- for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
- if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
- const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
- // If the loop which the recurrent expression belongs to is either |loop
- // or a nested loop inside |loop| then we assume it is variant.
- if (loop->IsInsideLoop(header)) {
- return false;
- }
- } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
- // If the instruction is inside the loop we conservatively assume it is
- // loop variant.
- if (loop->IsInsideLoop(unknown->ResultId())) return false;
- }
- }
- return true;
- }
- SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
- SENode* node, const Loop* loop) {
- // Traverse the DAG to find the recurrent expression belonging to |loop|.
- for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
- SERecurrentNode* rec = itr->AsSERecurrentNode();
- if (rec && rec->GetLoop() == loop) {
- return rec->GetCoefficient();
- }
- }
- return CreateConstant(0);
- }
- SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
- SENode* old_child,
- SENode* new_child) {
- // Only handles add.
- if (parent->GetType() != SENode::Add) return parent;
- std::vector<SENode*> new_children;
- for (SENode* child : *parent) {
- if (child == old_child) {
- new_children.push_back(new_child);
- } else {
- new_children.push_back(child);
- }
- }
- std::unique_ptr<SENode> add_node{new SEAddNode(this)};
- for (SENode* child : new_children) {
- add_node->AddChild(child);
- }
- return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
- }
- // Rebuild the |node| eliminating, if it exists, the recurrent term which
- // belongs to the |loop|.
- SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
- SENode* node, const Loop* loop) {
- // If the node is already a recurrent expression belonging to loop then just
- // return the offset.
- SERecurrentNode* recurrent = node->AsSERecurrentNode();
- if (recurrent) {
- if (recurrent->GetLoop() == loop) {
- return recurrent->GetOffset();
- } else {
- return node;
- }
- }
- std::vector<SENode*> new_children;
- // Otherwise find the recurrent node in the children of this node.
- for (auto itr : *node) {
- recurrent = itr->AsSERecurrentNode();
- if (recurrent && recurrent->GetLoop() == loop) {
- new_children.push_back(recurrent->GetOffset());
- } else {
- new_children.push_back(itr);
- }
- }
- std::unique_ptr<SENode> add_node{new SEAddNode(this)};
- for (SENode* child : new_children) {
- add_node->AddChild(child);
- }
- return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
- }
- // Return the recurrent term belonging to |loop| if it appears in the graph
- // starting at |node| or null if it doesn't.
- SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
- const Loop* loop) {
- for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
- SERecurrentNode* rec = itr->AsSERecurrentNode();
- if (rec && rec->GetLoop() == loop) {
- return rec;
- }
- }
- return nullptr;
- }
- std::string SENode::AsString() const {
- switch (GetType()) {
- case Constant:
- return "Constant";
- case RecurrentAddExpr:
- return "RecurrentAddExpr";
- case Add:
- return "Add";
- case Negative:
- return "Negative";
- case Multiply:
- return "Multiply";
- case ValueUnknown:
- return "Value Unknown";
- case CanNotCompute:
- return "Can not compute";
- }
- return "NULL";
- }
- bool SENode::operator==(const SENode& other) const {
- if (GetType() != other.GetType()) return false;
- if (other.GetChildren().size() != children_.size()) return false;
- const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
- // Check the children are the same, for SERecurrentNodes we need to check the
- // offset and coefficient manually as the child vector is sorted by ids so the
- // offset/coefficient information is lost.
- if (!this_as_recurrent) {
- for (size_t index = 0; index < children_.size(); ++index) {
- if (other.GetChildren()[index] != children_[index]) return false;
- }
- } else {
- const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
- // We've already checked the types are the same, this should not fail if
- // this->AsSERecurrentNode() succeeded.
- assert(other_as_recurrent);
- if (this_as_recurrent->GetCoefficient() !=
- other_as_recurrent->GetCoefficient())
- return false;
- if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
- return false;
- if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
- return false;
- }
- // If we're dealing with a value unknown node check both nodes were created by
- // the same instruction.
- if (GetType() == SENode::ValueUnknown) {
- if (AsSEValueUnknown()->ResultId() !=
- other.AsSEValueUnknown()->ResultId()) {
- return false;
- }
- }
- if (AsSEConstantNode()) {
- if (AsSEConstantNode()->FoldToSingleValue() !=
- other.AsSEConstantNode()->FoldToSingleValue())
- return false;
- }
- return true;
- }
- bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
- namespace {
- // Helper functions to insert 32/64 bit values into the 32 bit hash string. This
- // allows us to add pointers to the string by reinterpreting the pointers as
- // uintptr_t. PushToString will deduce the type, call sizeof on it and use
- // that size to call into the correct PushToStringImpl functor depending on
- // whether it is 32 or 64 bit.
- template <typename T, size_t size_of_t>
- struct PushToStringImpl;
- template <typename T>
- struct PushToStringImpl<T, 8> {
- void operator()(T id, std::u32string* str) {
- str->push_back(static_cast<uint32_t>(id >> 32));
- str->push_back(static_cast<uint32_t>(id));
- }
- };
- template <typename T>
- struct PushToStringImpl<T, 4> {
- void operator()(T id, std::u32string* str) {
- str->push_back(static_cast<uint32_t>(id));
- }
- };
- template <typename T>
- void PushToString(T id, std::u32string* str) {
- PushToStringImpl<T, sizeof(T)>{}(id, str);
- }
- } // namespace
- // Implements the hashing of SENodes.
- size_t SENodeHash::operator()(const SENode* node) const {
- // Concatenate the terms into a string which we can hash.
- std::u32string hash_string{};
- // Hashing the type as a string is safer than hashing the enum as the enum is
- // very likely to collide with constants.
- for (char ch : node->AsString()) {
- hash_string.push_back(static_cast<char32_t>(ch));
- }
- // We just ignore the literal value unless it is a constant.
- if (node->GetType() == SENode::Constant)
- PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
- const SERecurrentNode* recurrent = node->AsSERecurrentNode();
- // If we're dealing with a recurrent expression hash the loop as well so that
- // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
- if (recurrent) {
- PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
- &hash_string);
- // Recurrent expressions can't be hashed using the normal method as the
- // order of coefficient and offset matters to the hash.
- PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
- &hash_string);
- PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
- &hash_string);
- return std::hash<std::u32string>{}(hash_string);
- }
- // Hash the result id of the original instruction which created this node if
- // it is a value unknown node.
- if (node->GetType() == SENode::ValueUnknown) {
- PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
- }
- // Hash the pointers of the child nodes, each SENode has a unique pointer
- // associated with it.
- const std::vector<SENode*>& children = node->GetChildren();
- for (const SENode* child : children) {
- PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
- }
- return std::hash<std::u32string>{}(hash_string);
- }
- // This overload is the actual overload used by the node_cache_ set.
- size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
- return this->operator()(node.get());
- }
- void SENode::DumpDot(std::ostream& out, bool recurse) const {
- size_t unique_id = std::hash<const SENode*>{}(this);
- out << unique_id << " [label=\"" << AsString() << " ";
- if (GetType() == SENode::Constant) {
- out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
- }
- out << "\"]\n";
- for (const SENode* child : children_) {
- size_t child_unique_id = std::hash<const SENode*>{}(child);
- out << unique_id << " -> " << child_unique_id << " \n";
- if (recurse) child->DumpDot(out, true);
- }
- }
- namespace {
- class IsGreaterThanZero {
- public:
- explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
- // Determine if the value of |node| is always strictly greater than zero if
- // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
- // true. It returns true is the evaluation was able to conclude something, in
- // which case the result is stored in |result|.
- // The algorithm work by going through all the nodes and determine the
- // sign of each of them.
- bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
- *result = false;
- switch (Visit(node)) {
- case Signedness::kPositiveOrNegative: {
- return false;
- }
- case Signedness::kStrictlyNegative: {
- *result = false;
- break;
- }
- case Signedness::kNegative: {
- if (!or_equal_zero) {
- return false;
- }
- *result = false;
- break;
- }
- case Signedness::kStrictlyPositive: {
- *result = true;
- break;
- }
- case Signedness::kPositive: {
- if (!or_equal_zero) {
- return false;
- }
- *result = true;
- break;
- }
- }
- return true;
- }
- private:
- enum class Signedness {
- kPositiveOrNegative, // Yield a value positive or negative.
- kStrictlyNegative, // Yield a value strictly less than 0.
- kNegative, // Yield a value less or equal to 0.
- kStrictlyPositive, // Yield a value strictly greater than 0.
- kPositive // Yield a value greater or equal to 0.
- };
- // Combine the signedness according to arithmetic rules of a given operator.
- using Combiner = std::function<Signedness(Signedness, Signedness)>;
- // Returns a functor to interpret the signedness of 2 expressions as if they
- // were added.
- Combiner GetAddCombiner() const {
- return [](Signedness lhs, Signedness rhs) {
- switch (lhs) {
- case Signedness::kPositiveOrNegative:
- break;
- case Signedness::kStrictlyNegative:
- if (rhs == Signedness::kStrictlyNegative ||
- rhs == Signedness::kNegative)
- return lhs;
- break;
- case Signedness::kNegative: {
- if (rhs == Signedness::kStrictlyNegative)
- return Signedness::kStrictlyNegative;
- if (rhs == Signedness::kNegative) return Signedness::kNegative;
- break;
- }
- case Signedness::kStrictlyPositive: {
- if (rhs == Signedness::kStrictlyPositive ||
- rhs == Signedness::kPositive) {
- return Signedness::kStrictlyPositive;
- }
- break;
- }
- case Signedness::kPositive: {
- if (rhs == Signedness::kStrictlyPositive)
- return Signedness::kStrictlyPositive;
- if (rhs == Signedness::kPositive) return Signedness::kPositive;
- break;
- }
- }
- return Signedness::kPositiveOrNegative;
- };
- }
- // Returns a functor to interpret the signedness of 2 expressions as if they
- // were multiplied.
- Combiner GetMulCombiner() const {
- return [](Signedness lhs, Signedness rhs) {
- switch (lhs) {
- case Signedness::kPositiveOrNegative:
- break;
- case Signedness::kStrictlyNegative: {
- switch (rhs) {
- case Signedness::kPositiveOrNegative: {
- break;
- }
- case Signedness::kStrictlyNegative: {
- return Signedness::kStrictlyPositive;
- }
- case Signedness::kNegative: {
- return Signedness::kPositive;
- }
- case Signedness::kStrictlyPositive: {
- return Signedness::kStrictlyNegative;
- }
- case Signedness::kPositive: {
- return Signedness::kNegative;
- }
- }
- break;
- }
- case Signedness::kNegative: {
- switch (rhs) {
- case Signedness::kPositiveOrNegative: {
- break;
- }
- case Signedness::kStrictlyNegative:
- case Signedness::kNegative: {
- return Signedness::kPositive;
- }
- case Signedness::kStrictlyPositive:
- case Signedness::kPositive: {
- return Signedness::kNegative;
- }
- }
- break;
- }
- case Signedness::kStrictlyPositive: {
- return rhs;
- }
- case Signedness::kPositive: {
- switch (rhs) {
- case Signedness::kPositiveOrNegative: {
- break;
- }
- case Signedness::kStrictlyNegative:
- case Signedness::kNegative: {
- return Signedness::kNegative;
- }
- case Signedness::kStrictlyPositive:
- case Signedness::kPositive: {
- return Signedness::kPositive;
- }
- }
- break;
- }
- }
- return Signedness::kPositiveOrNegative;
- };
- }
- Signedness Visit(const SENode* node) {
- switch (node->GetType()) {
- case SENode::Constant:
- return Visit(node->AsSEConstantNode());
- break;
- case SENode::RecurrentAddExpr:
- return Visit(node->AsSERecurrentNode());
- break;
- case SENode::Negative:
- return Visit(node->AsSENegative());
- break;
- case SENode::CanNotCompute:
- return Visit(node->AsSECantCompute());
- break;
- case SENode::ValueUnknown:
- return Visit(node->AsSEValueUnknown());
- break;
- case SENode::Add:
- return VisitExpr(node, GetAddCombiner());
- break;
- case SENode::Multiply:
- return VisitExpr(node, GetMulCombiner());
- break;
- }
- return Signedness::kPositiveOrNegative;
- }
- // Returns the signedness of a constant |node|.
- Signedness Visit(const SEConstantNode* node) {
- if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
- if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
- if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
- return Signedness::kPositiveOrNegative;
- }
- // Returns the signedness of an unknown |node| based on its type.
- Signedness Visit(const SEValueUnknown* node) {
- Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
- analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
- assert(type && "Can't retrieve a type for the instruction");
- analysis::Integer* int_type = type->AsInteger();
- assert(type && "Can't retrieve an integer type for the instruction");
- return int_type->IsSigned() ? Signedness::kPositiveOrNegative
- : Signedness::kPositive;
- }
- // Returns the signedness of a recurring expression.
- Signedness Visit(const SERecurrentNode* node) {
- Signedness coeff_sign = Visit(node->GetCoefficient());
- // SERecurrentNode represent an affine expression in the range [0,
- // loop_bound], so the result cannot be strictly positive or negative.
- switch (coeff_sign) {
- default:
- break;
- case Signedness::kStrictlyNegative:
- coeff_sign = Signedness::kNegative;
- break;
- case Signedness::kStrictlyPositive:
- coeff_sign = Signedness::kPositive;
- break;
- }
- return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
- }
- // Returns the signedness of a negation |node|.
- Signedness Visit(const SENegative* node) {
- switch (Visit(*node->begin())) {
- case Signedness::kPositiveOrNegative: {
- return Signedness::kPositiveOrNegative;
- }
- case Signedness::kStrictlyNegative: {
- return Signedness::kStrictlyPositive;
- }
- case Signedness::kNegative: {
- return Signedness::kPositive;
- }
- case Signedness::kStrictlyPositive: {
- return Signedness::kStrictlyNegative;
- }
- case Signedness::kPositive: {
- return Signedness::kNegative;
- }
- }
- return Signedness::kPositiveOrNegative;
- }
- Signedness Visit(const SECantCompute*) {
- return Signedness::kPositiveOrNegative;
- }
- // Returns the signedness of a binary expression by using the combiner
- // |reduce|.
- Signedness VisitExpr(
- const SENode* node,
- std::function<Signedness(Signedness, Signedness)> reduce) {
- Signedness result = Visit(*node->begin());
- for (const SENode* operand : make_range(++node->begin(), node->end())) {
- if (result == Signedness::kPositiveOrNegative) {
- return Signedness::kPositiveOrNegative;
- }
- result = reduce(result, Visit(operand));
- }
- return result;
- }
- IRContext* context_;
- };
- } // namespace
- bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
- bool* is_gt_zero) const {
- return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
- }
- bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
- SENode* node, bool* is_ge_zero) const {
- return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
- }
- namespace {
- // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
- // if |node| is not in the chain, returns the original chain.
- SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
- const SENode* node) {
- SENode* lhs = mul->GetChildren()[0];
- SENode* rhs = mul->GetChildren()[1];
- if (lhs == node) {
- return rhs;
- }
- if (rhs == node) {
- return lhs;
- }
- if (lhs->AsSEMultiplyNode()) {
- SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
- if (res != lhs)
- return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
- }
- if (rhs->AsSEMultiplyNode()) {
- SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
- if (res != rhs)
- return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
- }
- return mul;
- }
- } // namespace
- std::pair<SExpression, int64_t> SExpression::operator/(
- SExpression rhs_wrapper) const {
- SENode* lhs = node_;
- SENode* rhs = rhs_wrapper.node_;
- // Check for division by 0.
- if (rhs->AsSEConstantNode() &&
- !rhs->AsSEConstantNode()->FoldToSingleValue()) {
- return {scev_->CreateCantComputeNode(), 0};
- }
- // Trivial case.
- if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
- int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
- int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
- return {scev_->CreateConstant(lhs_value / rhs_value),
- lhs_value % rhs_value};
- }
- // look for a "c U / U" pattern.
- if (lhs->AsSEMultiplyNode()) {
- assert(lhs->GetChildren().size() == 2 &&
- "More than 2 operand for a multiply node.");
- SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
- if (res != lhs) {
- return {res, 0};
- }
- }
- return {scev_->CreateCantComputeNode(), 0};
- }
- } // namespace opt
- } // namespace spvtools
|