Reducibility.cpp 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #include "Reducibility.h"
  2. #include "llvm/ADT/SetVector.h"
  3. #include "llvm/IR/BasicBlock.h"
  4. #include "llvm/IR/CFG.h"
  5. #include "llvm/IR/Function.h"
  6. #include "llvm/IR/Instructions.h"
  7. #include "llvm/IR/LegacyPassManager.h"
  8. #include "llvm/Support/raw_ostream.h"
  9. #include "llvm/Transforms/Utils/Cloning.h"
  10. #include "llvm/Transforms/Scalar.h"
  11. #include "LLVMUtils.h"
  12. #include <fstream>
  13. #include <vector>
  14. #include <map>
  15. #define DBGS errs
  16. //#define DBGS dbgs
  17. using namespace llvm;
  18. struct Node
  19. {
  20. SetVector<Node*> in;
  21. SetVector<Node*> out;
  22. SetVector<BasicBlock*> blocks; // block 0 dominates all others in this node
  23. size_t numInstructions = 0;
  24. Node() {}
  25. Node(BasicBlock* B) { insert(B); }
  26. void insert(BasicBlock* B)
  27. {
  28. numInstructions += B->size();
  29. blocks.insert(B);
  30. }
  31. };
  32. static void printDotGraph(const std::vector<Node*> nodes, const std::string& filename)
  33. {
  34. DBGS() << "Writing " << filename << " ...";
  35. std::ofstream out(filename);
  36. if (!out)
  37. {
  38. DBGS() << "FAILED\n";
  39. return;
  40. }
  41. // Give nodes a numerical index to make the output cleaner
  42. std::map<Node*, int> idxMap;
  43. for (size_t i = 0; i < nodes.size(); ++i)
  44. idxMap[nodes[i]] = i;
  45. // Error check - make sure that all the out/in nodes are in the map
  46. for (Node* N : nodes)
  47. {
  48. for (Node* P : N->in)
  49. {
  50. if (idxMap.find(P) == idxMap.end())
  51. DBGS() << "MISSING INPUT NODE\n";
  52. if (P->out.count(N) == 0)
  53. DBGS() << "MISSING OUTGOING EDGE FROM PREDECESSOR.\n";
  54. }
  55. for (Node* S : N->out)
  56. {
  57. if (idxMap.find(S) == idxMap.end())
  58. DBGS() << "MISSING OUTPUT NODE\n";
  59. if (S->in.count(N) == 0)
  60. DBGS() << "MISSING INCOMING EDGE FROM SUCCESSOR.\n";
  61. }
  62. }
  63. // Print header
  64. out << "digraph g {\n";
  65. out << "node [\n";
  66. out << " fontsize = \"12\"\n";
  67. out << " labeljust = \"l\"\n";
  68. out << "]\n";
  69. for (unsigned i = 0; i < nodes.size(); ++i)
  70. {
  71. Node* N = nodes[i];
  72. // node
  73. out << " N" << i << " [shape=record,label=\"";
  74. for (BasicBlock* B : N->blocks)
  75. out << B->getName().str() << "\\n";
  76. out << "\"];\n";
  77. // out edges
  78. for (Node* S : N->out)
  79. out << " N" << i << " -> N" << idxMap[S] << ";\n";
  80. // in edges
  81. //for( Node* P : N->in )
  82. // out << " N" << idxMap[P] << " -> N" << i << " [style=dotted];\n";
  83. }
  84. out << "}\n";
  85. DBGS() << "\n";
  86. }
  87. static void printDotGraph(const std::vector<Node*> nodes, Function* F, int step)
  88. {
  89. printDotGraph(nodes, ("red." + F->getName() + "_" + std::to_string(step) + ".dot").str());
  90. }
  91. static Node* split(Node* N, std::map<BasicBlock*, Node*>& bbToNode, bool firstSplit)
  92. {
  93. // Remove one predecessor P from N
  94. assert(N->in.size() > 1);
  95. Node* P = N->in.pop_back_val();
  96. P->out.remove(N);
  97. // Point P to the clone of N, Np
  98. Node* Np = new Node();
  99. P->out.insert(Np);
  100. Np->in.insert(P);
  101. // Copy successors of N to Np
  102. for (Node* S : N->out)
  103. {
  104. Np->out.insert(S);
  105. S->in.insert(Np);
  106. }
  107. #if 1
  108. // Run reg2mem on the whole function so we don't have to deal with phis
  109. if (firstSplit)
  110. {
  111. runPasses(N->blocks[0]->getParent(), {
  112. createDemoteRegisterToMemoryPass()
  113. });
  114. }
  115. // Clone N into Np
  116. ValueToValueMapTy VMap;
  117. for (BasicBlock* B : N->blocks)
  118. {
  119. BasicBlock* Bp = CloneBasicBlock(B, VMap, ".c", B->getParent());
  120. Np->insert(Bp);
  121. VMap[B] = Bp;
  122. }
  123. for (BasicBlock* B : Np->blocks)
  124. for (Instruction& I : *B)
  125. RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  126. // Remap terminators of P from N to Np
  127. for (BasicBlock* B : P->blocks)
  128. RemapInstruction(B->getTerminator(), VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  129. #else
  130. // Clone N into Np
  131. ValueToValueMapTy VMap;
  132. for (BasicBlock* B : N->blocks)
  133. {
  134. BasicBlock* Bp = CloneBasicBlock(B, VMap, ".c", B->getParent());
  135. Np->insert(Bp);
  136. VMap[B] = Bp;
  137. }
  138. for (BasicBlock* B : Np->blocks)
  139. for (Instruction& I : *B)
  140. RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  141. // Remove incoming values from phis in Np that don't come from actual predecessors
  142. BasicBlock* NpEntry = Np->blocks[0];
  143. std::set<BasicBlock*> predSet(pred_begin(NpEntry), pred_end(NpEntry));
  144. auto I = NpEntry->begin();
  145. while (PHINode* phi = dyn_cast<PHINode>(I++))
  146. {
  147. if (phi->getNumIncomingValues() == predSet.size())
  148. continue;
  149. for (unsigned i = 0; i < phi->getNumIncomingValues(); )
  150. {
  151. BasicBlock* B = phi->getIncomingBlock(i);
  152. if (!predSet.count(B))
  153. {
  154. phi->removeIncomingValue(B);
  155. continue;
  156. }
  157. ++i;
  158. }
  159. }
  160. // Remove phi references to P in N. (Do this before remapping terminators.)
  161. BasicBlock* Nentry = N->blocks[0];
  162. for (BasicBlock* PB : predecessors(Nentry))
  163. {
  164. if (P->blocks.count(PB))
  165. Nentry->removePredecessor(PB);
  166. }
  167. // Remap terminators of P from N to Np
  168. for (BasicBlock* B : P->blocks)
  169. RemapInstruction(B->getTerminator(), VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  170. // Update phis in successors of Np.
  171. // There are several cases for a value Vs reaching S. Vs may be defined in N and
  172. // a clone Vsp in Np or only passing through one or the other. Furthermore, Vs may
  173. // either appear in a phi in the entry block of S or not.
  174. // 1) Vs defined in N (and clone Vsp in Np) and in phi:
  175. // Add incoming value [Vsp, Bp] for cloned value Vsp from predecessor basic
  176. // block Bp in Np wherever [Vs, B] appears
  177. // 2) Vs defined in N (and clone Vsp in Np) and not in phi:
  178. // Add phi [Vs,B],[Vsp,Bp] if Vs reaches a use in or through S
  179. // 3) Vs passing through N or Np and in phi
  180. // Change [Vs,B] to [Vs,Bp] in phis in S if Vs reached S through P
  181. // 4) Vs passing through N or Np and not in a phi
  182. // Do nothing
  183. //
  184. // TODO: Only 1) is implemented below and it isn't checking for definition in N
  185. for (Node* S : Np->out)
  186. {
  187. BasicBlock* Sentry = S->blocks[0];
  188. auto I = Sentry->begin();
  189. while (PHINode* phi = dyn_cast<PHINode>(I++))
  190. {
  191. for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i)
  192. {
  193. BasicBlock* B = phi->getIncomingBlock(i);
  194. if (N->blocks.count(B))
  195. {
  196. Value* V = phi->getIncomingValue(i);
  197. Value* Vp = VMap[V];
  198. if (!Vp)
  199. Vp = V; // Def not in N
  200. BasicBlock* Bp = dyn_cast<BasicBlock>(VMap[B]);
  201. phi->addIncoming(Vp, Bp);
  202. }
  203. }
  204. }
  205. }
  206. #endif
  207. return Np;
  208. }
  209. // Returns the number of splits
  210. int makeReducible(Function* F)
  211. {
  212. // Break critical edges now in case we need to do mem2reg in split(). mem2reg
  213. // will break critical edges and the CFG needs to remain unchanged.
  214. runPasses(F, {
  215. createBreakCriticalEdgesPass()
  216. });
  217. // initialize nodes
  218. std::vector<Node*> nodes;
  219. std::map<BasicBlock*, Node*> bbToNode;
  220. for (BasicBlock& B : *F)
  221. {
  222. nodes.push_back(new Node(&B));
  223. bbToNode[&B] = nodes.back();
  224. }
  225. // initialize edges
  226. for (Node* N : nodes)
  227. {
  228. for (BasicBlock* B : successors(N->blocks[0]))
  229. {
  230. Node* BN = bbToNode[B];
  231. N->out.insert(BN);
  232. BN->in.insert(N);
  233. }
  234. }
  235. int step = 0;
  236. bool print = false;
  237. if (print) printDotGraph(nodes, F, step++);
  238. int numSplits = 0;
  239. while (!nodes.empty())
  240. {
  241. bool changed;
  242. do
  243. {
  244. // It might more efficient to use a worklist based implementation instead
  245. // of iterating over the vector.
  246. changed = false;
  247. for (size_t i = 0; i < nodes.size(); )
  248. {
  249. Node* N = nodes[i];
  250. // Remove self references
  251. if (N->in.count(N))
  252. {
  253. N->in.remove(N);
  254. N->out.remove(N);
  255. changed = true;
  256. }
  257. // Remove singletons
  258. if (N->in.size() == 0 && N->out.size() == 0)
  259. {
  260. nodes.erase(nodes.begin() + i);
  261. changed = true;
  262. if (print) printDotGraph(nodes, F, step++);
  263. continue;
  264. }
  265. // Remove nodes with only one incoming edge
  266. if (N->in.size() == 1)
  267. {
  268. // fold into predecessor
  269. Node* P = N->in.back();
  270. P->blocks.insert(N->blocks.begin(), N->blocks.end());
  271. P->out.remove(N);
  272. for (Node* S : N->out)
  273. {
  274. S->in.remove(N);
  275. P->out.insert(S);
  276. S->in.insert(P);
  277. }
  278. P->numInstructions += N->numInstructions;
  279. nodes.erase(nodes.begin() + i);
  280. changed = true;
  281. if (print) printDotGraph(nodes, F, step++);
  282. continue;
  283. }
  284. i++;
  285. }
  286. } while (changed);
  287. if (!nodes.empty())
  288. {
  289. // Duplicate the smallest node with more than one incoming edge. Better
  290. // methods exist for picking the node to split, e.g. "Making Graphs Reducible
  291. // with Controlled Node Splitting" by Janssen and Corporaal.
  292. size_t idxMin = ~0;
  293. for (size_t i = 0; i < nodes.size(); ++i)
  294. {
  295. if (nodes[i]->in.size() <= 1)
  296. continue;
  297. if (idxMin == ~0u || nodes[i]->numInstructions < nodes[idxMin]->numInstructions)
  298. idxMin = i;
  299. }
  300. nodes.push_back(split(nodes[idxMin], bbToNode, numSplits == 0));
  301. numSplits++;
  302. if (print) printDotGraph(nodes, F, step++);
  303. }
  304. }
  305. return numSplits;
  306. }