StateFunctionTransform.cpp 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797
  1. #include "StateFunctionTransform.h"
  2. #include "llvm/IR/CFG.h"
  3. #include "llvm/IR/Constants.h"
  4. #include "llvm/IR/InstIterator.h"
  5. #include "llvm/IR/Instructions.h"
  6. #include "llvm/IR/LegacyPassManager.h"
  7. #include "llvm/IR/PassManager.h"
  8. #include "llvm/IR/ValueMap.h"
  9. #include "llvm/IR/Verifier.h"
  10. #include "llvm/Support/FileSystem.h"
  11. #include "llvm/Transforms/Scalar.h"
  12. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  13. #include "llvm/Transforms/Utils/Cloning.h"
  14. #include "llvm/Transforms/Utils/Local.h"
  15. #include "FunctionBuilder.h"
  16. #include "LiveValues.h"
  17. #include "LLVMUtils.h"
  18. #include "Reducibility.h"
  19. #define DBGS dbgs
  20. //#define DBGS errs
  21. using namespace llvm;
  22. static const char* CALL_INDIRECT_NAME = "\x1?Fallback_CallIndirect@@YAXH@Z";
  23. static const char* SET_PENDING_ATTR_PREFIX = "\x1?Fallback_SetPendingAttr@@";
  24. // Create a string with printf-like arguments
  25. inline std::string stringf(const char* fmt, ...)
  26. {
  27. va_list args;
  28. va_start(args, fmt);
  29. #ifdef WIN32
  30. int size = _vscprintf(fmt, args);
  31. #else
  32. int size = vsnprintf(0, 0, fmt, args);
  33. #endif
  34. va_end(args);
  35. std::string ret;
  36. if (size > 0)
  37. {
  38. ret.resize(size);
  39. va_start(args, fmt);
  40. vsnprintf((char*)ret.data(), size + 1, fmt, args);
  41. va_end(args);
  42. }
  43. return ret;
  44. }
  45. // Remove ELF mangling
  46. static std::string cleanName(StringRef name)
  47. {
  48. if (!name.startswith("\x1?"))
  49. return name;
  50. size_t pos = name.find("@@");
  51. if (pos == name.npos)
  52. return name;
  53. std::string newName = name.substr(2, pos - 2);
  54. return newName;
  55. }
  56. // Utility to append the suffix to the name of the value, but returns
  57. // an empty string if name is empty. This is to avoid names like ".ptr".
  58. static std::string addSuffix(StringRef valueName, StringRef suffix)
  59. {
  60. if (!valueName.empty())
  61. {
  62. if (valueName.back() == '.' && suffix.front() == '.') // avoid double dots
  63. return (valueName + suffix.substr(1)).str();
  64. else
  65. return (valueName + suffix).str();
  66. }
  67. else
  68. return valueName.str();
  69. }
  70. // Remove suffix from name.
  71. static std::string stripSuffix(StringRef name, StringRef suffix)
  72. {
  73. size_t pos = name.rfind(suffix);
  74. if (pos != name.npos)
  75. return name.substr(0, pos).str();
  76. else
  77. return name.str();
  78. }
  79. static std::string stripAfter(StringRef name, StringRef suffixStart)
  80. {
  81. size_t pos = name.find(suffixStart);
  82. if (pos != name.npos)
  83. return name.substr(0, pos).str();
  84. else
  85. return name.str();
  86. }
  87. // Insert str before the final "." in filename.
  88. static std::string insertBeforeExtension(const std::string& filename, const std::string& str)
  89. {
  90. std::string ret = filename;
  91. size_t pos = filename.rfind('.');
  92. if (pos != std::string::npos)
  93. ret.insert(pos, str);
  94. else
  95. ret += str;
  96. return ret;
  97. }
  98. // Inserts <functionName>-<id>-<suffix> before the extension in baseName
  99. static std::string createDumpPath(
  100. const std::string& baseName,
  101. unsigned id,
  102. const std::string& suffix,
  103. const std::string& functionName)
  104. {
  105. std::string s;
  106. if (!functionName.empty())
  107. s = "-" + functionName;
  108. s += stringf("-%02d-", id) + suffix;
  109. return insertBeforeExtension(baseName, s);
  110. }
  111. // Return byte offset aligned to the alignment required by inst.
  112. static uint64_t align(uint64_t offset, Instruction* inst, DataLayout& DL)
  113. {
  114. unsigned alignment = 0;
  115. if (AllocaInst* ai = dyn_cast<AllocaInst>(inst))
  116. alignment = ai->getAlignment();
  117. if (alignment == 0)
  118. alignment = DL.getPrefTypeAlignment(inst->getType());
  119. return RoundUpToAlignment(offset, alignment);
  120. }
  121. template <class T> // T can be Value* or Instruction*
  122. T createCastForStack(T ptr, llvm::Type* targetPtrElemType, llvm::Instruction* insertBefore)
  123. {
  124. llvm::PointerType* requiredType = llvm::PointerType::get(targetPtrElemType, ptr->getType()->getPointerAddressSpace());
  125. if (ptr->getType() == requiredType)
  126. return ptr;
  127. return new llvm::BitCastInst(ptr, requiredType, ptr->getName(), insertBefore);
  128. }
  129. static Value* createCastToInt(Value* val, Instruction* insertBefore)
  130. {
  131. Type* i32Ty = Type::getInt32Ty(val->getContext());
  132. if (val->getType() == i32Ty)
  133. return val;
  134. if (val->getType() == Type::getInt1Ty(val->getContext()))
  135. return new ZExtInst(val, i32Ty, addSuffix(val->getName(), ".int"), insertBefore);
  136. Value* intVal = new BitCastInst(val, i32Ty, addSuffix(val->getName(), ".int"), insertBefore);
  137. return intVal;
  138. }
  139. static Value* createCastFromInt(Value* intVal, Type* ty, Instruction* insertBefore)
  140. {
  141. Type* i32Ty = Type::getInt32Ty(intVal->getContext());
  142. if (ty == i32Ty)
  143. return intVal;
  144. std::string name = intVal->getName();
  145. intVal->setName(addSuffix(name, ".int"));
  146. // Create boolean with compare
  147. if (ty == Type::getInt1Ty(intVal->getContext()))
  148. return new ICmpInst(insertBefore, CmpInst::ICMP_SGT, intVal, makeInt32(0, intVal->getContext()), name);
  149. return new BitCastInst(intVal, ty, name, insertBefore);
  150. }
  151. // Gives every value in the given function a name. This can aid in debugging.
  152. static void dbgNameUnnamedVals(Function* func)
  153. {
  154. Type* voidTy = Type::getVoidTy(func->getContext());
  155. for (auto& I : inst_range(func))
  156. {
  157. if (!I.hasName() && I.getType() != voidTy)
  158. I.setName("v"); // LLVM will uniquify the name by adding a numeric suffix
  159. }
  160. }
  161. // Returns an iterator for the instruction after the last alloca in the entry block
  162. // (assuming that allocas are at the top of the entry block).
  163. static BasicBlock::iterator afterEntryBlockAllocas(Function* function)
  164. {
  165. BasicBlock::iterator insertBefore = function->getEntryBlock().begin();
  166. while (isa<AllocaInst>(insertBefore))
  167. ++insertBefore;
  168. return insertBefore;
  169. }
  170. // Return all the blocks reachable from entryBlock.
  171. static BasicBlockVector getReachableBlocks(BasicBlock* entryBlock)
  172. {
  173. BasicBlockVector blocks;
  174. std::deque<BasicBlock*> stack = { entryBlock };
  175. ::BasicBlockSet visited = { entryBlock };
  176. while (!stack.empty())
  177. {
  178. BasicBlock* block = stack.front();
  179. stack.pop_front();
  180. blocks.push_back(block);
  181. TerminatorInst* termInst = block->getTerminator();
  182. for (unsigned int succ = 0, succEnd = termInst->getNumSuccessors(); succ != succEnd; ++succ)
  183. {
  184. BasicBlock* succBlock = termInst->getSuccessor(succ);
  185. if (visited.insert(succBlock).second)
  186. stack.push_front(succBlock);
  187. }
  188. }
  189. return blocks;
  190. }
  191. // Creates a new function with the same arguments and attributes as oldFunction
  192. static Function* cloneFunctionPrototype(const Function* oldFunction, ValueToValueMapTy& VMap)
  193. {
  194. std::vector<Type*> argTypes;
  195. for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E; ++I)
  196. argTypes.push_back(I->getType());
  197. FunctionType* FTy = FunctionType::get(oldFunction->getFunctionType()->getReturnType(), argTypes,
  198. oldFunction->getFunctionType()->isVarArg());
  199. Function* newFunction = Function::Create(FTy, oldFunction->getLinkage(), oldFunction->getName());
  200. Function::arg_iterator destI = newFunction->arg_begin();
  201. for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E; ++I, ++destI)
  202. {
  203. destI->setName(I->getName());
  204. VMap[I] = destI;
  205. }
  206. AttributeSet oldAttrs = oldFunction->getAttributes();
  207. for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E; ++I)
  208. {
  209. if (Argument* Anew = dyn_cast<Argument>(VMap[I]))
  210. {
  211. AttributeSet attrs = oldAttrs.getParamAttributes(I->getArgNo() + 1);
  212. if (attrs.getNumSlots() > 0)
  213. Anew->addAttr(attrs);
  214. }
  215. }
  216. newFunction->setAttributes(newFunction->getAttributes().addAttributes(newFunction->getContext(), AttributeSet::ReturnIndex,
  217. oldAttrs.getRetAttributes()));
  218. newFunction->setAttributes(newFunction->getAttributes().addAttributes(newFunction->getContext(), AttributeSet::FunctionIndex,
  219. oldAttrs.getFnAttributes()));
  220. return newFunction;
  221. }
  222. // Creates a new function by cloning blocks reachable from entryBlock
  223. static Function* cloneBlocksReachableFrom(BasicBlock* entryBlock, ValueToValueMapTy& VMap)
  224. {
  225. Function* oldFunction = entryBlock->getParent();
  226. Function* newFunction = cloneFunctionPrototype(oldFunction, VMap);
  227. // Insert a clone of the entry block into the function.
  228. BasicBlock* newEntry = CloneBasicBlock(entryBlock, VMap, "", newFunction);
  229. VMap[entryBlock] = newEntry;
  230. // Clone all other blocks.
  231. BasicBlockVector blocks = getReachableBlocks(entryBlock);
  232. for (auto block : blocks)
  233. {
  234. if (block == entryBlock)
  235. continue;
  236. BasicBlock* clonedBlock = CloneBasicBlock(block, VMap, "", newFunction);
  237. VMap[block] = clonedBlock;
  238. }
  239. // Remap new instructions to reference blocks and instructions of the new function.
  240. for (auto block : blocks)
  241. {
  242. auto clonedBlock = cast<BasicBlock>(VMap[block]);
  243. for (BasicBlock::iterator I = clonedBlock->begin(); I != clonedBlock->end(); ++I)
  244. {
  245. RemapInstruction(I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  246. }
  247. }
  248. // Remove phi operands incoming from blocks that are not present in the new function anymore.
  249. for (auto& block : *newFunction)
  250. {
  251. PHINode* firstPHI = dyn_cast<PHINode>(block.begin());
  252. if (firstPHI == nullptr)
  253. continue; // phi instructions only at beginning
  254. // Create set of actual predecessors
  255. BasicBlockSet preds(pred_begin(&block), pred_end(&block));
  256. if (preds.size() == firstPHI->getNumIncomingValues())
  257. continue;
  258. // Remove phi incoming blocks not in preds
  259. for (auto iter = block.begin(); isa<PHINode>(iter); ++iter)
  260. {
  261. std::vector<unsigned int> toRemove;
  262. PHINode* phi = cast<PHINode>(iter);
  263. for (unsigned int op = 0, opEnd = phi->getNumIncomingValues(); op != opEnd; ++op)
  264. {
  265. BasicBlock* pred = phi->getIncomingBlock(op);
  266. if (preds.count(pred) == 0)
  267. {
  268. toRemove.push_back(op);
  269. }
  270. }
  271. for (auto I = toRemove.rbegin(), E = toRemove.rend(); I != E; ++I)
  272. phi->removeIncomingValue(*I, false);
  273. }
  274. }
  275. return newFunction;
  276. }
  277. // Replace and remove calls to func with val
  278. static void replaceValAndRemoveUnusedDummyFunc(Value* oldVal, Value* newVal, Function* caller)
  279. {
  280. CallInst* call = dyn_cast<CallInst>(oldVal);
  281. assert(call != nullptr && "Must be a call");
  282. Function* func = call->getCalledFunction();
  283. for (CallInst* CI : getCallsToFunction(func, caller))
  284. {
  285. CI->replaceAllUsesWith(newVal);
  286. CI->eraseFromParent();
  287. }
  288. if (func->getNumUses() == 0)
  289. func->eraseFromParent();
  290. }
  291. // Get the integer value of val. If val is not a ConstantInt return false.
  292. static bool getConstantValue(int& constant, const Value* val)
  293. {
  294. const ConstantInt* CI = dyn_cast<ConstantInt>(val);
  295. if (!CI)
  296. return false;
  297. if (CI->getBitWidth() > 32)
  298. return false;
  299. constant = static_cast<int>(CI->getSExtValue());
  300. return true;
  301. }
  302. static int getConstantValue(const Value* val)
  303. {
  304. const ConstantInt* CI = dyn_cast<ConstantInt>(val);
  305. assert(CI && CI->getBitWidth() <= 32);
  306. return static_cast<int>(CI->getSExtValue());
  307. }
  308. struct StoreInfo
  309. {
  310. Function* stackIntPtrFunc;
  311. Value* runtimeDataArg;
  312. Value* baseOffset;
  313. Instruction* insertBefore;
  314. Value* val;
  315. std::vector<Value*> idxList;
  316. };
  317. // Takes the offset at which to store the next value.
  318. // Returns the next available offset.
  319. static int store(int offset, StoreInfo& SI, Type* ty)
  320. {
  321. if (StructType* STy = dyn_cast<StructType>(ty))
  322. {
  323. SI.idxList.push_back(nullptr);
  324. int elIdx = 0;
  325. for (auto& elTy : STy->elements())
  326. {
  327. SI.idxList.back() = makeInt32(elIdx++, ty->getContext());
  328. offset = store(offset, SI, elTy);
  329. }
  330. SI.idxList.pop_back();
  331. }
  332. else if (ArrayType* ATy = dyn_cast<ArrayType>(ty))
  333. {
  334. Type* elTy = ATy->getArrayElementType();
  335. SI.idxList.push_back(nullptr);
  336. for (int elIdx = 0; elIdx < (int)ATy->getArrayNumElements(); ++elIdx)
  337. {
  338. SI.idxList.back() = makeInt32(elIdx, ty->getContext());
  339. offset = store(offset, SI, elTy);
  340. }
  341. SI.idxList.pop_back();
  342. }
  343. else if (PointerType* PTy = dyn_cast<PointerType>(ty))
  344. {
  345. SI.idxList.push_back(makeInt32(0, ty->getContext()));
  346. offset = store(offset, SI, PTy->getPointerElementType());
  347. SI.idxList.pop_back();
  348. }
  349. else
  350. {
  351. Value* val = SI.val;
  352. if (!SI.idxList.empty())
  353. {
  354. Value* gep = GetElementPtrInst::CreateInBounds(SI.val, SI.idxList, "", SI.insertBefore);
  355. val = new LoadInst(gep, "", SI.insertBefore);
  356. }
  357. if (VectorType* VTy = dyn_cast<VectorType>(ty))
  358. {
  359. std::vector<Value*>idxList = std::move(SI.idxList);
  360. Type* elTy = VTy->getVectorElementType();
  361. for (int elIdx = 0; elIdx < (int)VTy->getVectorNumElements(); ++elIdx)
  362. {
  363. Value* idxVal = makeInt32(elIdx, ty->getContext());
  364. Value* el = ExtractElementInst::Create(val, idxVal, "", SI.insertBefore);
  365. SI.val = el;
  366. offset = store(offset, SI, elTy);
  367. }
  368. SI.idxList = std::move(idxList);
  369. }
  370. else
  371. {
  372. Value* idxVal = makeInt32(offset, val->getContext());
  373. Value* intVal = createCastToInt(val, SI.insertBefore);
  374. Value* intPtr = CallInst::Create(SI.stackIntPtrFunc, { SI.runtimeDataArg, SI.baseOffset, idxVal }, addSuffix(val->getName(), ".ptr"), SI.insertBefore);
  375. new StoreInst(intVal, intPtr, SI.insertBefore);
  376. offset += 1;
  377. }
  378. }
  379. return offset;
  380. }
  381. // Store value to the stack at given baseOffset + offset. Will flatten aggregates and vectors.
  382. // Returns the offset where writing left off. For pointer vals stores what is pointed to.
  383. static int store(Value* val, Function* stackIntPtrFunc, Value* runtimeDataArg, Value* baseOffset, int offset, Instruction* insertBefore)
  384. {
  385. StoreInfo SI;
  386. SI.stackIntPtrFunc = stackIntPtrFunc;
  387. SI.runtimeDataArg = runtimeDataArg;
  388. SI.baseOffset = baseOffset;
  389. SI.insertBefore = insertBefore;
  390. SI.val = val;
  391. return store(offset, SI, val->getType());
  392. }
  393. static Value* load(llvm::Function* m_stackIntPtrFunc, Value* runtimeDataArg, Value* offset, Value* idx, const std::string& name, Type* ty, Instruction* insertBefore)
  394. {
  395. if (VectorType* VTy = dyn_cast<VectorType>(ty))
  396. {
  397. LLVMContext& C = ty->getContext();
  398. int baseIdx = getConstantValue(idx);
  399. Type* elTy = VTy->getVectorElementType();
  400. Value* vec = UndefValue::get(VTy);
  401. for (int i = 0; i < (int)VTy->getVectorNumElements(); ++i)
  402. {
  403. std::string elName = stringf("el%d.", i);
  404. Value* intPtr = CallInst::Create(m_stackIntPtrFunc, { runtimeDataArg, offset, makeInt32(baseIdx + i, C) }, elName + "ptr", insertBefore);
  405. Value* intEl = new LoadInst(intPtr, elName, insertBefore);
  406. Value* el = createCastFromInt(intEl, elTy, insertBefore);
  407. vec = InsertElementInst::Create(vec, el, makeInt32(i, C), "tmpvec", insertBefore);
  408. }
  409. vec->setName(name);
  410. return vec;
  411. }
  412. else
  413. {
  414. Value* intPtr = CallInst::Create(m_stackIntPtrFunc, { runtimeDataArg, offset, idx }, addSuffix(name, ".ptr"), insertBefore);
  415. Value* intVal = new LoadInst(intPtr, name, insertBefore);
  416. Value* val = createCastFromInt(intVal, ty, insertBefore);
  417. return val;
  418. }
  419. }
  420. static void reg2Mem(DenseMap<Instruction*, AllocaInst*>& valToAlloca, DenseMap<AllocaInst*, Instruction*>& allocaToVal, Instruction* inst)
  421. {
  422. if (valToAlloca.count(inst))
  423. return;
  424. // Convert the value to an alloca
  425. AllocaInst* allocaPtr = DemoteRegToStack(*inst, false);
  426. if (allocaPtr)
  427. {
  428. valToAlloca[inst] = allocaPtr;
  429. allocaToVal[allocaPtr] = inst;
  430. }
  431. }
  432. // Utility class for rematerializing values at a callsite
  433. class Rematerializer
  434. {
  435. public:
  436. Rematerializer(
  437. DenseMap<AllocaInst*, Instruction*>& allocaToVal,
  438. const InstructionSetVector& liveHere,
  439. const std::set<Value*>& resources
  440. )
  441. : m_allocaToVal(allocaToVal)
  442. , m_liveHere(liveHere)
  443. , m_resources(resources)
  444. {}
  445. // Returns true if inst can be rematerialized.
  446. bool canRematerialize(Instruction* inst)
  447. {
  448. if (CallInst* call = dyn_cast<CallInst>(inst))
  449. {
  450. StringRef funcName = call->getCalledFunction()->getName();
  451. if (funcName.startswith("dummyStackFrameSize"))
  452. return true;
  453. if (funcName.startswith("stack.ptr"))
  454. return true;
  455. if (funcName.startswith("stack.load"))
  456. return true;
  457. if (funcName.startswith("dx.op.createHandle"))
  458. return true;
  459. }
  460. else if (LoadInst* load = dyn_cast<LoadInst>(inst))
  461. {
  462. Value* op = load->getOperand(0);
  463. if (GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(op)) // for descriptor tables
  464. op = gep->getOperand(0);
  465. if (m_resources.count(op))
  466. return true;
  467. }
  468. else if (GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(inst))
  469. {
  470. assert(gep->hasAllConstantIndices() && "Unhandled non-constant index"); // Should have been changed to stack.ptr
  471. return true;
  472. }
  473. return false;
  474. }
  475. // Rematerialize the given instruction and its dependency graph, adding
  476. // any nonrematerializable values that are live in the function, but not
  477. // at this callsite to the work list to insure that their values are restored.
  478. Instruction* rematerialize(Instruction* inst, std::vector<Instruction *> workList, Instruction* insertBefore, int depth = 0)
  479. {
  480. // Signal if we hit a complex case. Deep rematerialization needs more analysis.
  481. // To make this robust we would need to make it possible to run the current
  482. // value through the live value handling pipeline: figure out where it is live,
  483. // reg2mem, save/restore at appropriate callsites, etc.
  484. assert(depth < 8);
  485. // Reuse an already rematerialized value?
  486. auto it = m_rematMap.find(inst);
  487. if (it != m_rematMap.end())
  488. return it->second;
  489. // Handle allocas
  490. if (AllocaInst* alloc = dyn_cast<AllocaInst>(inst))
  491. {
  492. assert(depth > 0); // Should only be an operand to another rematerialized value
  493. auto it = m_allocaToVal.find(alloc);
  494. if (it != m_allocaToVal.end()) // Is it a value that is live at some callsite (and reg2mem'd)?
  495. {
  496. Instruction* val = it->second;
  497. if (canRematerialize(val))
  498. {
  499. // Rematerialize here and store to the alloca. We may have already rematerialized a load
  500. // from the alloca. Any future uses will use the rematerialized value directly.
  501. Instruction* remat = rematerialize(val, workList, insertBefore, depth + 1);
  502. new StoreInst(remat, alloc, insertBefore);
  503. }
  504. else
  505. {
  506. // Value has to be restored, but it rematerialization may have extended
  507. // the liveness of this value to this callsite. Make sure it gets restored.
  508. if (!m_liveHere.count(val))
  509. workList.push_back(val);
  510. }
  511. }
  512. // Allocas are not cloned.
  513. return inst;
  514. }
  515. Instruction* clone = inst->clone();
  516. clone->setName(addSuffix(inst->getName(), ".remat"));
  517. for (unsigned i = 0; i < inst->getNumOperands(); ++i)
  518. {
  519. Value* op = inst->getOperand(i);
  520. if (Instruction* opInst = dyn_cast<Instruction>(op))
  521. clone->setOperand(i, rematerialize(opInst, workList, insertBefore, depth + 1));
  522. else
  523. clone->setOperand(i, op);
  524. }
  525. clone->insertBefore(insertBefore); // insert after any instructions cloned for operands
  526. m_rematMap[inst] = clone;
  527. return clone;
  528. }
  529. Instruction* getRematerializedValueFor(Instruction* val)
  530. {
  531. auto it = m_rematMap.find(val);
  532. if (it != m_rematMap.end())
  533. return it->second;
  534. else
  535. return nullptr;
  536. }
  537. private:
  538. DenseMap<Instruction*, Instruction*> m_rematMap; // Map instructions to their rematerialized counterparts
  539. DenseMap<AllocaInst*, Instruction*>& m_allocaToVal; // Map allocas for reg2mem'd live values back to the value
  540. const InstructionSetVector& m_liveHere; // Values live at this callsite
  541. const std::set<Value*>& m_resources; // Values for resources like SRVs, UAVs, etc.
  542. };
  543. StateFunctionTransform::StateFunctionTransform(Function* func, const std::vector<std::string>& candidateFuncNames, Type* runtimeDataArgTy)
  544. : m_function(func)
  545. , m_candidateFuncNames(candidateFuncNames)
  546. , m_runtimeDataArgTy(runtimeDataArgTy)
  547. {
  548. m_functionName = cleanName(m_function->getName());
  549. auto it = std::find(m_candidateFuncNames.begin(), m_candidateFuncNames.end(), m_functionName);
  550. assert(it != m_candidateFuncNames.end());
  551. m_functionIdx = it - m_candidateFuncNames.begin();
  552. }
  553. void StateFunctionTransform::setAttributeSize(int size)
  554. {
  555. m_attributeSizeInBytes = size;
  556. }
  557. void StateFunctionTransform::setParameterInfo(const std::vector<ParameterSemanticType>& paramTypes, bool useCommittedAttr)
  558. {
  559. m_paramTypes = paramTypes;
  560. m_useCommittedAttr = useCommittedAttr;
  561. }
  562. void StateFunctionTransform::setResourceGlobals(const std::set<llvm::Value*>& resources)
  563. {
  564. m_resources = &resources;
  565. }
  566. Function* StateFunctionTransform::createDummyRuntimeDataArgFunc(Module* module, Type* runtimeDataArgTy)
  567. {
  568. return FunctionBuilder(module, "dummyRuntimeDataArg").type(runtimeDataArgTy).build();
  569. }
  570. void StateFunctionTransform::setVerbose(bool val)
  571. {
  572. m_verbose = val;
  573. }
  574. void StateFunctionTransform::setDumpFilename(const std::string& dumpFilename)
  575. {
  576. m_dumpFilename = dumpFilename;
  577. }
  578. void StateFunctionTransform::run(std::vector<Function*>& stateFunctions, _Out_ unsigned int &shaderStackSize)
  579. {
  580. printFunction("Initial");
  581. init();
  582. printFunction("AfterInit");
  583. changeCallingConvention();
  584. printFunction("AfterCallingConvention");
  585. preserveLiveValuesAcrossCallsites(shaderStackSize);
  586. printFunction("AfterPreserveLiveValues");
  587. createSubstateFunctions(stateFunctions);
  588. printFunctions(stateFunctions, "AfterSubstateFunctions");
  589. lowerStackFuncs();
  590. printFunctions(stateFunctions, "AfterLowerStackFuncs");
  591. }
  592. void StateFunctionTransform::finalizeStateIds(llvm::Module* module, const std::vector<int>& candidateFuncEntryStateIds)
  593. {
  594. LLVMContext& context = module->getContext();
  595. Function* func = module->getFunction("dummyStateId");
  596. if (!func)
  597. return;
  598. std::vector<Instruction*> toRemove;
  599. for (User* U : func->users())
  600. {
  601. CallInst* call = dyn_cast<CallInst>(U);
  602. if (!call)
  603. continue;
  604. int functionIdx = 0;
  605. int substate = 0;
  606. getConstantValue(functionIdx, call->getArgOperand(0));
  607. getConstantValue(substate, call->getArgOperand(1));
  608. int stateId = candidateFuncEntryStateIds[functionIdx] + substate;
  609. call->replaceAllUsesWith(makeInt32(stateId, context));
  610. toRemove.push_back(call);
  611. }
  612. for (Instruction* v : toRemove)
  613. v->eraseFromParent();
  614. func->eraseFromParent();
  615. }
  616. void StateFunctionTransform::init()
  617. {
  618. Module* module = m_function->getParent();
  619. m_function->setName(cleanName(m_function->getName()));
  620. // Run preparatory passes
  621. runPasses(m_function, {
  622. //createBreakCriticalEdgesPass(),
  623. //createLoopSimplifyPass(),
  624. //createLCSSAPass(),
  625. createPromoteMemoryToRegisterPass()
  626. });
  627. // Make debugging a little easier by giving things names
  628. dbgNameUnnamedVals(m_function);
  629. findCallSitesIntrinsicsAndReturns();
  630. // Create a bunch of functions that we are going to need
  631. m_stackIntPtrFunc = FunctionBuilder(module, "stackIntPtr").i32Ptr().type(m_runtimeDataArgTy, "runtimeData").i32("baseOffset").i32("offset").build();
  632. Instruction* insertBefore = afterEntryBlockAllocas(m_function);
  633. Function* runtimeDataArgFunc = createDummyRuntimeDataArgFunc(module, m_runtimeDataArgTy);
  634. m_runtimeDataArg = CallInst::Create(runtimeDataArgFunc, "runtimeData", insertBefore);
  635. Function* stackFrameSizeFunc = FunctionBuilder(module, "dummyStackFrameSize").i32().build();
  636. m_stackFrameSizeVal = CallInst::Create(stackFrameSizeFunc, "stackFrame.size", insertBefore);
  637. // TODO only create the values that are actually needed
  638. Function* payloadOffsetFunc = FunctionBuilder(module, "payloadOffset").i32().type(m_runtimeDataArgTy, "runtimeData").build();
  639. m_payloadOffset = CallInst::Create(payloadOffsetFunc, { m_runtimeDataArg }, "payload.offset", insertBefore);
  640. Function* committedAttrOffsetFunc = FunctionBuilder(module, "committedAttrOffset").i32().type(m_runtimeDataArgTy, "runtimeData").build();
  641. m_committedAttrOffset = CallInst::Create(committedAttrOffsetFunc, { m_runtimeDataArg }, "committedAttr.offset", insertBefore);
  642. Function* pendingAttrOffsetFunc = FunctionBuilder(module, "pendingAttrOffset").i32().type(m_runtimeDataArgTy, "runtimeData").build();
  643. m_pendingAttrOffset = CallInst::Create(pendingAttrOffsetFunc, { m_runtimeDataArg }, "pendingAttr.offset", insertBefore);
  644. Function* stackFrameOffsetFunc = FunctionBuilder(module, "stackFrameOffset").i32().type(m_runtimeDataArgTy, "runtimeData").build();
  645. m_stackFrameOffset = CallInst::Create(stackFrameOffsetFunc, { m_runtimeDataArg }, "stackFrame.offset", insertBefore);
  646. // lower SetPendingAttr() now
  647. for (CallInst* call : m_setPendingAttrCalls)
  648. {
  649. // Get the current pending attribute offset. It can change when a hit is committed
  650. Instruction* insertBefore = call;
  651. Value* currentPendingAttrOffset = CallInst::Create(pendingAttrOffsetFunc, { m_runtimeDataArg }, "cur.pendingAttr.offset", insertBefore);
  652. Value* attr = call->getArgOperand(0);
  653. createStackStore(currentPendingAttrOffset, attr, 0, insertBefore);
  654. call->eraseFromParent();
  655. }
  656. }
  657. void StateFunctionTransform::findCallSitesIntrinsicsAndReturns()
  658. {
  659. // Create a map for log N lookup
  660. std::map<std::string, int> candidateFuncMap;
  661. for (int i = 0; i < (int)m_candidateFuncNames.size(); ++i)
  662. candidateFuncMap[m_candidateFuncNames[i]] = i;
  663. for (auto& I : inst_range(m_function))
  664. {
  665. if (CallInst* call = dyn_cast<CallInst>(&I))
  666. {
  667. StringRef calledFuncName = call->getCalledFunction()->getName();
  668. if (calledFuncName.startswith(SET_PENDING_ATTR_PREFIX))
  669. m_setPendingAttrCalls.push_back(call);
  670. else if (calledFuncName.startswith("movePayloadToStack"))
  671. m_movePayloadToStackCalls.push_back(call);
  672. else if (calledFuncName == CALL_INDIRECT_NAME)
  673. m_callSites.push_back(call);
  674. else
  675. {
  676. auto it = candidateFuncMap.find(cleanName(calledFuncName));
  677. if (it == candidateFuncMap.end())
  678. continue;
  679. assert(call->getCalledFunction()->getReturnType() == Type::getVoidTy(call->getContext()) && "Continuations with returns not supported");
  680. m_callSites.push_back(call);
  681. m_callSiteFunctionIdx.push_back(it->second);
  682. }
  683. }
  684. else if (ReturnInst* ret = dyn_cast<ReturnInst>(&I))
  685. {
  686. m_returns.push_back(ret);
  687. }
  688. }
  689. }
  690. void StateFunctionTransform::changeCallingConvention()
  691. {
  692. if (!m_callSites.empty() || m_attributeSizeInBytes >= 0)
  693. allocateStackFrame();
  694. if (m_attributeSizeInBytes >= 0)
  695. allocateTraceFrame();
  696. createArgFrames();
  697. changeFunctionSignature();
  698. }
  699. static bool isCallToStackPtr(Value* inst)
  700. {
  701. CallInst* call = dyn_cast<CallInst>(inst);
  702. if (call && call->getCalledFunction()->getName().startswith("stack.ptr"))
  703. return true;
  704. return false;
  705. }
  706. static void extendAllocaLifetimes(LiveValues& lv)
  707. {
  708. for (Instruction* inst : lv.getAllLiveValues())
  709. {
  710. if (!inst->getType()->isPointerTy())
  711. continue;
  712. if (isa<AllocaInst>(inst) || isCallToStackPtr(inst))
  713. continue;
  714. GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(inst);
  715. assert(gep && "Unhandled live pointer");
  716. Value* ptr = gep->getPointerOperand();
  717. if (isCallToStackPtr(ptr))
  718. continue;
  719. AllocaInst* alloc = dyn_cast<AllocaInst>(gep->getPointerOperand());
  720. assert(alloc && "GEP of non-alloca pointer");
  721. // TODO: We need to set indices of the uses of the gep, not the gep itself
  722. const LiveValues::Indices* gepIndices = lv.getIndicesWhereLive(gep);
  723. const LiveValues::Indices* allocIndices = lv.getIndicesWhereLive(alloc);
  724. if (!allocIndices || *allocIndices != *gepIndices)
  725. lv.setIndicesWhereLive(alloc, gepIndices);
  726. }
  727. }
  728. void StateFunctionTransform::preserveLiveValuesAcrossCallsites(_Out_ unsigned int &shaderStackSize)
  729. {
  730. if (m_callSites.empty())
  731. {
  732. // No stack frame. Nothing to do.
  733. rewriteDummyStackSize(0);
  734. return;
  735. }
  736. SetVector<Instruction*> stackOffsets;
  737. stackOffsets.insert(m_stackFrameOffset);
  738. if (m_payloadOffset && !m_payloadOffset->user_empty())
  739. stackOffsets.insert(m_payloadOffset);
  740. if (m_committedAttrOffset && !m_committedAttrOffset->user_empty())
  741. stackOffsets.insert(m_committedAttrOffset);
  742. if (m_pendingAttrOffset && !m_pendingAttrOffset->user_empty())
  743. stackOffsets.insert(m_pendingAttrOffset);
  744. // Do liveness analysis
  745. ArrayRef<Instruction*> instructions((Instruction**)m_callSites.data(), m_callSites.size());
  746. LiveValues lv(instructions);
  747. lv.run();
  748. // Make sure alloca lifetimes match their uses
  749. extendAllocaLifetimes(lv);
  750. // Make sure stack offsets get included
  751. for (auto o : stackOffsets)
  752. lv.setLiveAtAllIndices(o, true);
  753. // Add payload allocas, if any
  754. for (CallInst* call : m_movePayloadToStackCalls)
  755. {
  756. if (AllocaInst* payloadAlloca = dyn_cast<AllocaInst>(call->getArgOperand(0)))
  757. lv.setLiveAtAllIndices(payloadAlloca, true);
  758. }
  759. printSet(lv.getAllLiveValues(), "live values");
  760. //
  761. // Carve up the stack frame.
  762. //
  763. uint64_t offsetInBytes = 0;
  764. // ... argument frame
  765. offsetInBytes += m_maxCallerArgFrameSizeInBytes;
  766. // ... live allocas.
  767. Module* module = m_function->getParent();
  768. DataLayout DL(module);
  769. DenseMap<Instruction*, Instruction*> allocaToStack;
  770. Instruction* insertBefore = getInstructionAfter(m_stackFrameOffset);
  771. for (Instruction* inst : lv.getAllLiveValues())
  772. {
  773. AllocaInst* alloc = dyn_cast<AllocaInst>(inst);
  774. if (!alloc)
  775. continue;
  776. // Allocate a slot in the stack frame for the alloca
  777. offsetInBytes = align(offsetInBytes, inst, DL);
  778. Instruction* stackAlloca = createStackPtr(m_stackFrameOffset, alloc, offsetInBytes, insertBefore);
  779. alloc->replaceAllUsesWith(stackAlloca);
  780. allocaToStack[inst] = stackAlloca;
  781. offsetInBytes += DL.getTypeAllocSize(alloc->getAllocatedType());
  782. }
  783. lv.remapLiveValues(allocaToStack); // replace old allocas with stackAllocas
  784. for (auto& kv : allocaToStack)
  785. kv.first->eraseFromParent(); // delete old allocas
  786. // Set payload offsets now that they are all on the stack
  787. for (CallInst* call : m_movePayloadToStackCalls)
  788. {
  789. CallInst* payloadStackPtr = dyn_cast<CallInst>(call->getArgOperand(0));
  790. assert(payloadStackPtr->getCalledFunction()->getName().startswith("stack.ptr"));
  791. Value* baseOffset = payloadStackPtr->getArgOperand(0);
  792. Value* idx = payloadStackPtr->getArgOperand(1);
  793. Value* payloadOffset = BinaryOperator::Create(Instruction::Add, baseOffset, idx, "", call);
  794. call->replaceAllUsesWith(payloadOffset);
  795. payloadOffset->takeName(call);
  796. call->eraseFromParent();
  797. }
  798. //printFunction("AfterStackAllocas");
  799. // ... saves/restores for each call site
  800. // Create allocas for live values. This makes it easier to generate code because
  801. // we don't have to maintain the use-def chains of SSA form. We can just
  802. // load/store from/to the alloca for a particular value. A subsequent mem2reg
  803. // pass will rebuild the SSA form.
  804. DenseMap<Instruction*, AllocaInst*> valToAlloca;
  805. DenseMap<AllocaInst*, Instruction*> allocaToVal;
  806. for (Instruction* inst : lv.getAllLiveValues())
  807. reg2Mem(valToAlloca, allocaToVal, inst);
  808. //printFunction("AfterReg2Mem");
  809. uint64_t baseOffsetInBytes = offsetInBytes;
  810. uint64_t maxOffsetInBytes = offsetInBytes;
  811. for (size_t i = 0; i < m_callSites.size(); ++i)
  812. {
  813. offsetInBytes = baseOffsetInBytes;
  814. const InstructionSetVector& liveHere = lv.getLiveValues(i);
  815. std::vector<Instruction*> workList(liveHere.begin(), liveHere.end());
  816. std::set<Instruction*> visited;
  817. Rematerializer R(allocaToVal, liveHere, *m_resources);
  818. Instruction* saveInsertBefore = m_callSites[i];
  819. Instruction* restoreInsertBefore = getInstructionAfter(m_callSites[i]);
  820. Instruction* rematInsertBefore = nullptr; // create only if needed
  821. // Rematerialize stack offsets after the continuation before other restores
  822. for (Instruction* inst : stackOffsets)
  823. {
  824. visited.insert(inst);
  825. Instruction* remat = R.rematerialize(inst, workList, restoreInsertBefore);
  826. new StoreInst(remat, valToAlloca[inst], restoreInsertBefore);
  827. }
  828. Instruction* saveStackFrameOffset = new LoadInst(valToAlloca[m_stackFrameOffset], "stackFrame.offset", saveInsertBefore);
  829. Instruction* restoreStackFrameOffset = R.getRematerializedValueFor(m_stackFrameOffset);
  830. while (!workList.empty())
  831. {
  832. Instruction* inst = workList.back();
  833. workList.pop_back();
  834. if (!visited.insert(inst).second)
  835. continue;
  836. if (!R.canRematerialize(inst))
  837. {
  838. assert(!inst->getType()->isPointerTy() && "Can not save pointers");
  839. offsetInBytes = align(offsetInBytes, inst, DL);
  840. AllocaInst* alloca = valToAlloca[inst];
  841. Value* saveVal = new LoadInst(alloca, addSuffix(inst->getName(), ".save"), saveInsertBefore);
  842. createStackStore(saveStackFrameOffset, saveVal, offsetInBytes, saveInsertBefore);
  843. Value* restoreVal = createStackLoad(restoreStackFrameOffset, inst, offsetInBytes, restoreInsertBefore);
  844. new StoreInst(restoreVal, alloca, restoreInsertBefore);
  845. offsetInBytes += DL.getTypeAllocSize(inst->getType());
  846. }
  847. else if (R.getRematerializedValueFor(inst) == nullptr)
  848. {
  849. if (!rematInsertBefore)
  850. {
  851. // Create a new block after restores for rematerialized values. This
  852. // ensures that we can use restored values (through their allocas) even
  853. // if we haven't generated the actual restore yet.
  854. rematInsertBefore = restoreInsertBefore->getParent()->splitBasicBlock(restoreInsertBefore, "remat_begin")->begin();
  855. restoreInsertBefore = m_callSites[i]->getParent()->getTerminator();
  856. }
  857. Instruction* remat = R.rematerialize(inst, workList, rematInsertBefore);
  858. new StoreInst(remat, valToAlloca[inst], rematInsertBefore);
  859. }
  860. }
  861. // Take the max offset over all call sites
  862. maxOffsetInBytes = std::max(maxOffsetInBytes, offsetInBytes);
  863. }
  864. // ... traceFrame (if any)
  865. maxOffsetInBytes += m_traceFrameSizeInBytes;
  866. // Set the stack size
  867. rewriteDummyStackSize(maxOffsetInBytes);
  868. shaderStackSize = maxOffsetInBytes;
  869. }
  870. void StateFunctionTransform::createSubstateFunctions(std::vector<Function*>& stateFunctions)
  871. {
  872. // The runtime perf of split() depends on the number of blocks in the function.
  873. // Simplifying the CFG before the split helps reduce the cost of that operation.
  874. runPasses(m_function, {
  875. createCFGSimplificationPass()
  876. });
  877. stateFunctions.resize(m_callSites.size() + 1);
  878. BasicBlockVector substateEntryBlocks = replaceCallSites();
  879. for (size_t i = 0, e = stateFunctions.size(); i < e; ++i)
  880. {
  881. stateFunctions[i] = split(m_function, substateEntryBlocks[i], i);
  882. // Add an attribute so we can detect when an intrinsic is not being called
  883. // from a state function, and thus doesn't have access to the runtimeData pointer.
  884. stateFunctions[i]->addFnAttr("state_function", "true");
  885. }
  886. // Erase base function
  887. m_function->eraseFromParent();
  888. m_function = nullptr;
  889. }
  890. void StateFunctionTransform::allocateStackFrame()
  891. {
  892. Module* module = m_function->getParent();
  893. // Push stack frame in entry block.
  894. Instruction* insertBefore = m_stackFrameOffset;
  895. Function* stackFramePushFunc = FunctionBuilder(module, "stackFramePush").voidTy().type(m_runtimeDataArgTy, "runtimeData").i32("size").build();
  896. m_stackFramePush = CallInst::Create(stackFramePushFunc, { m_runtimeDataArg, m_stackFrameSizeVal }, "", insertBefore);
  897. // Pop the stack frame just before returns.
  898. Function* stackFramePop = FunctionBuilder(module, "stackFramePop").voidTy().type(m_runtimeDataArgTy, "runtimeData").i32("size").build();
  899. for (Instruction* insertBefore : m_returns)
  900. CallInst::Create(stackFramePop, { m_runtimeDataArg, m_stackFrameSizeVal }, "", insertBefore);
  901. }
  902. void StateFunctionTransform::allocateTraceFrame()
  903. {
  904. assert(m_attributeSizeInBytes >= 0 && "Attribute size has not been specified");
  905. m_traceFrameSizeInBytes =
  906. 2 * m_attributeSizeInBytes // committed and pending attributes
  907. + 2 * sizeof(int); // old committed/pending attribute offsets
  908. int attrSizeInInts = m_attributeSizeInBytes / sizeof(int);
  909. // Push the trace frame first thing so that the runtime
  910. // can do setup relative to the entry stack offset.
  911. Module* module = m_function->getParent();
  912. Instruction* insertBefore = afterEntryBlockAllocas(m_function);
  913. Value* attrSize = makeInt32(attrSizeInInts, module->getContext());
  914. Function* traceFramePushFunc = FunctionBuilder(module, "traceFramePush").voidTy().type(m_runtimeDataArgTy, "runtimeData").i32("attrSize").build();
  915. CallInst::Create(traceFramePushFunc, { m_runtimeDataArg, attrSize }, "", insertBefore);
  916. // Pop the trace frame just before returns.
  917. Function* traceFramePopFunc = FunctionBuilder(module, "traceFramePop").voidTy().type(m_runtimeDataArgTy, "runtimeData").build();
  918. for (Instruction* insertBefore : m_returns)
  919. CallInst::Create(traceFramePopFunc, { m_runtimeDataArg }, "", insertBefore);
  920. }
  921. bool isTemporaryAlloca(Value* op)
  922. {
  923. // TODO: Need to some analysis to figure this out. We can put the alloca on
  924. // the caller stack if:
  925. // there is only a single callsite OR
  926. // if no callsite between stores/loads and this callsite
  927. return true;
  928. }
  929. void StateFunctionTransform::createArgFrames()
  930. {
  931. Module* module = m_function->getParent();
  932. DataLayout DL(module);
  933. Instruction* stackAllocaInsertBefore = getInstructionAfter(m_stackFrameOffset);
  934. // Retrieve this function's arguments from the stack
  935. if (m_function->getFunctionType()->getNumParams() > 0)
  936. {
  937. if (m_paramTypes.empty())
  938. m_paramTypes.assign(m_function->getFunctionType()->getNumParams(), PST_NONE); // assume standard argument types
  939. static_assert(PST_COUNT == 3, "Expected 3 parameter semantic types");
  940. int offsetInBytes[PST_COUNT] = { 0, 0, 0 };
  941. Value* baseOffset[PST_COUNT] = { nullptr, nullptr, nullptr };
  942. Instruction* insertBefore = stackAllocaInsertBefore;
  943. for (auto pst : m_paramTypes)
  944. {
  945. if (baseOffset[pst])
  946. continue;
  947. if (pst == PST_NONE)
  948. {
  949. baseOffset[pst] = BinaryOperator::Create(Instruction::Add, m_stackFrameOffset, m_stackFrameSizeVal, "callerArgFrame.offset", insertBefore);
  950. offsetInBytes[pst] = sizeof(int); // skip the first element in caller arg frame (returnStateID)
  951. }
  952. else if (pst == PST_PAYLOAD)
  953. {
  954. baseOffset[pst] = m_payloadOffset;
  955. }
  956. else if (pst == PST_ATTRIBUTE)
  957. {
  958. baseOffset[pst] = (m_useCommittedAttr) ? m_committedAttrOffset : m_pendingAttrOffset;
  959. }
  960. else
  961. {
  962. assert(0 && "Bad parameter type");
  963. }
  964. }
  965. int argIdx = 0;
  966. for (auto& arg : m_function->args())
  967. {
  968. ParameterSemanticType pst = m_paramTypes[argIdx];
  969. Value* val = nullptr;
  970. if (arg.getType()->isPointerTy())
  971. {
  972. // Assume that pointed to memory is on the stack.
  973. val = createStackPtr(baseOffset[pst], &arg, offsetInBytes[pst], insertBefore);
  974. offsetInBytes[pst] += DL.getTypeAllocSize(arg.getType()->getPointerElementType());
  975. }
  976. else
  977. {
  978. val = createStackLoad(baseOffset[pst], &arg, offsetInBytes[pst], insertBefore);
  979. offsetInBytes[pst] += DL.getTypeAllocSize(arg.getType());
  980. }
  981. // Replace use of the argument with the loaded value
  982. if (arg.hasName())
  983. val->takeName(&arg);
  984. else
  985. val->setName("arg" + std::to_string(argIdx));
  986. arg.replaceAllUsesWith(val);
  987. argIdx++;
  988. }
  989. }
  990. // Process function arguments for each call site
  991. m_maxCallerArgFrameSizeInBytes = 0;
  992. for (size_t i = 0; i < m_callSites.size(); ++i)
  993. {
  994. int offsetInBytes = 0;
  995. CallInst* call = m_callSites[i];
  996. FunctionType* FT = call->getCalledFunction()->getFunctionType();
  997. StringRef calledFuncName = call->getCalledFunction()->getName();
  998. Instruction* insertBefore = call;
  999. // Set the return stateId (next substate of this function)
  1000. int nextSubstate = i + 1;
  1001. Value* nextStateId = getDummyStateId(m_functionIdx, nextSubstate, insertBefore);
  1002. createStackStore(m_stackFrameOffset, nextStateId, offsetInBytes, insertBefore);
  1003. offsetInBytes += DL.getTypeAllocSize(nextStateId->getType());
  1004. if (FT->getNumParams() && calledFuncName != CALL_INDIRECT_NAME)
  1005. {
  1006. for (unsigned index = 0; index < FT->getNumParams(); ++index)
  1007. {
  1008. // Save the argument from the argFrame
  1009. Value* op = call->getArgOperand(index);
  1010. Type* opTy = op->getType();
  1011. if (opTy->isPointerTy())
  1012. {
  1013. // TODO: Until we have callable shaders we should not get here except
  1014. // in tests.
  1015. if (isTemporaryAlloca(op))
  1016. {
  1017. // We can just replace the alloca with space in the arg frame
  1018. assert(isa<AllocaInst>(op));
  1019. Value* stackAlloca = createStackPtr(m_stackFrameOffset, op, offsetInBytes, stackAllocaInsertBefore);
  1020. op->replaceAllUsesWith(stackAlloca);
  1021. cast<AllocaInst>(op)->eraseFromParent();
  1022. }
  1023. else
  1024. {
  1025. // copy in/out
  1026. assert(0);
  1027. }
  1028. offsetInBytes += DL.getTypeAllocSize(opTy->getPointerElementType());
  1029. }
  1030. else
  1031. {
  1032. createStackStore(m_stackFrameOffset, op, offsetInBytes, insertBefore);
  1033. offsetInBytes += DL.getTypeAllocSize(opTy);
  1034. }
  1035. // Replace use of the argument with undef
  1036. call->setArgOperand(index, UndefValue::get(opTy));
  1037. }
  1038. }
  1039. if (offsetInBytes > m_maxCallerArgFrameSizeInBytes)
  1040. m_maxCallerArgFrameSizeInBytes = offsetInBytes;
  1041. }
  1042. }
  1043. void StateFunctionTransform::changeFunctionSignature()
  1044. {
  1045. // Create a new function that takes a state object pointer and returns next state ID
  1046. // and splice in the body of the old function into the new one.
  1047. Function* newFunc = FunctionBuilder(m_function->getParent(), m_functionName + "_tmp").i32().type(m_runtimeDataArgTy, "runtimeData").build();
  1048. newFunc->getBasicBlockList().splice(newFunc->begin(), m_function->getBasicBlockList());
  1049. m_function = newFunc;
  1050. // Set the runtime data pointer and remove the dummy function .
  1051. Value* runtimeDataArg = m_function->arg_begin();
  1052. replaceValAndRemoveUnusedDummyFunc(m_runtimeDataArg, runtimeDataArg, m_function);
  1053. m_runtimeDataArg = runtimeDataArg;
  1054. // Get return stateID from stack on each return.
  1055. LLVMContext& context = m_function->getContext();
  1056. Value* zero = makeInt32(0, context);
  1057. CallInst* retStackFrameOffset = m_stackFrameOffset;
  1058. for (ReturnInst*& ret : m_returns)
  1059. {
  1060. Instruction* insertBefore = ret;
  1061. if (m_stackFramePush)
  1062. retStackFrameOffset = CallInst::Create(m_stackFrameOffset->getCalledFunction(), { m_runtimeDataArg }, "ret.stackFrame.offset", insertBefore);
  1063. Instruction* returnStateIdPtr = CallInst::Create(m_stackIntPtrFunc, { m_runtimeDataArg, retStackFrameOffset, zero }, "ret.stateId.ptr", insertBefore);
  1064. Value* returnStateId = new LoadInst(returnStateIdPtr, "ret.stateId", insertBefore);
  1065. ReturnInst* newRet = ReturnInst::Create(context, returnStateId);
  1066. ReplaceInstWithInst(ret, newRet);
  1067. ret = newRet; // update reference
  1068. }
  1069. }
  1070. void StateFunctionTransform::rewriteDummyStackSize(uint64_t frameSizeInBytes)
  1071. {
  1072. assert(frameSizeInBytes % sizeof(int) == 0);
  1073. Value* frameSizeVal = makeInt32(frameSizeInBytes / sizeof(int), m_function->getContext());
  1074. replaceValAndRemoveUnusedDummyFunc(m_stackFrameSizeVal, frameSizeVal, m_function);
  1075. m_stackFrameSizeVal = frameSizeVal;
  1076. }
  1077. static inline Value* toIntIndex(int offsetInBytes, Value* baseOffset, Instruction* insertBefore)
  1078. {
  1079. assert(offsetInBytes % sizeof(int) == 0);
  1080. Value* intIndex = makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
  1081. if (baseOffset)
  1082. intIndex = BinaryOperator::Create(Instruction::Add, intIndex, baseOffset, "", insertBefore);
  1083. return intIndex;
  1084. }
  1085. void StateFunctionTransform::createStackStore(Value* baseOffset, Value* val, int offsetInBytes, Instruction* insertBefore)
  1086. {
  1087. assert(offsetInBytes % sizeof(int) == 0);
  1088. Value* intIndex = makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
  1089. Value* args[] = { val, baseOffset, intIndex };
  1090. Type* argTypes[] = { args[0]->getType(), args[1]->getType(), args[2]->getType() };
  1091. FunctionType* FT = FunctionType::get(Type::getVoidTy(val->getContext()), argTypes, false);
  1092. Function* F = getOrCreateFunction("stack.store", insertBefore->getModule(), FT, m_stackStoreFuncs);
  1093. CallInst::Create(F, args, "", insertBefore);
  1094. }
  1095. Instruction* StateFunctionTransform::createStackLoad(Value* baseOffset, Value* val, int offsetInBytes, Instruction* insertBefore)
  1096. {
  1097. assert(offsetInBytes % sizeof(int) == 0);
  1098. Value* intIndex = makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
  1099. Value* args[] = { baseOffset, intIndex };
  1100. Type* argTypes[] = { args[0]->getType(), args[1]->getType() };
  1101. FunctionType* FT = FunctionType::get(val->getType(), argTypes, false);
  1102. Function* F = getOrCreateFunction("stack.load", insertBefore->getModule(), FT, m_stackLoadFuncs);
  1103. return CallInst::Create(F, args, addSuffix(val->getName(), ".restore"), insertBefore);
  1104. }
  1105. Instruction* StateFunctionTransform::createStackPtr(Value* baseOffset, Type* valTy, Value* intIndex, Instruction* insertBefore)
  1106. {
  1107. Value* args[] = { baseOffset, intIndex };
  1108. Type* argTypes[] = { args[0]->getType(), args[1]->getType() };
  1109. FunctionType* FT = FunctionType::get(valTy, argTypes, false);
  1110. Function* F = getOrCreateFunction("stack.ptr", insertBefore->getModule(), FT, m_stackPtrFuncs);
  1111. CallInst* call = CallInst::Create(F, args, "", insertBefore);
  1112. return call;
  1113. }
  1114. Instruction* StateFunctionTransform::createStackPtr(Value* baseOffset, Value* val, int offsetInBytes, Instruction* insertBefore)
  1115. {
  1116. assert(offsetInBytes % sizeof(int) == 0);
  1117. Value* intIndex = makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
  1118. Instruction* ptr = createStackPtr(baseOffset, val->getType(), intIndex, insertBefore);
  1119. ptr->takeName(val);
  1120. return ptr;
  1121. }
  1122. static bool isStackIntPtr(Value* val)
  1123. {
  1124. CallInst* call = dyn_cast<CallInst>(val);
  1125. return call && call->getCalledFunction()->getName().startswith("stack.ptr");
  1126. }
  1127. // This code adapted from GetElementPtrInst::accumulateConstantOffset().
  1128. // TODO: Use a single function for both constant and dynamic offsets? Could do
  1129. // some constant folding along the way for dynamic offsets.
  1130. Value* accumulateDynamicOffset(GetElementPtrInst* gep, const DataLayout &DL)
  1131. {
  1132. LLVMContext& C = gep->getContext();
  1133. Instruction* insertBefore = gep;
  1134. Value* offset = makeInt32(0, C);
  1135. for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep); GTI != GTE; ++GTI)
  1136. {
  1137. ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand());
  1138. if (OpC && OpC->isZero())
  1139. continue;
  1140. // Handle a struct index, which adds its field offset to the pointer.
  1141. Value* elementOffset = nullptr;
  1142. if (StructType *STy = dyn_cast<StructType>(*GTI))
  1143. {
  1144. assert(OpC && "Structure indices must be constant");
  1145. unsigned ElementIdx = OpC->getZExtValue();
  1146. const StructLayout *SL = DL.getStructLayout(STy);
  1147. elementOffset = makeInt32(SL->getElementOffset(ElementIdx) / sizeof(int), C);
  1148. }
  1149. else
  1150. {
  1151. // For array or vector indices, scale the index by the size of the type.
  1152. Value* stride = makeInt32(DL.getTypeAllocSize(GTI.getIndexedType()) / sizeof(int), C);
  1153. elementOffset = BinaryOperator::Create(Instruction::Mul, GTI.getOperand(), stride, "elOffs", insertBefore);
  1154. }
  1155. offset = BinaryOperator::Create(Instruction::Add, offset, elementOffset, "offs", insertBefore);
  1156. }
  1157. return offset;
  1158. }
  1159. // Adds gep offset to offsetVal and returns the result
  1160. static Value* accumulateGepOffset(GetElementPtrInst* gep, Value* offsetVal)
  1161. {
  1162. Module* M = gep->getModule();
  1163. const DataLayout& DL = M->getDataLayout();
  1164. Value* elementOffsetVal = nullptr;
  1165. APInt constOffset(DL.getPointerSizeInBits(), 0);
  1166. if (gep->accumulateConstantOffset(DL, constOffset))
  1167. elementOffsetVal = makeInt32((int)constOffset.getZExtValue() / sizeof(int), M->getContext());
  1168. else
  1169. elementOffsetVal = accumulateDynamicOffset(gep, DL);
  1170. elementOffsetVal = BinaryOperator::Create(Instruction::Add, offsetVal, elementOffsetVal, "offs", gep);
  1171. return elementOffsetVal;
  1172. }
  1173. // Turn GEPs on a stack.ptr of aggregate type into stack.ptrs of scalar type
  1174. void StateFunctionTransform::flattenGepsOnValue(Value* val, Value* baseOffset, Value* offsetVal)
  1175. {
  1176. for (auto U = val->user_begin(), UE = val->user_end(); U != UE;)
  1177. {
  1178. User* user = *U++;
  1179. if (CallInst* call = dyn_cast<CallInst>(user))
  1180. {
  1181. // inline the call to expose GEPs and restart the loop.
  1182. InlineFunctionInfo IFI;
  1183. bool success = InlineFunction(call, IFI, false);
  1184. assert(success);
  1185. (void)success;
  1186. U = val->user_begin();
  1187. UE = val->user_end();
  1188. continue;
  1189. }
  1190. GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(user);
  1191. if (!gep)
  1192. continue;
  1193. Value* elementOffsetVal = accumulateGepOffset(gep, offsetVal);
  1194. Type* gepElTy = gep->getType()->getPointerElementType();
  1195. if (gepElTy->isAggregateType())
  1196. {
  1197. // flatten geps on this gep
  1198. flattenGepsOnValue(gep, baseOffset, elementOffsetVal);
  1199. }
  1200. else if (isa<VectorType>(gepElTy))
  1201. scalarizeVectorStackAccess(gep, baseOffset, elementOffsetVal);
  1202. else
  1203. {
  1204. Value* ptr = createStackPtr(baseOffset, gep->getType(), elementOffsetVal, gep);
  1205. ptr->takeName(gep); // could use a name that encodes the gep type and indices
  1206. gep->replaceAllUsesWith(ptr);
  1207. }
  1208. gep->eraseFromParent();
  1209. }
  1210. }
  1211. void StateFunctionTransform::scalarizeVectorStackAccess(Instruction* vecPtr, Value* baseOffset, Value* offsetVal)
  1212. {
  1213. std::vector<Value*> elPtrs;
  1214. Type* VTy = vecPtr->getType()->getPointerElementType();
  1215. Type* elTy = VTy->getVectorElementType();
  1216. LLVMContext& C = vecPtr->getContext();
  1217. Value* curOffsetVal = offsetVal;
  1218. Value* one = makeInt32(1, C);
  1219. offsetVal->setName("offs0.");
  1220. for (unsigned i = 0; i < VTy->getVectorNumElements(); ++i)
  1221. {
  1222. // TODO: If offsetVal is a constant we could just create constants instead of add instructions
  1223. if (i > 0)
  1224. curOffsetVal = BinaryOperator::Create(Instruction::Add, curOffsetVal, one, stringf("offs%d.", i), vecPtr);
  1225. elPtrs.push_back(createStackPtr(baseOffset, elTy->getPointerTo(), curOffsetVal, vecPtr));
  1226. elPtrs.back()->setName(addSuffix(vecPtr->getName(), stringf(".el%d.", i)));
  1227. }
  1228. // Scalarize load/stores
  1229. for (auto U = vecPtr->user_begin(), UE = vecPtr->user_end(); U != UE;)
  1230. {
  1231. User* user = *U++;
  1232. if (LoadInst* load = dyn_cast<LoadInst>(user))
  1233. {
  1234. Value* vec = UndefValue::get(VTy);
  1235. for (size_t i = 0; i < elPtrs.size(); ++i)
  1236. {
  1237. Value* el = new LoadInst(elPtrs[i], stringf("el%d.", i), load);
  1238. vec = InsertElementInst::Create(vec, el, makeInt32(i, C), "vec", load);
  1239. }
  1240. load->replaceAllUsesWith(vec);
  1241. load->eraseFromParent();
  1242. }
  1243. else if (StoreInst* store = dyn_cast<StoreInst>(user))
  1244. {
  1245. Value* vec = store->getOperand(0);
  1246. for (size_t i = 0; i < elPtrs.size(); ++i)
  1247. {
  1248. Value* el = ExtractElementInst::Create(vec, makeInt32(i, C), stringf("el%d.", i), store);
  1249. new StoreInst(el, elPtrs[i], store);
  1250. }
  1251. store->eraseFromParent();
  1252. }
  1253. else
  1254. {
  1255. assert(0 && "Unhandled user");
  1256. }
  1257. }
  1258. }
  1259. void StateFunctionTransform::lowerStackFuncs()
  1260. {
  1261. LLVMContext& C = m_stackIntPtrFunc->getContext();
  1262. const DataLayout& DL = m_stackIntPtrFunc->getParent()->getDataLayout();
  1263. // stack.store functions
  1264. for (auto& kv : m_stackStoreFuncs)
  1265. {
  1266. Function* F = kv.second;
  1267. for (auto U = F->user_begin(); U != F->user_end(); )
  1268. {
  1269. CallInst* call = dyn_cast<CallInst>(*(U++));
  1270. assert(call);
  1271. Value* runtimeDataArg = call->getParent()->getParent()->arg_begin();
  1272. Value* val = call->getArgOperand(0);
  1273. Value* offset = call->getArgOperand(1);
  1274. int idx = getConstantValue(call->getArgOperand(2));
  1275. Instruction* insertBefore = call;
  1276. if (isStackIntPtr(val))
  1277. {
  1278. // Copy from one part of the stack to another
  1279. CallInst* valCall = dyn_cast<CallInst>(val);
  1280. Value* srcOffset = valCall->getArgOperand(0);
  1281. int srcIdx = getConstantValue(valCall->getArgOperand(1));
  1282. Value* dstOffset = offset;
  1283. int dstIdx = idx;
  1284. int intCount = (int)DL.getTypeAllocSize(val->getType()->getPointerElementType()) / sizeof(int);
  1285. for (int i = 0; i < intCount; ++i)
  1286. {
  1287. std::string idxStr = stringf("%d.", i);
  1288. Value* srcPtr = CallInst::Create(m_stackIntPtrFunc, { runtimeDataArg, srcOffset, makeInt32(srcIdx + i, C) }, addSuffix(val->getName(), ".ptr" + idxStr), insertBefore);
  1289. Value* dstPtr = CallInst::Create(m_stackIntPtrFunc, { runtimeDataArg, dstOffset, makeInt32(dstIdx + i, C) }, "dst.ptr" + idxStr, insertBefore);
  1290. Value* intVal = new LoadInst(srcPtr, "copy.val" + idxStr, insertBefore);
  1291. new StoreInst(intVal, dstPtr, insertBefore);
  1292. }
  1293. }
  1294. else
  1295. {
  1296. store(val, m_stackIntPtrFunc, runtimeDataArg, offset, idx, insertBefore);
  1297. }
  1298. call->eraseFromParent();
  1299. }
  1300. F->eraseFromParent();
  1301. }
  1302. // stack.load functions
  1303. for (auto& kv : m_stackLoadFuncs)
  1304. {
  1305. Function* F = kv.second;
  1306. for (auto U = F->user_begin(); U != F->user_end(); )
  1307. {
  1308. CallInst* call = dyn_cast<CallInst>(*(U++));
  1309. assert(call);
  1310. std::string name = stripSuffix(call->getName(), ".restore");
  1311. call->setName("");
  1312. Value* runtimeDataArg = call->getParent()->getParent()->arg_begin();
  1313. Value* offset = call->getArgOperand(0);
  1314. Value* idx = call->getArgOperand(1);
  1315. Instruction* insertBefore = call;
  1316. Value* val = load(m_stackIntPtrFunc, runtimeDataArg, offset, idx, name, call->getType(), insertBefore);
  1317. call->replaceAllUsesWith(val);
  1318. call->eraseFromParent();
  1319. }
  1320. F->eraseFromParent();
  1321. }
  1322. // Scalarize accesses based on a stack.ptr func
  1323. for (auto& kv : m_stackPtrFuncs)
  1324. {
  1325. Function* F = kv.second;
  1326. if (!F->getReturnType()->getPointerElementType()->isAggregateType())
  1327. continue;
  1328. for (auto U = F->user_begin(), UE = F->user_end(); U != UE; )
  1329. {
  1330. CallInst* call = dyn_cast<CallInst>(*(U++));
  1331. assert(call);
  1332. Value* offset = call->getArgOperand(0);
  1333. Value* idx = call->getArgOperand(1);
  1334. flattenGepsOnValue(call, offset, idx);
  1335. call->eraseFromParent();
  1336. }
  1337. }
  1338. // stack.ptr functions
  1339. for (auto& kv : m_stackPtrFuncs)
  1340. {
  1341. Function* F = kv.second;
  1342. for (auto U = F->user_begin(); U != F->user_end(); )
  1343. {
  1344. CallInst* call = dyn_cast<CallInst>(*(U++));
  1345. assert(call);
  1346. std::string name = call->getName();
  1347. Value* runtimeDataArg = call->getParent()->getParent()->arg_begin();
  1348. Value* offset = call->getArgOperand(0);
  1349. Value* idx = call->getArgOperand(1);
  1350. Instruction* insertBefore = call;
  1351. Value* ptr = CallInst::Create(m_stackIntPtrFunc, { runtimeDataArg, offset, idx }, addSuffix(name, ".ptr"), insertBefore);
  1352. if (ptr->getType() != call->getType())
  1353. ptr = new BitCastInst(ptr, call->getType(), "", insertBefore);
  1354. ptr->takeName(call);
  1355. call->replaceAllUsesWith(ptr);
  1356. call->eraseFromParent();
  1357. }
  1358. F->eraseFromParent();
  1359. }
  1360. }
  1361. Function* StateFunctionTransform::split(Function* baseFunc, BasicBlock* substateEntryBlock, int substateIndex)
  1362. {
  1363. ValueToValueMapTy VMap;
  1364. Function* substateFunc = cloneBlocksReachableFrom(substateEntryBlock, VMap);
  1365. Module* module = baseFunc->getParent();
  1366. module->getFunctionList().push_back(substateFunc);
  1367. substateFunc->setName(m_functionName + ".ss_" + std::to_string(substateIndex));
  1368. if (substateIndex != 0)
  1369. {
  1370. // Collect allocas from entry block
  1371. SmallVector<Instruction*, 16> allocasToClone;
  1372. for (auto& I : baseFunc->getEntryBlock().getInstList())
  1373. {
  1374. if (isa<AllocaInst>(&I))
  1375. allocasToClone.push_back(&I);
  1376. }
  1377. // Clone collected allocas
  1378. BasicBlock* newEntryBlock = &substateFunc->getEntryBlock();
  1379. for (auto I : allocasToClone)
  1380. {
  1381. // Collect users of original instruction in substateFunc
  1382. std::vector<Instruction*> users;
  1383. for (auto U : I->users())
  1384. {
  1385. Instruction* inst = dyn_cast<Instruction>(U);
  1386. if (inst->getParent()->getParent() == substateFunc)
  1387. users.push_back(inst);
  1388. }
  1389. if (users.empty())
  1390. continue;
  1391. // Clone instruction
  1392. Instruction* clone = I->clone();
  1393. if (I->hasName())
  1394. clone->setName(I->getName());
  1395. clone->insertBefore(newEntryBlock->getFirstInsertionPt()); // allocas first in entry block
  1396. RemapInstruction(clone, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
  1397. // Replaces uses
  1398. for (auto user : users)
  1399. user->replaceUsesOfWith(I, clone);
  1400. }
  1401. }
  1402. //printFunction( substateFunc, substateFunc->getName().str() + "-BeforeSplittingOpt", m_dumpId++ );
  1403. makeReducible(substateFunc);
  1404. // Undo the reg2mem done in preserveLiveValuesAcrossCallSites()
  1405. runPasses(substateFunc, {
  1406. createVerifierPass(),
  1407. createPromoteMemoryToRegisterPass()
  1408. });
  1409. //printFunction( substateFunc, substateFunc->getName().str() + "-AfterSplitting", m_dumpId++ );
  1410. return substateFunc;
  1411. }
  1412. BasicBlockVector StateFunctionTransform::replaceCallSites()
  1413. {
  1414. LLVMContext& context = m_function->getContext();
  1415. BasicBlockVector substateEntryPoints{ &m_function->getEntryBlock() };
  1416. substateEntryPoints[0]->setName(m_functionName + ".BB0");
  1417. // Add other substates by splitting blocks at call sites.
  1418. for (size_t i = 0; i < m_callSites.size(); ++i)
  1419. {
  1420. CallInst* call = m_callSites[i];
  1421. BasicBlock* block = call->getParent();
  1422. StringRef calledFuncName = call->getCalledFunction()->getName();
  1423. BasicBlock* nextBlock =
  1424. block->splitBasicBlock(call->getNextNode(), m_functionName + ".BB" + std::to_string(i + 1) + ".from."
  1425. + cleanName(calledFuncName));
  1426. substateEntryPoints.push_back(nextBlock);
  1427. // Return state id for entry state of the function being called
  1428. Instruction* insertBefore = call;
  1429. Value* returnStateId = nullptr;
  1430. if (calledFuncName == CALL_INDIRECT_NAME)
  1431. returnStateId = call->getArgOperand(0);
  1432. else
  1433. returnStateId = getDummyStateId(m_callSiteFunctionIdx[i], 0, insertBefore);
  1434. ReplaceInstWithInst(call->getParent()->getTerminator(), ReturnInst::Create(context, returnStateId));
  1435. call->eraseFromParent();
  1436. }
  1437. return substateEntryPoints;
  1438. }
  1439. llvm::Value* StateFunctionTransform::getDummyStateId(int functionIdx, int substate, llvm::Instruction* insertBefore)
  1440. {
  1441. if (!m_dummyStateIdFunc)
  1442. {
  1443. Module* M = m_function->getParent();
  1444. m_dummyStateIdFunc = FunctionBuilder(M, "dummyStateId").i32().i32("functionIdx").i32("substate").build();
  1445. }
  1446. LLVMContext& context = insertBefore->getContext();
  1447. Value* functionIdxVal = makeInt32(functionIdx, context);
  1448. Value* substateVal = makeInt32(substate, context);
  1449. return CallInst::Create(m_dummyStateIdFunc, { functionIdxVal, substateVal }, "stateId", insertBefore);
  1450. }
  1451. raw_ostream& StateFunctionTransform::getOutputStream(const std::string functionName, const std::string& suffix, unsigned int dumpId)
  1452. {
  1453. if (m_dumpFilename.empty())
  1454. return DBGS();
  1455. const std::string filename = createDumpPath(m_dumpFilename, dumpId, suffix, functionName);
  1456. std::error_code errorCode;
  1457. raw_ostream* out = new raw_fd_ostream(filename, errorCode, sys::fs::OpenFlags::F_None);
  1458. if (errorCode)
  1459. {
  1460. DBGS() << "Failed to open " << filename << " for writing sft output. " << errorCode.message() << "\n";
  1461. delete out;
  1462. return DBGS();
  1463. }
  1464. return *out;
  1465. }
  1466. void StateFunctionTransform::printFunction(const Function* function, const std::string& suffix, unsigned int dumpId)
  1467. {
  1468. if (!m_verbose)
  1469. return;
  1470. raw_ostream& out = getOutputStream(m_functionName, suffix, dumpId);
  1471. out << "; ########################### " << suffix << "\n";
  1472. out << *function << "\n";
  1473. if (&out != &DBGS())
  1474. delete &out;
  1475. }
  1476. void StateFunctionTransform::printFunction(const std::string& suffix)
  1477. {
  1478. printFunction(m_function, suffix, m_dumpId++);
  1479. }
  1480. void StateFunctionTransform::printFunctions(const std::vector<Function*>& funcs, const char* suffix)
  1481. {
  1482. if (!m_verbose)
  1483. return;
  1484. raw_ostream& out = getOutputStream(m_functionName, suffix, m_dumpId++);
  1485. out << "; ########################### " << suffix << "\n";
  1486. for (Function* F : funcs)
  1487. out << *F << "\n";
  1488. if (&out != &DBGS())
  1489. delete &out;
  1490. }
  1491. void StateFunctionTransform::printModule(const Module* module, const std::string& suffix)
  1492. {
  1493. if (!m_verbose)
  1494. return;
  1495. raw_ostream& out = getOutputStream("module", suffix, m_dumpId++);
  1496. out << "; ########################### " << suffix << "\n";
  1497. out << *module << "\n";
  1498. }
  1499. void StateFunctionTransform::printSet(const InstructionSetVector& vals, const char* msg)
  1500. {
  1501. if (!m_verbose)
  1502. return;
  1503. raw_ostream& out = DBGS();
  1504. if (msg)
  1505. out << msg << " --------------------\n";
  1506. uint64_t totalBytes = 0;
  1507. if (vals.size() > 0)
  1508. {
  1509. Module* module = m_function->getParent();
  1510. DataLayout DL(module);
  1511. for (InstructionSetVector::const_iterator I = vals.begin(), IE = vals.end(); I != IE; ++I)
  1512. {
  1513. const Instruction* inst = *I;
  1514. uint64_t size = DL.getTypeAllocSize(inst->getType());
  1515. out << stringf("%3dB: ", size) << *inst << '\n';
  1516. totalBytes += size;
  1517. }
  1518. }
  1519. out << "Count:" << vals.size() << " Bytes:" << totalBytes << "\n\n";
  1520. }