BoundsChecking.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file implements a pass that instruments the code to perform run-time
  11. // bounds checking on loads, stores, and other memory intrinsics.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Instrumentation.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/Analysis/MemoryBuiltins.h"
  17. #include "llvm/Analysis/TargetFolder.h"
  18. #include "llvm/Analysis/TargetLibraryInfo.h"
  19. #include "llvm/IR/DataLayout.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/InstIterator.h"
  22. #include "llvm/IR/Intrinsics.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Support/CommandLine.h"
  25. #include "llvm/Support/Debug.h"
  26. #include "llvm/Support/raw_ostream.h"
  27. using namespace llvm;
  28. #define DEBUG_TYPE "bounds-checking"
  29. static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap",
  30. cl::desc("Use one trap block per function"));
  31. STATISTIC(ChecksAdded, "Bounds checks added");
  32. STATISTIC(ChecksSkipped, "Bounds checks skipped");
  33. STATISTIC(ChecksUnable, "Bounds checks unable to add");
  34. typedef IRBuilder<true, TargetFolder> BuilderTy;
  35. namespace {
  36. struct BoundsChecking : public FunctionPass {
  37. static char ID;
  38. BoundsChecking() : FunctionPass(ID) {
  39. initializeBoundsCheckingPass(*PassRegistry::getPassRegistry());
  40. }
  41. bool runOnFunction(Function &F) override;
  42. void getAnalysisUsage(AnalysisUsage &AU) const override {
  43. AU.addRequired<TargetLibraryInfoWrapperPass>();
  44. }
  45. private:
  46. const TargetLibraryInfo *TLI;
  47. ObjectSizeOffsetEvaluator *ObjSizeEval;
  48. BuilderTy *Builder;
  49. Instruction *Inst;
  50. BasicBlock *TrapBB;
  51. BasicBlock *getTrapBB();
  52. void emitBranchToTrap(Value *Cmp = nullptr);
  53. bool instrument(Value *Ptr, Value *Val, const DataLayout &DL);
  54. };
  55. }
  56. char BoundsChecking::ID = 0;
  57. INITIALIZE_PASS(BoundsChecking, "bounds-checking", "Run-time bounds checking",
  58. false, false)
  59. /// getTrapBB - create a basic block that traps. All overflowing conditions
  60. /// branch to this block. There's only one trap block per function.
  61. BasicBlock *BoundsChecking::getTrapBB() {
  62. if (TrapBB && SingleTrapBB)
  63. return TrapBB;
  64. Function *Fn = Inst->getParent()->getParent();
  65. IRBuilder<>::InsertPointGuard Guard(*Builder);
  66. TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
  67. Builder->SetInsertPoint(TrapBB);
  68. llvm::Value *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
  69. CallInst *TrapCall = Builder->CreateCall(F, {});
  70. TrapCall->setDoesNotReturn();
  71. TrapCall->setDoesNotThrow();
  72. TrapCall->setDebugLoc(Inst->getDebugLoc());
  73. Builder->CreateUnreachable();
  74. return TrapBB;
  75. }
  76. /// emitBranchToTrap - emit a branch instruction to a trap block.
  77. /// If Cmp is non-null, perform a jump only if its value evaluates to true.
  78. void BoundsChecking::emitBranchToTrap(Value *Cmp) {
  79. // check if the comparison is always false
  80. ConstantInt *C = dyn_cast_or_null<ConstantInt>(Cmp);
  81. if (C) {
  82. ++ChecksSkipped;
  83. if (!C->getZExtValue())
  84. return;
  85. else
  86. Cmp = nullptr; // unconditional branch
  87. }
  88. ++ChecksAdded;
  89. Instruction *Inst = Builder->GetInsertPoint();
  90. BasicBlock *OldBB = Inst->getParent();
  91. BasicBlock *Cont = OldBB->splitBasicBlock(Inst);
  92. OldBB->getTerminator()->eraseFromParent();
  93. if (Cmp)
  94. BranchInst::Create(getTrapBB(), Cont, Cmp, OldBB);
  95. else
  96. BranchInst::Create(getTrapBB(), OldBB);
  97. }
  98. /// instrument - adds run-time bounds checks to memory accessing instructions.
  99. /// Ptr is the pointer that will be read/written, and InstVal is either the
  100. /// result from the load or the value being stored. It is used to determine the
  101. /// size of memory block that is touched.
  102. /// Returns true if any change was made to the IR, false otherwise.
  103. bool BoundsChecking::instrument(Value *Ptr, Value *InstVal,
  104. const DataLayout &DL) {
  105. uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType());
  106. DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
  107. << " bytes\n");
  108. SizeOffsetEvalType SizeOffset = ObjSizeEval->compute(Ptr);
  109. if (!ObjSizeEval->bothKnown(SizeOffset)) {
  110. ++ChecksUnable;
  111. return false;
  112. }
  113. Value *Size = SizeOffset.first;
  114. Value *Offset = SizeOffset.second;
  115. ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size);
  116. Type *IntTy = DL.getIntPtrType(Ptr->getType());
  117. Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
  118. // three checks are required to ensure safety:
  119. // . Offset >= 0 (since the offset is given from the base ptr)
  120. // . Size >= Offset (unsigned)
  121. // . Size - Offset >= NeededSize (unsigned)
  122. //
  123. // optimization: if Size >= 0 (signed), skip 1st check
  124. // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows
  125. Value *ObjSize = Builder->CreateSub(Size, Offset);
  126. Value *Cmp2 = Builder->CreateICmpULT(Size, Offset);
  127. Value *Cmp3 = Builder->CreateICmpULT(ObjSize, NeededSizeVal);
  128. Value *Or = Builder->CreateOr(Cmp2, Cmp3);
  129. if (!SizeCI || SizeCI->getValue().slt(0)) {
  130. Value *Cmp1 = Builder->CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0));
  131. Or = Builder->CreateOr(Cmp1, Or);
  132. }
  133. emitBranchToTrap(Or);
  134. return true;
  135. }
  136. bool BoundsChecking::runOnFunction(Function &F) {
  137. const DataLayout &DL = F.getParent()->getDataLayout();
  138. TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
  139. TrapBB = nullptr;
  140. BuilderTy TheBuilder(F.getContext(), TargetFolder(DL));
  141. Builder = &TheBuilder;
  142. ObjectSizeOffsetEvaluator TheObjSizeEval(DL, TLI, F.getContext(),
  143. /*RoundToAlign=*/true);
  144. ObjSizeEval = &TheObjSizeEval;
  145. // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
  146. // touching instructions
  147. std::vector<Instruction*> WorkList;
  148. for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
  149. Instruction *I = &*i;
  150. if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) ||
  151. isa<AtomicRMWInst>(I))
  152. WorkList.push_back(I);
  153. }
  154. bool MadeChange = false;
  155. for (std::vector<Instruction*>::iterator i = WorkList.begin(),
  156. e = WorkList.end(); i != e; ++i) {
  157. Inst = *i;
  158. Builder->SetInsertPoint(Inst);
  159. if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
  160. MadeChange |= instrument(LI->getPointerOperand(), LI, DL);
  161. } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
  162. MadeChange |=
  163. instrument(SI->getPointerOperand(), SI->getValueOperand(), DL);
  164. } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) {
  165. MadeChange |=
  166. instrument(AI->getPointerOperand(), AI->getCompareOperand(), DL);
  167. } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) {
  168. MadeChange |=
  169. instrument(AI->getPointerOperand(), AI->getValOperand(), DL);
  170. } else {
  171. llvm_unreachable("unknown Instruction type");
  172. }
  173. }
  174. return MadeChange;
  175. }
  176. FunctionPass *llvm::createBoundsCheckingPass() {
  177. return new BoundsChecking();
  178. }