mem_pass.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/mem_pass.h"
  17. #include <memory>
  18. #include <set>
  19. #include <vector>
  20. #include "source/cfa.h"
  21. #include "source/opt/basic_block.h"
  22. #include "source/opt/ir_context.h"
  23. namespace spvtools {
  24. namespace opt {
  25. namespace {
  26. constexpr uint32_t kCopyObjectOperandInIdx = 0;
  27. constexpr uint32_t kTypePointerStorageClassInIdx = 0;
  28. constexpr uint32_t kTypePointerTypeIdInIdx = 1;
  29. } // namespace
  30. bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
  31. switch (typeInst->opcode()) {
  32. case spv::Op::OpTypeInt:
  33. case spv::Op::OpTypeFloat:
  34. case spv::Op::OpTypeBool:
  35. case spv::Op::OpTypeVector:
  36. case spv::Op::OpTypeMatrix:
  37. case spv::Op::OpTypeImage:
  38. case spv::Op::OpTypeSampler:
  39. case spv::Op::OpTypeSampledImage:
  40. case spv::Op::OpTypePointer:
  41. case spv::Op::OpTypeCooperativeMatrixNV:
  42. case spv::Op::OpTypeCooperativeMatrixKHR:
  43. return true;
  44. default:
  45. break;
  46. }
  47. return false;
  48. }
  49. bool MemPass::IsTargetType(const Instruction* typeInst) const {
  50. if (IsBaseTargetType(typeInst)) return true;
  51. if (typeInst->opcode() == spv::Op::OpTypeArray) {
  52. if (!IsTargetType(
  53. get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
  54. return false;
  55. }
  56. return true;
  57. }
  58. if (typeInst->opcode() != spv::Op::OpTypeStruct) return false;
  59. // All struct members must be math type
  60. return typeInst->WhileEachInId([this](const uint32_t* tid) {
  61. Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
  62. if (!IsTargetType(compTypeInst)) return false;
  63. return true;
  64. });
  65. }
  66. bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
  67. return opcode == spv::Op::OpAccessChain ||
  68. opcode == spv::Op::OpInBoundsAccessChain;
  69. }
  70. bool MemPass::IsPtr(uint32_t ptrId) {
  71. uint32_t varId = ptrId;
  72. Instruction* ptrInst = get_def_use_mgr()->GetDef(varId);
  73. if (ptrInst->opcode() == spv::Op::OpFunction) {
  74. // A function is not a pointer, but it's return type could be, which will
  75. // erroneously lead to this function returning true later on
  76. return false;
  77. }
  78. while (ptrInst->opcode() == spv::Op::OpCopyObject) {
  79. varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
  80. ptrInst = get_def_use_mgr()->GetDef(varId);
  81. }
  82. const spv::Op op = ptrInst->opcode();
  83. if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
  84. const uint32_t varTypeId = ptrInst->type_id();
  85. if (varTypeId == 0) return false;
  86. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  87. return varTypeInst->opcode() == spv::Op::OpTypePointer;
  88. }
  89. Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
  90. *varId = ptrId;
  91. Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId);
  92. Instruction* varInst;
  93. if (ptrInst->opcode() == spv::Op::OpConstantNull) {
  94. *varId = 0;
  95. return ptrInst;
  96. }
  97. if (ptrInst->opcode() != spv::Op::OpVariable &&
  98. ptrInst->opcode() != spv::Op::OpFunctionParameter) {
  99. varInst = ptrInst->GetBaseAddress();
  100. } else {
  101. varInst = ptrInst;
  102. }
  103. if (varInst->opcode() == spv::Op::OpVariable) {
  104. *varId = varInst->result_id();
  105. } else {
  106. *varId = 0;
  107. }
  108. while (ptrInst->opcode() == spv::Op::OpCopyObject) {
  109. uint32_t temp = ptrInst->GetSingleWordInOperand(0);
  110. ptrInst = get_def_use_mgr()->GetDef(temp);
  111. }
  112. return ptrInst;
  113. }
  114. Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) {
  115. assert(ip->opcode() == spv::Op::OpStore || ip->opcode() == spv::Op::OpLoad ||
  116. ip->opcode() == spv::Op::OpImageTexelPointer ||
  117. ip->IsAtomicWithLoad());
  118. // All of these opcode place the pointer in position 0.
  119. const uint32_t ptrId = ip->GetSingleWordInOperand(0);
  120. return GetPtr(ptrId, varId);
  121. }
  122. bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
  123. return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) {
  124. spv::Op op = user->opcode();
  125. if (op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
  126. return false;
  127. }
  128. return true;
  129. });
  130. }
  131. void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) {
  132. bp->KillAllInsts(killLabel);
  133. }
  134. bool MemPass::HasLoads(uint32_t varId) const {
  135. return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  136. spv::Op op = user->opcode();
  137. // TODO(): The following is slightly conservative. Could be
  138. // better handling of non-store/name.
  139. if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
  140. if (HasLoads(user->result_id())) {
  141. return false;
  142. }
  143. } else if (op != spv::Op::OpStore && op != spv::Op::OpName &&
  144. !IsNonTypeDecorate(op)) {
  145. return false;
  146. }
  147. return true;
  148. });
  149. }
  150. bool MemPass::IsLiveVar(uint32_t varId) const {
  151. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  152. // assume live if not a variable eg. function parameter
  153. if (varInst->opcode() != spv::Op::OpVariable) return true;
  154. // non-function scope vars are live
  155. const uint32_t varTypeId = varInst->type_id();
  156. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  157. if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
  158. kTypePointerStorageClassInIdx)) != spv::StorageClass::Function)
  159. return true;
  160. // test if variable is loaded from
  161. return HasLoads(varId);
  162. }
  163. void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) {
  164. get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) {
  165. spv::Op op = user->opcode();
  166. if (IsNonPtrAccessChain(op)) {
  167. AddStores(user->result_id(), insts);
  168. } else if (op == spv::Op::OpStore) {
  169. insts->push(user);
  170. }
  171. });
  172. }
  173. void MemPass::DCEInst(Instruction* inst,
  174. const std::function<void(Instruction*)>& call_back) {
  175. std::queue<Instruction*> deadInsts;
  176. deadInsts.push(inst);
  177. while (!deadInsts.empty()) {
  178. Instruction* di = deadInsts.front();
  179. // Don't delete labels
  180. if (di->opcode() == spv::Op::OpLabel) {
  181. deadInsts.pop();
  182. continue;
  183. }
  184. // Remember operands
  185. std::set<uint32_t> ids;
  186. di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); });
  187. uint32_t varId = 0;
  188. // Remember variable if dead load
  189. if (di->opcode() == spv::Op::OpLoad) (void)GetPtr(di, &varId);
  190. if (call_back) {
  191. call_back(di);
  192. }
  193. context()->KillInst(di);
  194. // For all operands with no remaining uses, add their instruction
  195. // to the dead instruction queue.
  196. for (auto id : ids)
  197. if (HasOnlyNamesAndDecorates(id)) {
  198. Instruction* odi = get_def_use_mgr()->GetDef(id);
  199. if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi);
  200. }
  201. // if a load was deleted and it was the variable's
  202. // last load, add all its stores to dead queue
  203. if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts);
  204. deadInsts.pop();
  205. }
  206. }
  207. MemPass::MemPass() {}
  208. bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
  209. return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
  210. auto dbg_op = user->GetCommonDebugOpcode();
  211. if (dbg_op == CommonDebugInfoDebugDeclare ||
  212. dbg_op == CommonDebugInfoDebugValue) {
  213. return true;
  214. }
  215. spv::Op op = user->opcode();
  216. if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
  217. op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
  218. return false;
  219. }
  220. return true;
  221. });
  222. }
  223. uint32_t MemPass::Type2Undef(uint32_t type_id) {
  224. const auto uitr = type2undefs_.find(type_id);
  225. if (uitr != type2undefs_.end()) return uitr->second;
  226. const uint32_t undefId = TakeNextId();
  227. if (undefId == 0) {
  228. return 0;
  229. }
  230. std::unique_ptr<Instruction> undef_inst(
  231. new Instruction(context(), spv::Op::OpUndef, type_id, undefId, {}));
  232. get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst);
  233. get_module()->AddGlobalValue(std::move(undef_inst));
  234. type2undefs_[type_id] = undefId;
  235. return undefId;
  236. }
  237. bool MemPass::IsTargetVar(uint32_t varId) {
  238. if (varId == 0) {
  239. return false;
  240. }
  241. if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
  242. return false;
  243. if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true;
  244. const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
  245. if (varInst->opcode() != spv::Op::OpVariable) return false;
  246. const uint32_t varTypeId = varInst->type_id();
  247. const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
  248. if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
  249. kTypePointerStorageClassInIdx)) != spv::StorageClass::Function) {
  250. seen_non_target_vars_.insert(varId);
  251. return false;
  252. }
  253. const uint32_t varPteTypeId =
  254. varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
  255. Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId);
  256. if (!IsTargetType(varPteTypeInst)) {
  257. seen_non_target_vars_.insert(varId);
  258. return false;
  259. }
  260. seen_target_vars_.insert(varId);
  261. return true;
  262. }
  263. // Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in
  264. // |reachable_blocks|). There are two types of removal that this function can
  265. // perform:
  266. //
  267. // 1- Any operand that comes directly from an unreachable block is completely
  268. // removed. Since the block is unreachable, the edge between the unreachable
  269. // block and the block holding |phi| has been removed.
  270. //
  271. // 2- Any operand that comes via a live block and was defined at an unreachable
  272. // block gets its value replaced with an OpUndef value. Since the argument
  273. // was generated in an unreachable block, it no longer exists, so it cannot
  274. // be referenced. However, since the value does not reach |phi| directly
  275. // from the unreachable block, the operand cannot be removed from |phi|.
  276. // Therefore, we replace the argument value with OpUndef.
  277. //
  278. // For example, in the switch() below, assume that we want to remove the
  279. // argument with value %11 coming from block %41.
  280. //
  281. // [ ... ]
  282. // %41 = OpLabel <--- Unreachable block
  283. // %11 = OpLoad %int %y
  284. // [ ... ]
  285. // OpSelectionMerge %16 None
  286. // OpSwitch %12 %16 10 %13 13 %14 18 %15
  287. // %13 = OpLabel
  288. // OpBranch %16
  289. // %14 = OpLabel
  290. // OpStore %outparm %int_14
  291. // OpBranch %16
  292. // %15 = OpLabel
  293. // OpStore %outparm %int_15
  294. // OpBranch %16
  295. // %16 = OpLabel
  296. // %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15
  297. //
  298. // Since %41 is now an unreachable block, the first operand of |phi| needs to
  299. // be removed completely. But the operands (%11 %14) and (%11 %15) cannot be
  300. // removed because %14 and %15 are reachable blocks. Since %11 no longer exist,
  301. // in those arguments, we replace all references to %11 with an OpUndef value.
  302. // This results in |phi| looking like:
  303. //
  304. // %50 = OpUndef %int
  305. // [ ... ]
  306. // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15
  307. void MemPass::RemovePhiOperands(
  308. Instruction* phi, const std::unordered_set<BasicBlock*>& reachable_blocks) {
  309. std::vector<Operand> keep_operands;
  310. uint32_t type_id = 0;
  311. // The id of an undefined value we've generated.
  312. uint32_t undef_id = 0;
  313. // Traverse all the operands in |phi|. Build the new operand vector by adding
  314. // all the original operands from |phi| except the unwanted ones.
  315. for (uint32_t i = 0; i < phi->NumOperands();) {
  316. if (i < 2) {
  317. // The first two arguments are always preserved.
  318. keep_operands.push_back(phi->GetOperand(i));
  319. ++i;
  320. continue;
  321. }
  322. // The remaining Phi arguments come in pairs. Index 'i' contains the
  323. // variable id, index 'i + 1' is the originating block id.
  324. assert(i % 2 == 0 && i < phi->NumOperands() - 1 &&
  325. "malformed Phi arguments");
  326. BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1));
  327. if (reachable_blocks.find(in_block) == reachable_blocks.end()) {
  328. // If the incoming block is unreachable, remove both operands as this
  329. // means that the |phi| has lost an incoming edge.
  330. i += 2;
  331. continue;
  332. }
  333. // In all other cases, the operand must be kept but may need to be changed.
  334. uint32_t arg_id = phi->GetSingleWordOperand(i);
  335. Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id);
  336. BasicBlock* def_block = context()->get_instr_block(arg_def_instr);
  337. if (def_block &&
  338. reachable_blocks.find(def_block) == reachable_blocks.end()) {
  339. // If the current |phi| argument was defined in an unreachable block, it
  340. // means that this |phi| argument is no longer defined. Replace it with
  341. // |undef_id|.
  342. if (!undef_id) {
  343. type_id = arg_def_instr->type_id();
  344. undef_id = Type2Undef(type_id);
  345. }
  346. keep_operands.push_back(
  347. Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id}));
  348. } else {
  349. // Otherwise, the argument comes from a reachable block or from no block
  350. // at all (meaning that it was defined in the global section of the
  351. // program). In both cases, keep the argument intact.
  352. keep_operands.push_back(phi->GetOperand(i));
  353. }
  354. keep_operands.push_back(phi->GetOperand(i + 1));
  355. i += 2;
  356. }
  357. context()->ForgetUses(phi);
  358. phi->ReplaceOperands(keep_operands);
  359. context()->AnalyzeUses(phi);
  360. }
  361. void MemPass::RemoveBlock(Function::iterator* bi) {
  362. auto& rm_block = **bi;
  363. // Remove instructions from the block.
  364. rm_block.ForEachInst([&rm_block, this](Instruction* inst) {
  365. // Note that we do not kill the block label instruction here. The label
  366. // instruction is needed to identify the block, which is needed by the
  367. // removal of phi operands.
  368. if (inst != rm_block.GetLabelInst()) {
  369. context()->KillInst(inst);
  370. }
  371. });
  372. // Remove the label instruction last.
  373. auto label = rm_block.GetLabelInst();
  374. context()->KillInst(label);
  375. *bi = bi->Erase();
  376. }
  377. bool MemPass::RemoveUnreachableBlocks(Function* func) {
  378. if (func->IsDeclaration()) return false;
  379. bool modified = false;
  380. // Mark reachable all blocks reachable from the function's entry block.
  381. std::unordered_set<BasicBlock*> reachable_blocks;
  382. std::unordered_set<BasicBlock*> visited_blocks;
  383. std::queue<BasicBlock*> worklist;
  384. reachable_blocks.insert(func->entry().get());
  385. // Initially mark the function entry point as reachable.
  386. worklist.push(func->entry().get());
  387. auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist,
  388. this](uint32_t label_id) {
  389. auto successor = cfg()->block(label_id);
  390. if (visited_blocks.count(successor) == 0) {
  391. reachable_blocks.insert(successor);
  392. worklist.push(successor);
  393. visited_blocks.insert(successor);
  394. }
  395. };
  396. // Transitively mark all blocks reachable from the entry as reachable.
  397. while (!worklist.empty()) {
  398. BasicBlock* block = worklist.front();
  399. worklist.pop();
  400. // All the successors of a live block are also live.
  401. static_cast<const BasicBlock*>(block)->ForEachSuccessorLabel(
  402. mark_reachable);
  403. // All the Merge and ContinueTarget blocks of a live block are also live.
  404. block->ForMergeAndContinueLabel(mark_reachable);
  405. }
  406. // Update operands of Phi nodes that reference unreachable blocks.
  407. for (auto& block : *func) {
  408. // If the block is about to be removed, don't bother updating its
  409. // Phi instructions.
  410. if (reachable_blocks.count(&block) == 0) {
  411. continue;
  412. }
  413. // If the block is reachable and has Phi instructions, remove all
  414. // operands from its Phi instructions that reference unreachable blocks.
  415. // If the block has no Phi instructions, this is a no-op.
  416. block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) {
  417. RemovePhiOperands(phi, reachable_blocks);
  418. });
  419. }
  420. // Erase unreachable blocks.
  421. for (auto ebi = func->begin(); ebi != func->end();) {
  422. if (reachable_blocks.count(&*ebi) == 0) {
  423. RemoveBlock(&ebi);
  424. modified = true;
  425. } else {
  426. ++ebi;
  427. }
  428. }
  429. return modified;
  430. }
  431. bool MemPass::CFGCleanup(Function* func) {
  432. bool modified = false;
  433. modified |= RemoveUnreachableBlocks(func);
  434. return modified;
  435. }
  436. void MemPass::CollectTargetVars(Function* func) {
  437. seen_target_vars_.clear();
  438. seen_non_target_vars_.clear();
  439. type2undefs_.clear();
  440. // Collect target (and non-) variable sets. Remove variables with
  441. // non-load/store refs from target variable set
  442. for (auto& blk : *func) {
  443. for (auto& inst : blk) {
  444. switch (inst.opcode()) {
  445. case spv::Op::OpStore:
  446. case spv::Op::OpLoad: {
  447. uint32_t varId;
  448. (void)GetPtr(&inst, &varId);
  449. if (!IsTargetVar(varId)) break;
  450. if (HasOnlySupportedRefs(varId)) break;
  451. seen_non_target_vars_.insert(varId);
  452. seen_target_vars_.erase(varId);
  453. } break;
  454. default:
  455. break;
  456. }
  457. }
  458. }
  459. }
  460. } // namespace opt
  461. } // namespace spvtools