LLVMUtils.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #include "llvm/Analysis/CFGPrinter.h" // needed for DOTGraphTraits<const Function*>
  2. #include "llvm/IR/Constants.h"
  3. #include "llvm/IR/Instructions.h"
  4. #include "llvm/IR/LegacyPassManager.h"
  5. #include "llvm/IR/Module.h"
  6. #include "llvm/IRReader/IRReader.h"
  7. #include "llvm/Support/FileSystem.h"
  8. #include "llvm/Support/raw_ostream.h"
  9. #include "llvm/Support/SourceMgr.h"
  10. #include "llvm/Support/GraphWriter.h"
  11. using namespace llvm;
  12. std::vector<CallInst*> getCallsToFunction(Function* callee, const Function* caller)
  13. {
  14. std::vector<CallInst*> calls;
  15. if (callee == nullptr)
  16. return calls;
  17. for (auto U = callee->user_begin(), UE = callee->user_end(); U != UE; ++U)
  18. {
  19. CallInst* CI = dyn_cast<CallInst>(*U);
  20. if (!CI) // We are not interested in uses that are not calls
  21. continue;
  22. assert(CI->getCalledFunction() == callee);
  23. if (caller == nullptr || CI->getParent()->getParent() == caller)
  24. calls.push_back(CI);
  25. }
  26. return calls;
  27. }
  28. ConstantInt* makeInt32(int val, LLVMContext& context)
  29. {
  30. return ConstantInt::get(Type::getInt32Ty(context), val);
  31. }
  32. Instruction* getInstructionAfter(Instruction* inst)
  33. {
  34. return ++BasicBlock::iterator(inst);
  35. }
  36. std::unique_ptr<Module> loadModuleFromAsmFile(LLVMContext& context, const std::string& filename)
  37. {
  38. SMDiagnostic err;
  39. std::unique_ptr<Module> mod = parseIRFile(filename, err, context);
  40. if (!mod)
  41. {
  42. err.print(filename.c_str(), errs());
  43. exit(1);
  44. }
  45. return mod;
  46. }
  47. std::unique_ptr<Module> loadModuleFromAsmString(LLVMContext& context, const std::string& str)
  48. {
  49. SMDiagnostic err;
  50. MemoryBufferRef memBuffer(str, "id");
  51. std::unique_ptr<Module> mod = parseIR(memBuffer, err, context);
  52. return mod;
  53. }
  54. void saveModuleToAsmFile(const llvm::Module* mod, const std::string& filename)
  55. {
  56. std::error_code EC;
  57. raw_fd_ostream out(filename, EC, sys::fs::F_Text);
  58. if (!out.has_error())
  59. {
  60. mod->print(out, 0);
  61. out.close();
  62. }
  63. if (out.has_error())
  64. {
  65. errs() << "Error saving to " << filename << "\n";
  66. exit(1);
  67. }
  68. }
  69. void dumpCFG(const Function* F, const std::string& suffix)
  70. {
  71. std::string filename = ("cfg." + F->getName() + "." + suffix + ".dot").str();
  72. std::error_code EC;
  73. raw_fd_ostream out(filename, EC, sys::fs::F_Text);
  74. if (!out.has_error())
  75. {
  76. errs() << "Writing '" << filename << "'...\n";
  77. WriteGraph(out, F, true, F->getName());
  78. out.close();
  79. }
  80. if (out.has_error())
  81. {
  82. errs() << "Error saving to " << filename << "\n";
  83. exit(1);
  84. }
  85. }
  86. Function* getOrCreateFunction(const std::string& name, Module* mod, FunctionType* funcType, std::map<FunctionType*, Function*>& typeToFuncMap)
  87. {
  88. auto it = typeToFuncMap.find(funcType);
  89. if (it != typeToFuncMap.end())
  90. return it->second;
  91. // Give name a numerical suffix to make it unique
  92. std::string uniqueName = name + std::to_string(typeToFuncMap.size());
  93. Function* F = dyn_cast<Function>(mod->getOrInsertFunction(uniqueName, funcType));
  94. typeToFuncMap[funcType] = F;
  95. return F;
  96. }
  97. void runPasses(llvm::Function* F, const std::vector<llvm::Pass*>& passes)
  98. {
  99. legacy::FunctionPassManager FPM(F->getParent());
  100. for (Pass* pass : passes)
  101. FPM.add(pass);
  102. FPM.doInitialization();
  103. FPM.run(*F);
  104. FPM.doFinalization();
  105. }