strength_reduction_pass.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/strength_reduction_pass.h"
  15. #include <cstring>
  16. #include <memory>
  17. #include <utility>
  18. #include <vector>
  19. #include "source/opt/def_use_manager.h"
  20. #include "source/opt/ir_context.h"
  21. #include "source/opt/log.h"
  22. #include "source/opt/reflect.h"
  23. namespace spvtools {
  24. namespace opt {
  25. namespace {
  26. // Count the number of trailing zeros in the binary representation of
  27. // |constVal|.
  28. uint32_t CountTrailingZeros(uint32_t constVal) {
  29. // Faster if we use the hardware count trailing zeros instruction.
  30. // If not available, we could create a table.
  31. uint32_t shiftAmount = 0;
  32. while ((constVal & 1) == 0) {
  33. ++shiftAmount;
  34. constVal = (constVal >> 1);
  35. }
  36. return shiftAmount;
  37. }
  38. // Return true if |val| is a power of 2.
  39. bool IsPowerOf2(uint32_t val) {
  40. // The idea is that the & will clear out the least
  41. // significant 1 bit. If it is a power of 2, then
  42. // there is exactly 1 bit set, and the value becomes 0.
  43. if (val == 0) return false;
  44. return ((val - 1) & val) == 0;
  45. }
  46. } // namespace
  47. Pass::Status StrengthReductionPass::Process() {
  48. // Initialize the member variables on a per module basis.
  49. bool modified = false;
  50. int32_type_id_ = 0;
  51. uint32_type_id_ = 0;
  52. std::memset(constant_ids_, 0, sizeof(constant_ids_));
  53. FindIntTypesAndConstants();
  54. modified = ScanFunctions();
  55. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  56. }
  57. bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
  58. BasicBlock::iterator* inst) {
  59. assert((*inst)->opcode() == spv::Op::OpIMul &&
  60. "Only works for multiplication of integers.");
  61. bool modified = false;
  62. // Currently only works on 32-bit integers.
  63. if ((*inst)->type_id() != int32_type_id_ &&
  64. (*inst)->type_id() != uint32_type_id_) {
  65. return modified;
  66. }
  67. // Check the operands for a constant that is a power of 2.
  68. for (int i = 0; i < 2; i++) {
  69. uint32_t opId = (*inst)->GetSingleWordInOperand(i);
  70. Instruction* opInst = get_def_use_mgr()->GetDef(opId);
  71. if (opInst->opcode() == spv::Op::OpConstant) {
  72. // We found a constant operand.
  73. uint32_t constVal = opInst->GetSingleWordOperand(2);
  74. if (IsPowerOf2(constVal)) {
  75. modified = true;
  76. uint32_t shiftAmount = CountTrailingZeros(constVal);
  77. uint32_t shiftConstResultId = GetConstantId(shiftAmount);
  78. // Create the new instruction.
  79. uint32_t newResultId = TakeNextId();
  80. std::vector<Operand> newOperands;
  81. newOperands.push_back((*inst)->GetInOperand(1 - i));
  82. Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  83. {shiftConstResultId});
  84. newOperands.push_back(shiftOperand);
  85. std::unique_ptr<Instruction> newInstruction(
  86. new Instruction(context(), spv::Op::OpShiftLeftLogical,
  87. (*inst)->type_id(), newResultId, newOperands));
  88. // Insert the new instruction and update the data structures.
  89. (*inst) = (*inst).InsertBefore(std::move(newInstruction));
  90. get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
  91. ++(*inst);
  92. context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
  93. // Remove the old instruction.
  94. Instruction* inst_to_delete = &*(*inst);
  95. --(*inst);
  96. context()->KillInst(inst_to_delete);
  97. // We do not want to replace the instruction twice if both operands
  98. // are constants that are a power of 2. So we break here.
  99. break;
  100. }
  101. }
  102. }
  103. return modified;
  104. }
  105. void StrengthReductionPass::FindIntTypesAndConstants() {
  106. analysis::Integer int32(32, true);
  107. int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
  108. analysis::Integer uint32(32, false);
  109. uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
  110. for (auto iter = get_module()->types_values_begin();
  111. iter != get_module()->types_values_end(); ++iter) {
  112. switch (iter->opcode()) {
  113. case spv::Op::OpConstant:
  114. if (iter->type_id() == uint32_type_id_) {
  115. uint32_t value = iter->GetSingleWordOperand(2);
  116. if (value <= 32) constant_ids_[value] = iter->result_id();
  117. }
  118. break;
  119. default:
  120. break;
  121. }
  122. }
  123. }
  124. uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
  125. assert(val <= 32 &&
  126. "This function does not handle constants larger than 32.");
  127. if (constant_ids_[val] == 0) {
  128. if (uint32_type_id_ == 0) {
  129. analysis::Integer uint(32, false);
  130. uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
  131. }
  132. // Construct the constant.
  133. uint32_t resultId = TakeNextId();
  134. Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
  135. {val});
  136. std::unique_ptr<Instruction> newConstant(new Instruction(
  137. context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant}));
  138. get_module()->AddGlobalValue(std::move(newConstant));
  139. // Notify the DefUseManager about this constant.
  140. auto constantIter = --get_module()->types_values_end();
  141. get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
  142. // Store the result id for next time.
  143. constant_ids_[val] = resultId;
  144. }
  145. return constant_ids_[val];
  146. }
  147. bool StrengthReductionPass::ScanFunctions() {
  148. // I did not use |ForEachInst| in the module because the function that acts on
  149. // the instruction gets a pointer to the instruction. We cannot use that to
  150. // insert a new instruction. I want an iterator.
  151. bool modified = false;
  152. for (auto& func : *get_module()) {
  153. for (auto& bb : func) {
  154. for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
  155. switch (inst->opcode()) {
  156. case spv::Op::OpIMul:
  157. if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
  158. break;
  159. default:
  160. break;
  161. }
  162. }
  163. }
  164. }
  165. return modified;
  166. }
  167. } // namespace opt
  168. } // namespace spvtools