mem_pass.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/mem_pass.h"
  17. #include <memory>
  18. #include <set>
  19. #include <vector>
  20. #include "source/cfa.h"
  21. #include "source/opt/basic_block.h"
  22. #include "source/opt/dominator_analysis.h"
  23. #include "source/opt/ir_context.h"
  24. #include "source/opt/iterator.h"
  25. namespace spvtools {
  26. namespace opt {
  27. namespace {
  28. constexpr uint32_t kCopyObjectOperandInIdx = 0;
  29. constexpr uint32_t kTypePointerStorageClassInIdx = 0;
  30. constexpr uint32_t kTypePointerTypeIdInIdx = 1;
  31. } // namespace
  32. bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
  33. switch (typeInst->opcode()) {
  34. case spv::Op::OpTypeInt:
  35. case spv::Op::OpTypeFloat:
  36. case spv::Op::OpTypeBool:
  37. case spv::Op::OpTypeVector:
  38. case spv::Op::OpTypeMatrix:
  39. case spv::Op::OpTypeImage:
  40. case spv::Op::OpTypeSampler:
  41. case spv::Op::OpTypeSampledImage:
  42. case spv::Op::OpTypePointer:
  43. return true;
  44. default:
  45. break;
  46. }
  47. return false;
  48. }
  49. bool MemPass::IsTargetType(const Instruction* typeInst) const {
  50. if (IsBaseTargetType(typeInst)) return true;
  51. if (typeInst->opcode() == spv::Op::OpTypeArray) {
  52. if (!IsTargetType(
  53. get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
  54. return false;
  55. }
  56. return true;
  57. }
  58. if (typeInst->opcode() != spv::Op::OpTypeStruct) return false;
  59. // All struct members must be math type
  60. return typeInst->WhileEachInId([this](const uint32_t* tid) {
  61. Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
  62. if (!IsTargetType(compTypeInst)) return false;
  63. return true;
  64. });
  65. }
  66. bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
  67. return opcode == spv::Op::OpAccessChain ||
  68. opcode == spv::Op::OpInBoundsAccessChain;
  69. }
  70. bool MemPass::IsPtr(uint32_t ptrId) {
  71. uint32_t varId = ptrId;
  72. Instruction* ptrInst = get_def_use_mgr()->GetDef(varId);
  73. while (ptrInst->opcode() == spv::Op::OpCopyObject) {
  74. varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
  75. ptrInst = get_def_use_mgr()->GetDef(varId);
  76. }
  77. const spv::Op op = ptrInst->opcode();
  78. if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
  79. const uint32_t varTypeId = ptrInst->type_id();
  80. if (varTypeId == 0) return false;
  81. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  82. return varTypeInst->opcode() == spv::Op::OpTypePointer;
  83. }
  84. Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
  85. *varId = ptrId;
  86. Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId);
  87. Instruction* varInst;
  88. if (ptrInst->opcode() == spv::Op::OpConstantNull) {
  89. *varId = 0;
  90. return ptrInst;
  91. }
  92. if (ptrInst->opcode() != spv::Op::OpVariable &&
  93. ptrInst->opcode() != spv::Op::OpFunctionParameter) {
  94. varInst = ptrInst->GetBaseAddress();
  95. } else {
  96. varInst = ptrInst;
  97. }
  98. if (varInst->opcode() == spv::Op::OpVariable) {
  99. *varId = varInst->result_id();
  100. } else {
  101. *varId = 0;
  102. }
  103. while (ptrInst->opcode() == spv::Op::OpCopyObject) {
  104. uint32_t temp = ptrInst->GetSingleWordInOperand(0);
  105. ptrInst = get_def_use_mgr()->GetDef(temp);
  106. }
  107. return ptrInst;
  108. }
  109. Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) {
  110. assert(ip->opcode() == spv::Op::OpStore || ip->opcode() == spv::Op::OpLoad ||
  111. ip->opcode() == spv::Op::OpImageTexelPointer ||
  112. ip->IsAtomicWithLoad());
  113. // All of these opcode place the pointer in position 0.
  114. const uint32_t ptrId = ip->GetSingleWordInOperand(0);
  115. return GetPtr(ptrId, varId);
  116. }
  117. bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
  118. return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) {
  119. spv::Op op = user->opcode();
  120. if (op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
  121. return false;
  122. }
  123. return true;
  124. });
  125. }
  126. void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) {
  127. bp->KillAllInsts(killLabel);
  128. }
  129. bool MemPass::HasLoads(uint32_t varId) const {
  130. return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  131. spv::Op op = user->opcode();
  132. // TODO(): The following is slightly conservative. Could be
  133. // better handling of non-store/name.
  134. if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
  135. if (HasLoads(user->result_id())) {
  136. return false;
  137. }
  138. } else if (op != spv::Op::OpStore && op != spv::Op::OpName &&
  139. !IsNonTypeDecorate(op)) {
  140. return false;
  141. }
  142. return true;
  143. });
  144. }
  145. bool MemPass::IsLiveVar(uint32_t varId) const {
  146. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  147. // assume live if not a variable eg. function parameter
  148. if (varInst->opcode() != spv::Op::OpVariable) return true;
  149. // non-function scope vars are live
  150. const uint32_t varTypeId = varInst->type_id();
  151. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  152. if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
  153. kTypePointerStorageClassInIdx)) != spv::StorageClass::Function)
  154. return true;
  155. // test if variable is loaded from
  156. return HasLoads(varId);
  157. }
  158. void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) {
  159. get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) {
  160. spv::Op op = user->opcode();
  161. if (IsNonPtrAccessChain(op)) {
  162. AddStores(user->result_id(), insts);
  163. } else if (op == spv::Op::OpStore) {
  164. insts->push(user);
  165. }
  166. });
  167. }
  168. void MemPass::DCEInst(Instruction* inst,
  169. const std::function<void(Instruction*)>& call_back) {
  170. std::queue<Instruction*> deadInsts;
  171. deadInsts.push(inst);
  172. while (!deadInsts.empty()) {
  173. Instruction* di = deadInsts.front();
  174. // Don't delete labels
  175. if (di->opcode() == spv::Op::OpLabel) {
  176. deadInsts.pop();
  177. continue;
  178. }
  179. // Remember operands
  180. std::set<uint32_t> ids;
  181. di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); });
  182. uint32_t varId = 0;
  183. // Remember variable if dead load
  184. if (di->opcode() == spv::Op::OpLoad) (void)GetPtr(di, &varId);
  185. if (call_back) {
  186. call_back(di);
  187. }
  188. context()->KillInst(di);
  189. // For all operands with no remaining uses, add their instruction
  190. // to the dead instruction queue.
  191. for (auto id : ids)
  192. if (HasOnlyNamesAndDecorates(id)) {
  193. Instruction* odi = get_def_use_mgr()->GetDef(id);
  194. if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi);
  195. }
  196. // if a load was deleted and it was the variable's
  197. // last load, add all its stores to dead queue
  198. if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts);
  199. deadInsts.pop();
  200. }
  201. }
  202. MemPass::MemPass() {}
  203. bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
  204. return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  205. auto dbg_op = user->GetCommonDebugOpcode();
  206. if (dbg_op == CommonDebugInfoDebugDeclare ||
  207. dbg_op == CommonDebugInfoDebugValue) {
  208. return true;
  209. }
  210. spv::Op op = user->opcode();
  211. if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
  212. op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
  213. return false;
  214. }
  215. return true;
  216. });
  217. }
  218. uint32_t MemPass::Type2Undef(uint32_t type_id) {
  219. const auto uitr = type2undefs_.find(type_id);
  220. if (uitr != type2undefs_.end()) return uitr->second;
  221. const uint32_t undefId = TakeNextId();
  222. if (undefId == 0) {
  223. return 0;
  224. }
  225. std::unique_ptr<Instruction> undef_inst(
  226. new Instruction(context(), spv::Op::OpUndef, type_id, undefId, {}));
  227. get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst);
  228. get_module()->AddGlobalValue(std::move(undef_inst));
  229. type2undefs_[type_id] = undefId;
  230. return undefId;
  231. }
  232. bool MemPass::IsTargetVar(uint32_t varId) {
  233. if (varId == 0) {
  234. return false;
  235. }
  236. if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
  237. return false;
  238. if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true;
  239. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  240. if (varInst->opcode() != spv::Op::OpVariable) return false;
  241. const uint32_t varTypeId = varInst->type_id();
  242. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  243. if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
  244. kTypePointerStorageClassInIdx)) != spv::StorageClass::Function) {
  245. seen_non_target_vars_.insert(varId);
  246. return false;
  247. }
  248. const uint32_t varPteTypeId =
  249. varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
  250. Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId);
  251. if (!IsTargetType(varPteTypeInst)) {
  252. seen_non_target_vars_.insert(varId);
  253. return false;
  254. }
  255. seen_target_vars_.insert(varId);
  256. return true;
  257. }
  258. // Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in
  259. // |reachable_blocks|). There are two types of removal that this function can
  260. // perform:
  261. //
  262. // 1- Any operand that comes directly from an unreachable block is completely
  263. // removed. Since the block is unreachable, the edge between the unreachable
  264. // block and the block holding |phi| has been removed.
  265. //
  266. // 2- Any operand that comes via a live block and was defined at an unreachable
  267. // block gets its value replaced with an OpUndef value. Since the argument
  268. // was generated in an unreachable block, it no longer exists, so it cannot
  269. // be referenced. However, since the value does not reach |phi| directly
  270. // from the unreachable block, the operand cannot be removed from |phi|.
  271. // Therefore, we replace the argument value with OpUndef.
  272. //
  273. // For example, in the switch() below, assume that we want to remove the
  274. // argument with value %11 coming from block %41.
  275. //
  276. // [ ... ]
  277. // %41 = OpLabel <--- Unreachable block
  278. // %11 = OpLoad %int %y
  279. // [ ... ]
  280. // OpSelectionMerge %16 None
  281. // OpSwitch %12 %16 10 %13 13 %14 18 %15
  282. // %13 = OpLabel
  283. // OpBranch %16
  284. // %14 = OpLabel
  285. // OpStore %outparm %int_14
  286. // OpBranch %16
  287. // %15 = OpLabel
  288. // OpStore %outparm %int_15
  289. // OpBranch %16
  290. // %16 = OpLabel
  291. // %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15
  292. //
  293. // Since %41 is now an unreachable block, the first operand of |phi| needs to
  294. // be removed completely. But the operands (%11 %14) and (%11 %15) cannot be
  295. // removed because %14 and %15 are reachable blocks. Since %11 no longer exist,
  296. // in those arguments, we replace all references to %11 with an OpUndef value.
  297. // This results in |phi| looking like:
  298. //
  299. // %50 = OpUndef %int
  300. // [ ... ]
  301. // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15
  302. void MemPass::RemovePhiOperands(
  303. Instruction* phi, const std::unordered_set<BasicBlock*>& reachable_blocks) {
  304. std::vector<Operand> keep_operands;
  305. uint32_t type_id = 0;
  306. // The id of an undefined value we've generated.
  307. uint32_t undef_id = 0;
  308. // Traverse all the operands in |phi|. Build the new operand vector by adding
  309. // all the original operands from |phi| except the unwanted ones.
  310. for (uint32_t i = 0; i < phi->NumOperands();) {
  311. if (i < 2) {
  312. // The first two arguments are always preserved.
  313. keep_operands.push_back(phi->GetOperand(i));
  314. ++i;
  315. continue;
  316. }
  317. // The remaining Phi arguments come in pairs. Index 'i' contains the
  318. // variable id, index 'i + 1' is the originating block id.
  319. assert(i % 2 == 0 && i < phi->NumOperands() - 1 &&
  320. "malformed Phi arguments");
  321. BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1));
  322. if (reachable_blocks.find(in_block) == reachable_blocks.end()) {
  323. // If the incoming block is unreachable, remove both operands as this
  324. // means that the |phi| has lost an incoming edge.
  325. i += 2;
  326. continue;
  327. }
  328. // In all other cases, the operand must be kept but may need to be changed.
  329. uint32_t arg_id = phi->GetSingleWordOperand(i);
  330. Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id);
  331. BasicBlock* def_block = context()->get_instr_block(arg_def_instr);
  332. if (def_block &&
  333. reachable_blocks.find(def_block) == reachable_blocks.end()) {
  334. // If the current |phi| argument was defined in an unreachable block, it
  335. // means that this |phi| argument is no longer defined. Replace it with
  336. // |undef_id|.
  337. if (!undef_id) {
  338. type_id = arg_def_instr->type_id();
  339. undef_id = Type2Undef(type_id);
  340. }
  341. keep_operands.push_back(
  342. Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id}));
  343. } else {
  344. // Otherwise, the argument comes from a reachable block or from no block
  345. // at all (meaning that it was defined in the global section of the
  346. // program). In both cases, keep the argument intact.
  347. keep_operands.push_back(phi->GetOperand(i));
  348. }
  349. keep_operands.push_back(phi->GetOperand(i + 1));
  350. i += 2;
  351. }
  352. context()->ForgetUses(phi);
  353. phi->ReplaceOperands(keep_operands);
  354. context()->AnalyzeUses(phi);
  355. }
  356. void MemPass::RemoveBlock(Function::iterator* bi) {
  357. auto& rm_block = **bi;
  358. // Remove instructions from the block.
  359. rm_block.ForEachInst([&rm_block, this](Instruction* inst) {
  360. // Note that we do not kill the block label instruction here. The label
  361. // instruction is needed to identify the block, which is needed by the
  362. // removal of phi operands.
  363. if (inst != rm_block.GetLabelInst()) {
  364. context()->KillInst(inst);
  365. }
  366. });
  367. // Remove the label instruction last.
  368. auto label = rm_block.GetLabelInst();
  369. context()->KillInst(label);
  370. *bi = bi->Erase();
  371. }
  372. bool MemPass::RemoveUnreachableBlocks(Function* func) {
  373. bool modified = false;
  374. // Mark reachable all blocks reachable from the function's entry block.
  375. std::unordered_set<BasicBlock*> reachable_blocks;
  376. std::unordered_set<BasicBlock*> visited_blocks;
  377. std::queue<BasicBlock*> worklist;
  378. reachable_blocks.insert(func->entry().get());
  379. // Initially mark the function entry point as reachable.
  380. worklist.push(func->entry().get());
  381. auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist,
  382. this](uint32_t label_id) {
  383. auto successor = cfg()->block(label_id);
  384. if (visited_blocks.count(successor) == 0) {
  385. reachable_blocks.insert(successor);
  386. worklist.push(successor);
  387. visited_blocks.insert(successor);
  388. }
  389. };
  390. // Transitively mark all blocks reachable from the entry as reachable.
  391. while (!worklist.empty()) {
  392. BasicBlock* block = worklist.front();
  393. worklist.pop();
  394. // All the successors of a live block are also live.
  395. static_cast<const BasicBlock*>(block)->ForEachSuccessorLabel(
  396. mark_reachable);
  397. // All the Merge and ContinueTarget blocks of a live block are also live.
  398. block->ForMergeAndContinueLabel(mark_reachable);
  399. }
  400. // Update operands of Phi nodes that reference unreachable blocks.
  401. for (auto& block : *func) {
  402. // If the block is about to be removed, don't bother updating its
  403. // Phi instructions.
  404. if (reachable_blocks.count(&block) == 0) {
  405. continue;
  406. }
  407. // If the block is reachable and has Phi instructions, remove all
  408. // operands from its Phi instructions that reference unreachable blocks.
  409. // If the block has no Phi instructions, this is a no-op.
  410. block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) {
  411. RemovePhiOperands(phi, reachable_blocks);
  412. });
  413. }
  414. // Erase unreachable blocks.
  415. for (auto ebi = func->begin(); ebi != func->end();) {
  416. if (reachable_blocks.count(&*ebi) == 0) {
  417. RemoveBlock(&ebi);
  418. modified = true;
  419. } else {
  420. ++ebi;
  421. }
  422. }
  423. return modified;
  424. }
  425. bool MemPass::CFGCleanup(Function* func) {
  426. bool modified = false;
  427. modified |= RemoveUnreachableBlocks(func);
  428. return modified;
  429. }
  430. void MemPass::CollectTargetVars(Function* func) {
  431. seen_target_vars_.clear();
  432. seen_non_target_vars_.clear();
  433. type2undefs_.clear();
  434. // Collect target (and non-) variable sets. Remove variables with
  435. // non-load/store refs from target variable set
  436. for (auto& blk : *func) {
  437. for (auto& inst : blk) {
  438. switch (inst.opcode()) {
  439. case spv::Op::OpStore:
  440. case spv::Op::OpLoad: {
  441. uint32_t varId;
  442. (void)GetPtr(&inst, &varId);
  443. if (!IsTargetVar(varId)) break;
  444. if (HasOnlySupportedRefs(varId)) break;
  445. seen_non_target_vars_.insert(varId);
  446. seen_target_vars_.erase(varId);
  447. } break;
  448. default:
  449. break;
  450. }
  451. }
  452. }
  453. }
  454. } // namespace opt
  455. } // namespace spvtools