DxrFallbackCompiler.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. #include "dxc/DxrFallback/DxrFallbackCompiler.h"
  2. #include "dxc/Support/Global.h"
  3. #include "dxc/Support/Unicode.h"
  4. #include "dxc/Support/WinIncludes.h"
  5. #include "dxc/Support/FileIOHelper.h"
  6. #include "dxc/dxcapi.h"
  7. #include "dxc/dxcdxrfallbackcompiler.h"
  8. #include "dxc/Support/dxcapi.use.h"
  9. #include "dxc/Support/dxcapi.impl.h"
  10. #include "dxc/DXIL/DxilModule.h"
  11. #include "dxc/HLSL/DxilLinker.h"
  12. #include "dxc/DXIL/DxilFunctionProps.h"
  13. #include "dxc/DXIL/DxilOperations.h"
  14. #include "dxc/DXIL/DxilInstructions.h"
  15. #include "llvm/Analysis/CallGraph.h"
  16. #include "llvm/IR/InstIterator.h"
  17. #include "llvm/IR/Instructions.h"
  18. #include "llvm/IR/IRBuilder.h"
  19. #include "llvm/IR/LegacyPassManager.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/Linker/Linker.h"
  22. #include "llvm/Transforms/IPO.h"
  23. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  24. #include "llvm/Transforms/Utils/Cloning.h"
  25. #include "FunctionBuilder.h"
  26. #include "LLVMUtils.h"
  27. #include "runtime.h"
  28. #include "StateFunctionTransform.h"
  29. #include <queue>
  30. using namespace hlsl;
  31. using namespace llvm;
  32. static std::vector<Function*> getFunctionsWithPrefix(Module* mod, const std::string& prefix)
  33. {
  34. std::vector<Function*> functions;
  35. for (auto F = mod->begin(), E = mod->end(); F != E; ++F)
  36. {
  37. StringRef name = F->getName();
  38. if (name.startswith(prefix))
  39. functions.push_back(F);
  40. }
  41. return functions;
  42. }
  43. static bool inlineFunc(CallInst* call, Function* Fimpl)
  44. {
  45. // Note. LLVM inlining may not be sufficient if the function references DX
  46. // resources because the corresponding metadata is not created if the function
  47. // comes from another module.
  48. // Make sure that we have a definition for the called function in this module
  49. Function* F = call->getCalledFunction();
  50. Module* dstM = F->getParent();
  51. if (F->isDeclaration())
  52. {
  53. // Map called functions in impl module to functions in this one (because the
  54. // cloning step doesn't do this automatically)
  55. ValueToValueMapTy VMap;
  56. for (auto& I : inst_range(Fimpl))
  57. {
  58. if (CallInst* c = dyn_cast<CallInst>(&I))
  59. {
  60. Function* calledFimpl = c->getCalledFunction();
  61. if (VMap.count(calledFimpl))
  62. continue;
  63. Constant* calledF = dstM->getOrInsertFunction(calledFimpl->getName(), calledFimpl->getFunctionType(), calledFimpl->getAttributes());
  64. VMap[calledFimpl] = calledF;
  65. }
  66. }
  67. // Map arguments
  68. for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(), DI = F->arg_begin(); SI != SE; ++SI, ++DI)
  69. VMap[SI] = DI;
  70. SmallVector<ReturnInst*, 4> returns;
  71. CloneFunctionInto(F, Fimpl, VMap, true, returns);
  72. F->setLinkage(GlobalValue::InternalLinkage);
  73. }
  74. InlineFunctionInfo IFI;
  75. return InlineFunction(call, IFI, false);
  76. }
  77. // Remove ELF mangling
  78. static std::string cleanName(StringRef name)
  79. {
  80. if (!name.startswith("\x1?"))
  81. return name;
  82. size_t pos = name.find("@@");
  83. if (pos == name.npos)
  84. return name;
  85. std::string newName = name.substr(2, pos - 2);
  86. return newName;
  87. }
  88. static inline Function* getOrInsertFunction(Module* mod, Function* F)
  89. {
  90. return dyn_cast<Function>(mod->getOrInsertFunction(F->getName(), F->getFunctionType()));
  91. }
  92. template<typename K, typename V>
  93. V get(std::map<K, V>& theMap, const K& key, V defaultVal = static_cast<V>(nullptr))
  94. {
  95. auto it = theMap.find(key);
  96. if (it == theMap.end())
  97. return defaultVal;
  98. else
  99. return it->second;
  100. }
  101. DxrFallbackCompiler::DxrFallbackCompiler(llvm::Module* mod, const std::vector<std::string>& shaderNames, unsigned maxAttributeSize, unsigned stackSizeInBytes, bool findCalledShaders /*= false*/)
  102. : m_module(mod)
  103. , m_entryShaderNames(shaderNames)
  104. , m_stackSizeInBytes(stackSizeInBytes)
  105. , m_maxAttributeSize(maxAttributeSize)
  106. , m_findCalledShaders(findCalledShaders)
  107. {}
  108. void DxrFallbackCompiler::compile(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
  109. {
  110. std::vector<std::string> shaderNames = m_entryShaderNames;
  111. initShaderMap(shaderNames);
  112. // Bring in runtime so we can get the runtime data type
  113. linkRuntime();
  114. Type* runtimeDataArgTy = getRuntimeDataArgType();
  115. // Make sure all calls to intrinsics and shaders are at function scope and
  116. // fix up control flow.
  117. lowerAnyHitControlFlowFuncs();
  118. lowerReportHit();
  119. lowerTraceRay(runtimeDataArgTy);
  120. // Create state functions
  121. IntToFuncMap stateFunctionMap; // stateID -> state function
  122. const int baseStateId = 1000; // could be anything but this makes stateIds more recognizable
  123. createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes, baseStateId, shaderNames, runtimeDataArgTy);
  124. if (pCachedMap)
  125. {
  126. for (auto &entry : stateFunctionMap)
  127. {
  128. (*pCachedMap)[entry.first] = entry.second->getName().str();
  129. }
  130. }
  131. }
  132. void DxrFallbackCompiler::link(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
  133. {
  134. IntToFuncMap stateFunctionMap; // stateID -> state function
  135. if (pCachedMap)
  136. {
  137. for (auto entry : *pCachedMap)
  138. {
  139. stateFunctionMap[entry.first] = m_module->getFunction(entry.second);
  140. }
  141. }
  142. else
  143. {
  144. for (UINT i = 0; i < shaderEntryStateIds.size(); i++)
  145. {
  146. UINT substateIndex = 0;
  147. UINT baseStateId = shaderEntryStateIds[i];
  148. while (true)
  149. {
  150. auto substateName = m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex);
  151. auto function = m_module->getFunction(substateName);
  152. if (!function) break;
  153. stateFunctionMap[baseStateId + substateIndex] = m_module->getFunction(substateName);
  154. substateIndex++;
  155. }
  156. }
  157. }
  158. // Fix up scheduler
  159. Function* schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler");
  160. createLaunchParams(schedulerFunc);
  161. Type* runtimeDataArgTy = getRuntimeDataArgType();
  162. createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy);
  163. createStack(schedulerFunc);
  164. lowerIntrinsics();
  165. }
  166. void DxrFallbackCompiler::setDebugOutputLevel(int val)
  167. {
  168. m_debugOutputLevel = val;
  169. }
  170. static bool isShader(Function* F)
  171. {
  172. if (F->hasFnAttribute("exp-shader"))
  173. return true;
  174. DxilModule& DM = F->getParent()->GetDxilModule();
  175. return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay());
  176. }
  177. DXIL::ShaderKind getRayShaderKind(Function* F)
  178. {
  179. if (F->hasFnAttribute("exp-shader"))
  180. return DXIL::ShaderKind::RayGeneration;
  181. DxilModule& DM = F->getParent()->GetDxilModule();
  182. if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
  183. return DM.GetDxilFunctionProps(F).shaderKind;
  184. return DXIL::ShaderKind::Invalid;
  185. }
  186. // Some shaders should use the "pending" values of intrinsics instead of the
  187. // committed ones. In particular anyhit and intersection shaders use the
  188. // pending values with the exception that the committed rayTCurrent should be
  189. // used in intersection.
  190. static bool shouldUsePendingValue(Function* F, StringRef instrinsicName)
  191. {
  192. DxilModule& DM = F->getParent()->GetDxilModule();
  193. if (!DM.HasDxilFunctionProps(F))
  194. return false;
  195. const hlsl::DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
  196. return props.IsAnyHit() || (props.IsIntersection() && instrinsicName != "rayTCurrent");
  197. }
  198. void DxrFallbackCompiler::initShaderMap(std::vector<std::string>& shaderNames)
  199. {
  200. // Clean names and initialize shaderMap
  201. StringToFuncMap allShadersMap;
  202. for (Function& F : m_module->functions())
  203. {
  204. if (isShader(&F))
  205. {
  206. if (!F.isDeclaration())
  207. allShadersMap[cleanName(F.getName())] = &F;
  208. }
  209. F.removeFnAttr(Attribute::NoInline);
  210. }
  211. for (auto& name : shaderNames)
  212. m_shaderMap[name] = allShadersMap[name];
  213. if (!m_findCalledShaders)
  214. return;
  215. // Create a map from shader name to CallGraphNode
  216. CallGraph callGraph(*m_module);
  217. std::map<std::string, CallGraphNode*> allShaderNodes;
  218. for (auto& kv : m_shaderMap)
  219. {
  220. const std::string& name = kv.first;
  221. Function* func = kv.second;
  222. allShaderNodes[name] = callGraph[func];
  223. }
  224. // Start traversing the call graph from given shaderNames
  225. std::deque<CallGraphNode*> workList;
  226. for (auto& name : shaderNames)
  227. workList.push_back(allShaderNodes[name]);
  228. while (!workList.empty())
  229. {
  230. CallGraphNode* cur = workList.front();
  231. workList.pop_front();
  232. for (size_t i = 0; i < cur->size(); ++i)
  233. {
  234. Function* nextFunc = (*cur)[i]->getFunction();
  235. if (!nextFunc)
  236. continue;
  237. if (isShader(nextFunc))
  238. {
  239. const std::string nextName = cleanName(nextFunc->getName());
  240. if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet?
  241. {
  242. workList.push_back(allShaderNodes[nextName]);
  243. shaderNames.push_back(nextName);
  244. m_shaderMap[nextName] = workList.back()->getFunction();
  245. }
  246. }
  247. }
  248. }
  249. }
  250. void DxrFallbackCompiler::linkRuntime()
  251. {
  252. Linker linker(m_module);
  253. std::unique_ptr<Module> runtimeModule = loadModuleFromAsmString(m_module->getContext(), getRuntimeString());
  254. bool linkErr = linker.linkInModule(runtimeModule.get());
  255. assert(!linkErr && "Error linking runtime");
  256. UNREFERENCED_PARAMETER(linkErr);
  257. }
  258. static void inlineFuncAndAddRet(CallInst* call, Function*F)
  259. {
  260. // Add a return after the function call.
  261. // Should be followed immediately by "unreachable". Turn that into a "ret void".
  262. Instruction* ret = ReturnInst::Create(call->getContext());
  263. ReplaceInstWithInst(call->getParent()->getTerminator(), ret);
  264. bool success = inlineFunc(call, F);
  265. assert(success);
  266. UNREFERENCED_PARAMETER(success);
  267. }
  268. void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs()
  269. {
  270. std::vector<CallInst*> callsToIgnoreHit = getCallsInShadersToFunction("dx.op.ignoreHit");
  271. if (!callsToIgnoreHit.empty())
  272. {
  273. Function* ignoreHitFunc = m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ");
  274. assert(ignoreHitFunc && "IgnoreHit() implementation not found");
  275. for (CallInst* call : callsToIgnoreHit)
  276. inlineFuncAndAddRet(call, ignoreHitFunc);
  277. }
  278. std::vector<CallInst*> callsToAcceptHitAndEndSearch = getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch");
  279. if (!callsToAcceptHitAndEndSearch.empty())
  280. {
  281. Function* acceptHitAndEndSearchFunc = m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ");
  282. assert(acceptHitAndEndSearchFunc && "AcceptHitAndEndSearch() implementation not found");
  283. for (CallInst* call : callsToAcceptHitAndEndSearch)
  284. inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc);
  285. }
  286. }
  287. void DxrFallbackCompiler::lowerReportHit()
  288. {
  289. std::vector<CallInst*> callsToReportHit = getCallsInShadersToFunctionWithPrefix("dx.op.reportHit");
  290. if (callsToReportHit.empty())
  291. return;
  292. Function* reportHitFunc = m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z");
  293. assert(reportHitFunc && "ReportHit() implementation not found");
  294. LLVMContext& C = m_module->getContext();
  295. for (CallInst* call : callsToReportHit)
  296. {
  297. // Wrap attribute arguments in Fallback_SetPendingAttr() call
  298. Instruction* insertBefore = call;
  299. hlsl::DxilInst_ReportHit reportHitCall(call);
  300. Value* attr = reportHitCall.get_Attributes();
  301. Function* setPendingAttrFunc = FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@").voidTy().type(attr->getType(), "attr").build();
  302. CallInst::Create(setPendingAttrFunc, { attr }, "", insertBefore);
  303. // Make call to implementation and load result
  304. CallInst* callImpl = CallInst::Create(reportHitFunc, { reportHitCall.get_THit(), reportHitCall.get_HitKind() }, "reportHit.result", insertBefore);
  305. Value* result = callImpl;
  306. // Result < 0 ==> ret
  307. Value* zero = makeInt32(0, C);
  308. Value* ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero, "endSearch");
  309. BasicBlock* prevBlock = call->getParent();
  310. BasicBlock* retBlock = prevBlock->splitBasicBlock(call, "endSearch");
  311. BasicBlock* nextBlock = retBlock->splitBasicBlock(call, "afterReportHit");
  312. ReplaceInstWithInst(prevBlock->getTerminator(), BranchInst::Create(retBlock, nextBlock, ltz));
  313. ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C));
  314. // Compare result to zero and store into original result
  315. Value* gtz = new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted");
  316. call->replaceAllUsesWith(gtz);
  317. bool success = inlineFunc(callImpl, reportHitFunc);
  318. assert(success);
  319. (void)success;
  320. call->eraseFromParent();
  321. }
  322. }
  323. void DxrFallbackCompiler::lowerTraceRay(Type* runtimeDataArgTy)
  324. {
  325. std::vector<CallInst*> callsToTraceRay = getCallsInShadersToFunctionWithPrefix("dx.op.traceRay");
  326. if (callsToTraceRay.empty())
  327. {
  328. // TODO: It might be worth dropping this from the tests eventually
  329. callsToTraceRay = getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@");
  330. if (callsToTraceRay.empty())
  331. return;
  332. }
  333. std::vector<Function*> traceRayImpl = getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@");
  334. assert(traceRayImpl.size() == 1 && "Could not find Fallback_TraceRay() implementation");
  335. enum { CLOSEST_HIT = 0, MISS = 1 };
  336. Function* traceRaySave[] = { m_module->getFunction("traceRaySave_ClosestHit"), m_module->getFunction("traceRaySave_Miss") };
  337. Function* traceRayRestore[] = { m_module->getFunction("traceRayRestore_ClosestHit"), m_module->getFunction("traceRayRestore_Miss") };
  338. assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] && traceRaySave[MISS] && traceRayRestore[MISS] &&
  339. "Could not find TraceRay spill functions");
  340. Function* dummyRuntimeDataArgFunc = StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module, runtimeDataArgTy);
  341. assert(dummyRuntimeDataArgFunc && "dummyRuntimeDataArg function could not be created.");
  342. // Process calls
  343. LLVMContext& C = m_module->getContext();
  344. Type* int32Ty = Type::getInt32Ty(C);
  345. std::map<FunctionType*, Function*> movePayloadToStackFuncs;
  346. std::map<Function*, AllocaInst*> funcToSpillAlloca;
  347. for (CallInst* call : callsToTraceRay)
  348. {
  349. Instruction* insertBefore = call;
  350. // Spill runtime data values, if necessary (closesthit and miss shaders)
  351. Function* caller = call->getParent()->getParent();
  352. DXIL::ShaderKind kind = getRayShaderKind(caller);
  353. if (kind == DXIL::ShaderKind::ClosestHit || kind == DXIL::ShaderKind::Miss)
  354. {
  355. int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS;
  356. AllocaInst* spillAlloca = get(funcToSpillAlloca, caller);
  357. if (!spillAlloca)
  358. {
  359. Argument* spillAllocaArg = (++traceRaySave[sh]->arg_begin());
  360. Type* spillAllocaTy = spillAllocaArg->getType()->getPointerElementType();
  361. spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca", caller->getEntryBlock().begin());
  362. funcToSpillAlloca[caller] = spillAlloca;
  363. }
  364. // Create calls. SFT will inline them.
  365. Value* runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc, "runtimeData", insertBefore);
  366. CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "", insertBefore);
  367. CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "", getInstructionAfter(call));
  368. }
  369. // Get the payload offset to pass to trace implementation
  370. //hlsl::DxilInst_TraceRay traceRayCall(call);
  371. // TODO: Avoiding the intrinsic to support the test's use of TraceRayTest
  372. Value* payload = call->getOperand(call->getNumArgOperands() - 1);
  373. FunctionType* funcType = FunctionType::get(int32Ty, { payload->getType() }, false);
  374. Function* movePayloadToStackFunc = getOrCreateFunction("movePayloadToStack", m_module, funcType, movePayloadToStackFuncs);
  375. Value* newPayloadOffset = CallInst::Create(movePayloadToStackFunc, { payload }, "new.payload.offset", insertBefore);
  376. // Call implementation
  377. unsigned i = 0;
  378. if (call->getCalledFunction()->getName().startswith("dx.op"))
  379. i += 2; // skip intrinsic number and acceleration structure (for now)
  380. std::vector<Value*> args;
  381. for (; i < call->getNumArgOperands() - 1; ++i)
  382. args.push_back(call->getArgOperand(i));
  383. args.push_back(newPayloadOffset);
  384. CallInst::Create(traceRayImpl[0], args, "", insertBefore);
  385. call->eraseFromParent();
  386. }
  387. }
  388. static std::vector<StateFunctionTransform::ParameterSemanticType> getParameterTypes(Function* F, DXIL::ShaderKind shaderKind)
  389. {
  390. std::vector<StateFunctionTransform::ParameterSemanticType> paramTypes;
  391. if (shaderKind == DXIL::ShaderKind::AnyHit || shaderKind == DXIL::ShaderKind::ClosestHit)
  392. {
  393. paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
  394. paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE);
  395. }
  396. else if (shaderKind == DXIL::ShaderKind::Miss)
  397. {
  398. paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
  399. }
  400. else
  401. {
  402. paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE);
  403. }
  404. return paramTypes;
  405. }
  406. static void collectResources(DxilModule& DM, std::set<Value*>& resources)
  407. {
  408. for (auto& r : DM.GetCBuffers())
  409. resources.insert(r->GetGlobalSymbol());
  410. for (auto& r : DM.GetUAVs())
  411. resources.insert(r->GetGlobalSymbol());
  412. for (auto& r : DM.GetSRVs())
  413. resources.insert(r->GetGlobalSymbol());
  414. for (auto& r : DM.GetSamplers())
  415. resources.insert(r->GetGlobalSymbol());
  416. }
  417. void DxrFallbackCompiler::createStateFunctions(
  418. IntToFuncMap& stateFunctionMap,
  419. std::vector<int>& shaderEntryStateIds,
  420. std::vector<unsigned int>& shaderStackSizes,
  421. int baseStateId,
  422. const std::vector<std::string>& shaderNames,
  423. Type* runtimeDataArgTy
  424. )
  425. {
  426. for (auto& kv : m_shaderMap)
  427. {
  428. if (kv.second == nullptr)
  429. errs() << "Function not found for shader " << kv.first << "\n";
  430. }
  431. DxilModule& DM = m_module->GetOrCreateDxilModule();
  432. std::set<Value*> resources;
  433. collectResources(DM, resources);
  434. shaderEntryStateIds.clear();
  435. shaderStackSizes.clear();
  436. int stateId = baseStateId;
  437. for (auto& shader : shaderNames)
  438. {
  439. std::vector<Function*> stateFunctions;
  440. Function* F = m_shaderMap[shader];
  441. StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy);
  442. if (m_debugOutputLevel >= 2)
  443. sft.setVerbose(true);
  444. if (m_debugOutputLevel >= 3)
  445. sft.setDumpFilename("dump.ll");
  446. if (shader == "Fallback_TraceRay")
  447. sft.setAttributeSize(m_maxAttributeSize);
  448. DXIL::ShaderKind shaderKind = getRayShaderKind(F);
  449. if (shaderKind != DXIL::ShaderKind::Invalid)
  450. sft.setParameterInfo(getParameterTypes(F, shaderKind), shaderKind == DXIL::ShaderKind::ClosestHit);
  451. sft.setResourceGlobals(resources);
  452. UINT shaderStackSize = 0;
  453. sft.run(stateFunctions, shaderStackSize);
  454. shaderEntryStateIds.push_back(stateId);
  455. shaderStackSizes.push_back(shaderStackSize);
  456. for (Function* stateF : stateFunctions)
  457. {
  458. stateFunctionMap[stateId++] = stateF;
  459. if (DM.HasDxilFunctionProps(F)) {
  460. DM.CloneDxilEntryProps(F, stateF);
  461. }
  462. }
  463. }
  464. StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds);
  465. }
  466. void DxrFallbackCompiler::createLaunchParams(Function* func)
  467. {
  468. Module* mod = func->getParent();
  469. Function* rewrite_setLaunchParams = mod->getFunction("rewrite_setLaunchParams");
  470. CallInst* call = dyn_cast<CallInst>(*rewrite_setLaunchParams->user_begin());
  471. LLVMContext& context = mod->getContext();
  472. Instruction* insertBefore = call;
  473. Function* DTidFunc = FunctionBuilder(mod, "dx.op.threadId.i32").i32().i32().i32().build();
  474. Value* DTidx = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(0, context) }, "DTidx", insertBefore);
  475. Value* DTidy = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(1, context) }, "DTidy", insertBefore);
  476. Value* dimx = call->getArgOperand(1);
  477. Value* dimy = call->getArgOperand(2);
  478. Function* groupIndexFunc = FunctionBuilder(mod, "dx.op.flattenedThreadIdInGroup.i32").i32().i32().build();
  479. Value* groupIndex = CallInst::Create(groupIndexFunc, { makeInt32(96, context) }, "groupIndex", insertBefore);
  480. Function* fb_setLaunchParams = mod->getFunction("fb_Fallback_SetLaunchParams");
  481. Value* runtimeDataArg = call->getArgOperand(0);
  482. CallInst::Create(fb_setLaunchParams, { runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex }, "", insertBefore);
  483. call->eraseFromParent();
  484. rewrite_setLaunchParams->eraseFromParent();
  485. }
  486. void DxrFallbackCompiler::createStateDispatch(Function* func, const IntToFuncMap& stateFunctionMap, Type* runtimeDataArgTy)
  487. {
  488. Module* mod = func->getParent();
  489. Function* dispatchFunc = createDispatchFunction(stateFunctionMap, runtimeDataArgTy);
  490. Function* rewrite_dispatchFunc = mod->getFunction("rewrite_dispatch");
  491. rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc);
  492. rewrite_dispatchFunc->eraseFromParent();
  493. }
  494. void DxrFallbackCompiler::createStack(Function* func)
  495. {
  496. LLVMContext& context = func->getContext();
  497. // We would like to allocate the properly sized stack here, but DXIL doesn't
  498. // allow bitcasts between objects of different sizes. So we have to use the
  499. // default size from the runtime and replace all the accesses later.
  500. Function* rewrite_createStack = m_module->getFunction("rewrite_createStack");
  501. CallInst* call = dyn_cast<CallInst>(*rewrite_createStack->user_begin());
  502. AllocaInst* stack = new AllocaInst(call->getType()->getPointerElementType(), "theStack", call);
  503. stack->setAlignment(sizeof(int));
  504. call->replaceAllUsesWith(stack);
  505. call->eraseFromParent();
  506. rewrite_createStack->eraseFromParent();
  507. if (m_stackSizeInBytes == 0) // Take the default
  508. m_stackSizeInBytes = stack->getType()->getPointerElementType()->getArrayNumElements() * sizeof(int);
  509. Function* rewrite_getStackSize = m_module->getFunction("rewrite_getStackSize");
  510. call = dyn_cast<CallInst>(*rewrite_getStackSize->user_begin());
  511. Value* stackSizeVal = makeInt32(m_stackSizeInBytes, context);
  512. call->replaceAllUsesWith(stackSizeVal);
  513. call->eraseFromParent();
  514. rewrite_getStackSize->eraseFromParent();
  515. }
  516. // WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime
  517. static bool expandFloat3(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
  518. {
  519. VectorType* argTy = dyn_cast<VectorType>(arg->getType());
  520. if (!argTy || argTy->getVectorNumElements() != 3)
  521. return false;
  522. LLVMContext& C = arg->getContext();
  523. args.push_back(ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore));
  524. args.push_back(ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore));
  525. args.push_back(ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore));
  526. return true;
  527. }
  528. static bool float3x4ToFloat12(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
  529. {
  530. StructType* STy = dyn_cast<StructType>(arg->getType());
  531. if (!STy || STy->getName() != "class.matrix.float.3.4")
  532. return false;
  533. BasicBlock& entryBlock = insertBefore->getParent()->getParent()->getEntryBlock();
  534. AllocaInst* alloca = new AllocaInst(arg->getType(), "tmp", entryBlock.begin());
  535. new StoreInst(arg, alloca, insertBefore);
  536. VectorType* VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12);
  537. Value* vec12Ptr = new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore);
  538. Value* vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore);
  539. args.push_back(vec12);
  540. return true;
  541. }
  542. void DxrFallbackCompiler::lowerIntrinsics()
  543. {
  544. std::vector<Function*> intrinsics = getFunctionsWithPrefix(m_module, "fb_");
  545. assert(intrinsics.size() > 0);
  546. // Replace intrinsics in anyhit shaders with their pending versions
  547. LLVMContext& C = m_module->getContext();
  548. std::map<std::string, Function*> pendingIntrinsics;
  549. std::string pendingPrefixes[] = { "fb_dxop_pending_", "fb_Fallback_Pending" };
  550. for (auto& F : intrinsics)
  551. {
  552. std::string intrinsicName;
  553. if (F->getName().startswith(pendingPrefixes[0]))
  554. intrinsicName = F->getName().substr(pendingPrefixes[0].length());
  555. else if (F->getName().startswith(pendingPrefixes[1]))
  556. intrinsicName = "Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str();
  557. else
  558. continue;
  559. pendingIntrinsics[intrinsicName] = F;
  560. }
  561. for (Function* func : intrinsics)
  562. {
  563. StringRef intrinsicName;
  564. std::string name;
  565. bool isDxilOp = false;
  566. if (func->getName().startswith("fb_Fallback_"))
  567. {
  568. intrinsicName = func->getName().substr(3); // after the "fb_" prefix
  569. name = "\x1?" + intrinsicName.str();
  570. }
  571. else if (func->getName().startswith("fb_dxop_"))
  572. {
  573. intrinsicName = func->getName().substr(8);
  574. name = "dx.op." + intrinsicName.str();
  575. isDxilOp = true;
  576. }
  577. else
  578. {
  579. assert(0 && "Bad intrinsic");
  580. }
  581. std::vector<Function*> calledFunc = getFunctionsWithPrefix(m_module, name);
  582. if (calledFunc.empty())
  583. continue;
  584. std::vector<CallInst*> calls = getCallsToFunction(calledFunc[0]);
  585. if (calls.empty())
  586. continue;
  587. bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler");
  588. Function* pendingFunc = get(pendingIntrinsics, intrinsicName.str());
  589. Function* funcInModule = nullptr;
  590. Function* pendingFuncInModule = nullptr;
  591. for (CallInst* call : calls)
  592. {
  593. Function* caller = call->getParent()->getParent();
  594. if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function"))
  595. continue;
  596. Function* F = nullptr;
  597. if (pendingFunc && shouldUsePendingValue(caller, intrinsicName))
  598. {
  599. if (!pendingFuncInModule)
  600. pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc);
  601. F = pendingFuncInModule;
  602. }
  603. else
  604. {
  605. if (!funcInModule)
  606. funcInModule = getOrInsertFunction(m_module, func);
  607. F = funcInModule;
  608. }
  609. // insert runtime data and the rest of the arguments
  610. std::vector<Value*> args;
  611. if (needsRuntimeDataArg)
  612. args.push_back(caller->arg_begin());
  613. int argIdx = 0;
  614. for (auto& arg : call->arg_operands())
  615. {
  616. if (argIdx++ == 0 && isDxilOp)
  617. continue; // skip the intrinsic number
  618. if (!expandFloat3(args, arg, call) && !float3x4ToFloat12(args, arg, call))
  619. args.push_back(arg);
  620. }
  621. CallInst* newCall = CallInst::Create(F, args, "", call);
  622. if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C))
  623. {
  624. newCall->takeName(call);
  625. call->replaceAllUsesWith(newCall);
  626. }
  627. call->eraseFromParent();
  628. }
  629. }
  630. }
  631. Type* DxrFallbackCompiler::getRuntimeDataArgType()
  632. {
  633. // Get the first argument from a known runtime function (assuming the runtime
  634. // has already been linked in).
  635. Function* F = m_module->getFunction("stackIntPtr");
  636. return F->arg_begin()->getType();
  637. }
  638. Function* DxrFallbackCompiler::createDispatchFunction(const IntToFuncMap &stateFunctionMap, Type* runtimeDataArgTy)
  639. {
  640. LLVMContext& context = m_module->getContext();
  641. FunctionType* stateFuncTy = FunctionType::get(Type::getInt32Ty(context), { runtimeDataArgTy }, false);
  642. Function* dispatchFunc = FunctionBuilder(m_module, "dispatch").i32().type(runtimeDataArgTy, "runtimeData").i32("stateID").build();
  643. Value* runtimeDataArg = dispatchFunc->arg_begin();
  644. Value* stateIdArg = ++dispatchFunc->arg_begin();
  645. BasicBlock* entryBlock = BasicBlock::Create(context, "entry", dispatchFunc);
  646. BasicBlock* badBlock = BasicBlock::Create(context, "badStateID", dispatchFunc);
  647. IRBuilder<> builder(badBlock);
  648. builder.SetInsertPoint(badBlock);
  649. builder.CreateRet(makeInt32(-3, context)); // return an error value
  650. builder.SetInsertPoint(entryBlock);
  651. SwitchInst* switchInst = builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size());
  652. BasicBlock* endBlock = badBlock;
  653. for (auto& kv : stateFunctionMap)
  654. {
  655. int stateId = kv.first;
  656. Function* stateFunc = kv.second;
  657. Value* stateFuncInModule = m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy);
  658. BasicBlock* block = BasicBlock::Create(context, "state_" + Twine(stateId) + "." + stateFunc->getName(), dispatchFunc, endBlock);
  659. builder.SetInsertPoint(block);
  660. Value* nextStateId = builder.CreateCall(stateFuncInModule, { runtimeDataArg }, "nextStateId");
  661. builder.CreateRet(nextStateId);
  662. switchInst->addCase(makeInt32(stateId, context), block);
  663. }
  664. return dispatchFunc;
  665. }
  666. std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunction(const std::string& funcName)
  667. {
  668. std::vector<CallInst*> calls;
  669. Function* F = m_module->getFunction(funcName);
  670. if (!F)
  671. return calls;
  672. for (User* U : F->users())
  673. {
  674. CallInst* call = dyn_cast<CallInst>(U);
  675. if (!call)
  676. continue;
  677. Function* caller = call->getParent()->getParent();
  678. auto it = m_shaderMap.find(cleanName(caller->getName()));
  679. if (it != m_shaderMap.end())
  680. calls.push_back(call);
  681. }
  682. return calls;
  683. }
  684. std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix(const std::string& funcNamePrefix)
  685. {
  686. std::vector<CallInst*> calls;
  687. for (Function* F : getFunctionsWithPrefix(m_module, funcNamePrefix))
  688. {
  689. for (User* U : F->users())
  690. {
  691. CallInst* call = dyn_cast<CallInst>(U);
  692. if (!call)
  693. continue;
  694. Function* caller = call->getParent()->getParent();
  695. if (m_shaderMap.count(cleanName(caller->getName())))
  696. calls.push_back(call);
  697. }
  698. }
  699. return calls;
  700. }
  701. void DxrFallbackCompiler::resizeStack(Function* F, unsigned sizeInBytes)
  702. {
  703. // Find the stack
  704. AllocaInst* stack = nullptr;
  705. for (auto& I : F->getEntryBlock().getInstList())
  706. {
  707. AllocaInst* alloc = dyn_cast<AllocaInst>(&I);
  708. if (alloc && alloc->getName().startswith("theStack"))
  709. {
  710. stack = alloc;
  711. break;
  712. }
  713. }
  714. if (!stack)
  715. return;
  716. // Create a new stack
  717. LLVMContext& C = F->getContext();
  718. ArrayType* newStackTy = ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int));
  719. AllocaInst* newStack = new AllocaInst(newStackTy, "", stack);
  720. newStack->takeName(stack);
  721. // Remap all GEPs - replaceAllUsesWith() won't change types
  722. for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE; )
  723. {
  724. GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(*U++);
  725. assert(gep && "theStack has non-gep user.");
  726. std::vector<Value*> idxList(gep->idx_begin(), gep->idx_end());
  727. GetElementPtrInst* newGep = GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep);
  728. newGep->takeName(gep);
  729. gep->replaceAllUsesWith(newGep);
  730. gep->eraseFromParent();
  731. }
  732. stack->eraseFromParent();
  733. }