2
0

strength_reduction_pass.cpp 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/strength_reduction_pass.h"
  15. #include <algorithm>
  16. #include <cstdio>
  17. #include <cstring>
  18. #include <memory>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <vector>
  23. #include "source/opt/def_use_manager.h"
  24. #include "source/opt/ir_context.h"
  25. #include "source/opt/log.h"
  26. #include "source/opt/reflect.h"
  27. namespace {
  28. // Count the number of trailing zeros in the binary representation of
  29. // |constVal|.
  30. uint32_t CountTrailingZeros(uint32_t constVal) {
  31. // Faster if we use the hardware count trailing zeros instruction.
  32. // If not available, we could create a table.
  33. uint32_t shiftAmount = 0;
  34. while ((constVal & 1) == 0) {
  35. ++shiftAmount;
  36. constVal = (constVal >> 1);
  37. }
  38. return shiftAmount;
  39. }
  40. // Return true if |val| is a power of 2.
  41. bool IsPowerOf2(uint32_t val) {
  42. // The idea is that the & will clear out the least
  43. // significant 1 bit. If it is a power of 2, then
  44. // there is exactly 1 bit set, and the value becomes 0.
  45. if (val == 0) return false;
  46. return ((val - 1) & val) == 0;
  47. }
  48. } // namespace
  49. namespace spvtools {
  50. namespace opt {
  51. Pass::Status StrengthReductionPass::Process() {
  52. // Initialize the member variables on a per module basis.
  53. bool modified = false;
  54. int32_type_id_ = 0;
  55. uint32_type_id_ = 0;
  56. std::memset(constant_ids_, 0, sizeof(constant_ids_));
  57. FindIntTypesAndConstants();
  58. modified = ScanFunctions();
  59. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  60. }
  61. bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
  62. BasicBlock::iterator* inst) {
  63. assert((*inst)->opcode() == SpvOp::SpvOpIMul &&
  64. "Only works for multiplication of integers.");
  65. bool modified = false;
  66. // Currently only works on 32-bit integers.
  67. if ((*inst)->type_id() != int32_type_id_ &&
  68. (*inst)->type_id() != uint32_type_id_) {
  69. return modified;
  70. }
  71. // Check the operands for a constant that is a power of 2.
  72. for (int i = 0; i < 2; i++) {
  73. uint32_t opId = (*inst)->GetSingleWordInOperand(i);
  74. Instruction* opInst = get_def_use_mgr()->GetDef(opId);
  75. if (opInst->opcode() == SpvOp::SpvOpConstant) {
  76. // We found a constant operand.
  77. uint32_t constVal = opInst->GetSingleWordOperand(2);
  78. if (IsPowerOf2(constVal)) {
  79. modified = true;
  80. uint32_t shiftAmount = CountTrailingZeros(constVal);
  81. uint32_t shiftConstResultId = GetConstantId(shiftAmount);
  82. // Create the new instruction.
  83. uint32_t newResultId = TakeNextId();
  84. std::vector<Operand> newOperands;
  85. newOperands.push_back((*inst)->GetInOperand(1 - i));
  86. Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  87. {shiftConstResultId});
  88. newOperands.push_back(shiftOperand);
  89. std::unique_ptr<Instruction> newInstruction(
  90. new Instruction(context(), SpvOp::SpvOpShiftLeftLogical,
  91. (*inst)->type_id(), newResultId, newOperands));
  92. // Insert the new instruction and update the data structures.
  93. (*inst) = (*inst).InsertBefore(std::move(newInstruction));
  94. get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
  95. ++(*inst);
  96. context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
  97. // Remove the old instruction.
  98. Instruction* inst_to_delete = &*(*inst);
  99. --(*inst);
  100. context()->KillInst(inst_to_delete);
  101. // We do not want to replace the instruction twice if both operands
  102. // are constants that are a power of 2. So we break here.
  103. break;
  104. }
  105. }
  106. }
  107. return modified;
  108. }
  109. void StrengthReductionPass::FindIntTypesAndConstants() {
  110. analysis::Integer int32(32, true);
  111. int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
  112. analysis::Integer uint32(32, false);
  113. uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
  114. for (auto iter = get_module()->types_values_begin();
  115. iter != get_module()->types_values_end(); ++iter) {
  116. switch (iter->opcode()) {
  117. case SpvOp::SpvOpConstant:
  118. if (iter->type_id() == uint32_type_id_) {
  119. uint32_t value = iter->GetSingleWordOperand(2);
  120. if (value <= 32) constant_ids_[value] = iter->result_id();
  121. }
  122. break;
  123. default:
  124. break;
  125. }
  126. }
  127. }
  128. uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
  129. assert(val <= 32 &&
  130. "This function does not handle constants larger than 32.");
  131. if (constant_ids_[val] == 0) {
  132. if (uint32_type_id_ == 0) {
  133. analysis::Integer uint(32, false);
  134. uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
  135. }
  136. // Construct the constant.
  137. uint32_t resultId = TakeNextId();
  138. Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
  139. {val});
  140. std::unique_ptr<Instruction> newConstant(
  141. new Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_,
  142. resultId, {constant}));
  143. get_module()->AddGlobalValue(std::move(newConstant));
  144. // Notify the DefUseManager about this constant.
  145. auto constantIter = --get_module()->types_values_end();
  146. get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
  147. // Store the result id for next time.
  148. constant_ids_[val] = resultId;
  149. }
  150. return constant_ids_[val];
  151. }
  152. bool StrengthReductionPass::ScanFunctions() {
  153. // I did not use |ForEachInst| in the module because the function that acts on
  154. // the instruction gets a pointer to the instruction. We cannot use that to
  155. // insert a new instruction. I want an iterator.
  156. bool modified = false;
  157. for (auto& func : *get_module()) {
  158. for (auto& bb : func) {
  159. for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
  160. switch (inst->opcode()) {
  161. case SpvOp::SpvOpIMul:
  162. if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
  163. break;
  164. default:
  165. break;
  166. }
  167. }
  168. }
  169. }
  170. return modified;
  171. }
  172. } // namespace opt
  173. } // namespace spvtools