DxilGenerationPass.cpp 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilGenerationPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // DxilGenerationPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilModule.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/HLOperations.h"
  16. #include "dxc/DXIL/DxilInstructions.h"
  17. #include "dxc/HlslIntrinsicOp.h"
  18. #include "dxc/Support/Global.h"
  19. #include "dxc/DXIL/DxilTypeSystem.h"
  20. #include "dxc/HLSL/HLOperationLower.h"
  21. #include "HLSignatureLower.h"
  22. #include "dxc/DXIL/DxilUtil.h"
  23. #include "dxc/Support/exception.h"
  24. #include "dxc/DXIL/DxilEntryProps.h"
  25. #include "llvm/IR/GetElementPtrTypeIterator.h"
  26. #include "llvm/IR/IRBuilder.h"
  27. #include "llvm/IR/Instructions.h"
  28. #include "llvm/IR/InstIterator.h"
  29. #include "llvm/IR/IntrinsicInst.h"
  30. #include "llvm/IR/Module.h"
  31. #include "llvm/IR/DebugInfo.h"
  32. #include "llvm/IR/PassManager.h"
  33. #include "llvm/ADT/BitVector.h"
  34. #include "llvm/ADT/SetVector.h"
  35. #include "llvm/Pass.h"
  36. #include "llvm/Transforms/Utils/SSAUpdater.h"
  37. #include "llvm/Analysis/AssumptionCache.h"
  38. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  39. #include "llvm/IR/Dominators.h"
  40. #include <memory>
  41. #include <unordered_set>
  42. #include <iterator>
  43. using namespace llvm;
  44. using namespace hlsl;
  45. // TODO: use hlsl namespace for the most of this file.
  46. namespace {
  47. // Collect unused phi of resources and remove them.
  48. class ResourceRemover : public LoadAndStorePromoter {
  49. AllocaInst *AI;
  50. mutable std::unordered_set<PHINode *> unusedPhis;
  51. public:
  52. ResourceRemover(ArrayRef<Instruction *> Insts, SSAUpdater &S)
  53. : LoadAndStorePromoter(Insts, S), AI(nullptr) {}
  54. void run(AllocaInst *AI, const SmallVectorImpl<Instruction *> &Insts) {
  55. // Remember which alloca we're promoting (for isInstInList).
  56. this->AI = AI;
  57. LoadAndStorePromoter::run(Insts);
  58. for (PHINode *P : unusedPhis) {
  59. P->eraseFromParent();
  60. }
  61. }
  62. bool
  63. isInstInList(Instruction *I,
  64. const SmallVectorImpl<Instruction *> &Insts) const override {
  65. if (LoadInst *LI = dyn_cast<LoadInst>(I))
  66. return LI->getOperand(0) == AI;
  67. return cast<StoreInst>(I)->getPointerOperand() == AI;
  68. }
  69. void replaceLoadWithValue(LoadInst *LI, Value *V) const override {
  70. if (PHINode *PHI = dyn_cast<PHINode>(V)) {
  71. if (PHI->user_empty())
  72. unusedPhis.insert(PHI);
  73. }
  74. LI->replaceAllUsesWith(UndefValue::get(LI->getType()));
  75. }
  76. };
  77. void SimplifyGlobalSymbol(GlobalVariable *GV) {
  78. Type *Ty = GV->getType()->getElementType();
  79. if (!Ty->isArrayTy()) {
  80. // Make sure only 1 load of GV in each function.
  81. std::unordered_map<Function *, Instruction *> handleMapOnFunction;
  82. for (User *U : GV->users()) {
  83. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  84. Function *F = LI->getParent()->getParent();
  85. auto it = handleMapOnFunction.find(F);
  86. if (it == handleMapOnFunction.end()) {
  87. handleMapOnFunction[F] = LI;
  88. } else {
  89. LI->replaceAllUsesWith(it->second);
  90. }
  91. }
  92. }
  93. for (auto it : handleMapOnFunction) {
  94. Function *F = it.first;
  95. Instruction *I = it.second;
  96. IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(F));
  97. Value *headLI = Builder.CreateLoad(GV);
  98. I->replaceAllUsesWith(headLI);
  99. }
  100. }
  101. }
  102. void InitResourceBase(const DxilResourceBase *pSource,
  103. DxilResourceBase *pDest) {
  104. DXASSERT_NOMSG(pSource->GetClass() == pDest->GetClass());
  105. pDest->SetKind(pSource->GetKind());
  106. pDest->SetID(pSource->GetID());
  107. pDest->SetSpaceID(pSource->GetSpaceID());
  108. pDest->SetLowerBound(pSource->GetLowerBound());
  109. pDest->SetRangeSize(pSource->GetRangeSize());
  110. pDest->SetGlobalSymbol(pSource->GetGlobalSymbol());
  111. pDest->SetGlobalName(pSource->GetGlobalName());
  112. pDest->SetHandle(pSource->GetHandle());
  113. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(pSource->GetGlobalSymbol()))
  114. SimplifyGlobalSymbol(GV);
  115. }
  116. void InitResource(const DxilResource *pSource, DxilResource *pDest) {
  117. pDest->SetCompType(pSource->GetCompType());
  118. pDest->SetSampleCount(pSource->GetSampleCount());
  119. pDest->SetElementStride(pSource->GetElementStride());
  120. pDest->SetGloballyCoherent(pSource->IsGloballyCoherent());
  121. pDest->SetHasCounter(pSource->HasCounter());
  122. pDest->SetRW(pSource->IsRW());
  123. pDest->SetROV(pSource->IsROV());
  124. InitResourceBase(pSource, pDest);
  125. }
  126. void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
  127. // Subsystems.
  128. unsigned ValMajor, ValMinor;
  129. H.GetValidatorVersion(ValMajor, ValMinor);
  130. M.SetValidatorVersion(ValMajor, ValMinor);
  131. M.SetShaderModel(H.GetShaderModel(), H.GetHLOptions().bUseMinPrecision);
  132. // Entry function.
  133. if (!M.GetShaderModel()->IsLib()) {
  134. Function *EntryFn = H.GetEntryFunction();
  135. M.SetEntryFunction(EntryFn);
  136. M.SetEntryFunctionName(H.GetEntryFunctionName());
  137. }
  138. std::vector<GlobalVariable* > &LLVMUsed = M.GetLLVMUsed();
  139. // Resources
  140. for (auto && C : H.GetCBuffers()) {
  141. auto b = llvm::make_unique<DxilCBuffer>();
  142. InitResourceBase(C.get(), b.get());
  143. b->SetSize(C->GetSize());
  144. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(b->GetGlobalSymbol()))
  145. LLVMUsed.emplace_back(GV);
  146. M.AddCBuffer(std::move(b));
  147. }
  148. for (auto && C : H.GetUAVs()) {
  149. auto b = llvm::make_unique<DxilResource>();
  150. InitResource(C.get(), b.get());
  151. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(b->GetGlobalSymbol()))
  152. LLVMUsed.emplace_back(GV);
  153. M.AddUAV(std::move(b));
  154. }
  155. for (auto && C : H.GetSRVs()) {
  156. auto b = llvm::make_unique<DxilResource>();
  157. InitResource(C.get(), b.get());
  158. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(b->GetGlobalSymbol()))
  159. LLVMUsed.emplace_back(GV);
  160. M.AddSRV(std::move(b));
  161. }
  162. for (auto && C : H.GetSamplers()) {
  163. auto b = llvm::make_unique<DxilSampler>();
  164. InitResourceBase(C.get(), b.get());
  165. b->SetSamplerKind(C->GetSamplerKind());
  166. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(b->GetGlobalSymbol()))
  167. LLVMUsed.emplace_back(GV);
  168. M.AddSampler(std::move(b));
  169. }
  170. // Signatures.
  171. M.ResetSerializedRootSignature(H.GetSerializedRootSignature());
  172. // Subobjects.
  173. M.ResetSubobjects(H.ReleaseSubobjects());
  174. // Shader properties.
  175. //bool m_bDisableOptimizations;
  176. M.SetDisableOptimization(H.GetHLOptions().bDisableOptimizations);
  177. M.SetLegacyResourceReservation(H.GetHLOptions().bLegacyResourceReservation);
  178. //bool m_bDisableMathRefactoring;
  179. //bool m_bEnableDoublePrecision;
  180. //bool m_bEnableDoubleExtensions;
  181. //M.CollectShaderFlags();
  182. //bool m_bForceEarlyDepthStencil;
  183. //bool m_bEnableRawAndStructuredBuffers;
  184. //bool m_bEnableMSAD;
  185. //M.m_ShaderFlags.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
  186. // DXIL type system.
  187. M.ResetTypeSystem(H.ReleaseTypeSystem());
  188. // Dxil OP.
  189. M.ResetOP(H.ReleaseOP());
  190. // Keep llvm used.
  191. M.EmitLLVMUsed();
  192. M.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
  193. M.SetAutoBindingSpace(H.GetAutoBindingSpace());
  194. // Update Validator Version
  195. M.UpgradeToMinValidatorVersion();
  196. }
  197. class DxilGenerationPass : public ModulePass {
  198. HLModule *m_pHLModule;
  199. bool m_HasDbgInfo;
  200. HLSLExtensionsCodegenHelper *m_extensionsCodegenHelper;
  201. public:
  202. static char ID; // Pass identification, replacement for typeid
  203. explicit DxilGenerationPass(bool NoOpt = false)
  204. : ModulePass(ID), m_pHLModule(nullptr), m_extensionsCodegenHelper(nullptr), NotOptimized(NoOpt) {}
  205. const char *getPassName() const override { return "DXIL Generator"; }
  206. void SetExtensionsHelper(HLSLExtensionsCodegenHelper *helper) {
  207. m_extensionsCodegenHelper = helper;
  208. }
  209. bool runOnModule(Module &M) override {
  210. m_pHLModule = &M.GetOrCreateHLModule();
  211. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  212. // Load up debug information, to cross-reference values and the instructions
  213. // used to load them.
  214. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  215. // EntrySig for shader functions.
  216. DxilEntryPropsMap EntryPropsMap;
  217. if (!SM->IsLib()) {
  218. Function *EntryFn = m_pHLModule->GetEntryFunction();
  219. if (!m_pHLModule->HasDxilFunctionProps(EntryFn)) {
  220. M.getContext().emitError("Entry function don't have property.");
  221. return false;
  222. }
  223. DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(EntryFn);
  224. std::unique_ptr<DxilEntryProps> pProps =
  225. llvm::make_unique<DxilEntryProps>(
  226. props, m_pHLModule->GetHLOptions().bUseMinPrecision);
  227. HLSignatureLower sigLower(m_pHLModule->GetEntryFunction(), *m_pHLModule,
  228. pProps->sig);
  229. sigLower.Run();
  230. EntryPropsMap[EntryFn] = std::move(pProps);
  231. } else {
  232. for (auto It = M.begin(); It != M.end();) {
  233. Function &F = *(It++);
  234. // Lower signature for each graphics or compute entry function.
  235. if (m_pHLModule->HasDxilFunctionProps(&F)) {
  236. DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&F);
  237. std::unique_ptr<DxilEntryProps> pProps =
  238. llvm::make_unique<DxilEntryProps>(
  239. props, m_pHLModule->GetHLOptions().bUseMinPrecision);
  240. if (m_pHLModule->IsGraphicsShader(&F) ||
  241. m_pHLModule->IsComputeShader(&F)) {
  242. HLSignatureLower sigLower(&F, *m_pHLModule, pProps->sig);
  243. // TODO: BUG: This will lower patch constant function sigs twice if
  244. // used by two hull shaders!
  245. sigLower.Run();
  246. }
  247. EntryPropsMap[&F] = std::move(pProps);
  248. }
  249. }
  250. }
  251. std::unordered_set<LoadInst *> UpdateCounterSet;
  252. GenerateDxilOperations(M, UpdateCounterSet);
  253. GenerateDxilCBufferHandles();
  254. MarkUpdateCounter(UpdateCounterSet);
  255. LowerHLCreateHandle();
  256. // LowerHLCreateHandle() should have translated HLCreateHandle to CreateHandleForLib.
  257. // Clean up HLCreateHandle functions.
  258. for (auto It = M.begin(); It != M.end();) {
  259. Function &F = *(It++);
  260. if (!F.isDeclaration()) {
  261. if (hlsl::GetHLOpcodeGroupByName(&F) ==
  262. HLOpcodeGroup::HLCreateHandle) {
  263. if (F.user_empty()) {
  264. F.eraseFromParent();
  265. } else {
  266. M.getContext().emitError("Fail to lower createHandle.");
  267. }
  268. }
  269. }
  270. }
  271. // Translate precise on allocas into function call to keep the information after mem2reg.
  272. // The function calls will be removed after propagate precise attribute.
  273. TranslatePreciseAttribute();
  274. // High-level metadata should now be turned into low-level metadata.
  275. const bool SkipInit = true;
  276. hlsl::DxilModule &DxilMod = M.GetOrCreateDxilModule(SkipInit);
  277. auto pProps = &EntryPropsMap.begin()->second->props;
  278. InitDxilModuleFromHLModule(*m_pHLModule, DxilMod, m_HasDbgInfo);
  279. DxilMod.ResetEntryPropsMap(std::move(EntryPropsMap));
  280. if (!SM->IsLib()) {
  281. DxilMod.SetShaderProperties(pProps);
  282. }
  283. HLModule::ClearHLMetadata(M);
  284. M.ResetHLModule();
  285. // We now have a DXIL representation - record this.
  286. SetPauseResumePasses(M, "hlsl-dxilemit", "hlsl-dxilload");
  287. (void)NotOptimized; // Dummy out unused member to silence warnings
  288. return true;
  289. }
  290. private:
  291. void MarkUpdateCounter(std::unordered_set<LoadInst *> &UpdateCounterSet);
  292. // Generate DXIL cbuffer handles.
  293. void
  294. GenerateDxilCBufferHandles();
  295. // change built-in funtion into DXIL operations
  296. void GenerateDxilOperations(Module &M,
  297. std::unordered_set<LoadInst *> &UpdateCounterSet);
  298. void LowerHLCreateHandle();
  299. // Translate precise attribute into HL function call.
  300. void TranslatePreciseAttribute();
  301. // Input module is not optimized.
  302. bool NotOptimized;
  303. };
  304. }
  305. namespace {
  306. void TranslateHLCreateHandle(Function *F, hlsl::OP &hlslOP) {
  307. Value *opArg = hlslOP.GetU32Const(
  308. (unsigned)DXIL::OpCode::CreateHandleForLib);
  309. for (auto U = F->user_begin(); U != F->user_end();) {
  310. Value *user = *(U++);
  311. if (!isa<Instruction>(user))
  312. continue;
  313. // must be call inst
  314. CallInst *CI = cast<CallInst>(user);
  315. Value *res = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  316. Value *newHandle = nullptr;
  317. IRBuilder<> Builder(CI);
  318. // Res could be ld/phi/select. Will be removed in
  319. // DxilLowerCreateHandleForLib.
  320. Function *createHandle = hlslOP.GetOpFunc(
  321. DXIL::OpCode::CreateHandleForLib, res->getType());
  322. newHandle = Builder.CreateCall(createHandle, {opArg, res});
  323. CI->replaceAllUsesWith(newHandle);
  324. if (res->user_empty()) {
  325. if (Instruction *I = dyn_cast<Instruction>(res))
  326. I->eraseFromParent();
  327. }
  328. CI->eraseFromParent();
  329. }
  330. }
  331. } // namespace
  332. void DxilGenerationPass::LowerHLCreateHandle() {
  333. Module *M = m_pHLModule->GetModule();
  334. hlsl::OP &hlslOP = *m_pHLModule->GetOP();
  335. // generate dxil operation
  336. for (iplist<Function>::iterator F : M->getFunctionList()) {
  337. if (F->user_empty())
  338. continue;
  339. if (!F->isDeclaration()) {
  340. hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
  341. if (group == HLOpcodeGroup::HLCreateHandle) {
  342. // Will lower in later pass.
  343. TranslateHLCreateHandle(F, hlslOP);
  344. }
  345. }
  346. }
  347. }
  348. static void
  349. MarkUavUpdateCounter(Value* LoadOrGEP,
  350. DxilResource &res,
  351. std::unordered_set<LoadInst *> &UpdateCounterSet) {
  352. if (LoadInst *ldInst = dyn_cast<LoadInst>(LoadOrGEP)) {
  353. if (UpdateCounterSet.count(ldInst)) {
  354. DXASSERT_NOMSG(res.GetClass() == DXIL::ResourceClass::UAV);
  355. res.SetHasCounter(true);
  356. }
  357. } else {
  358. DXASSERT(dyn_cast<GEPOperator>(LoadOrGEP) != nullptr,
  359. "else AddOpcodeParamForIntrinsic in CodeGen did not patch uses "
  360. "to only have ld/st refer to temp object");
  361. GEPOperator *GEP = cast<GEPOperator>(LoadOrGEP);
  362. for (auto GEPU : GEP->users()) {
  363. MarkUavUpdateCounter(GEPU, res, UpdateCounterSet);
  364. }
  365. }
  366. }
  367. static void
  368. MarkUavUpdateCounter(DxilResource &res,
  369. std::unordered_set<LoadInst *> &UpdateCounterSet) {
  370. Value *V = res.GetGlobalSymbol();
  371. for (auto U = V->user_begin(), E = V->user_end(); U != E;) {
  372. User *user = *(U++);
  373. // Skip unused user.
  374. if (user->user_empty())
  375. continue;
  376. MarkUavUpdateCounter(user, res, UpdateCounterSet);
  377. }
  378. }
  379. void DxilGenerationPass::MarkUpdateCounter(
  380. std::unordered_set<LoadInst *> &UpdateCounterSet) {
  381. for (size_t i = 0; i < m_pHLModule->GetUAVs().size(); i++) {
  382. HLResource &UAV = m_pHLModule->GetUAV(i);
  383. MarkUavUpdateCounter(UAV, UpdateCounterSet);
  384. }
  385. }
  386. void DxilGenerationPass::GenerateDxilCBufferHandles() {
  387. // For CBuffer, handle are mapped to HLCreateHandle.
  388. OP *hlslOP = m_pHLModule->GetOP();
  389. Value *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::CreateHandleForLib);
  390. LLVMContext &Ctx = hlslOP->GetCtx();
  391. Value *zeroIdx = hlslOP->GetU32Const(0);
  392. for (size_t i = 0; i < m_pHLModule->GetCBuffers().size(); i++) {
  393. DxilCBuffer &CB = m_pHLModule->GetCBuffer(i);
  394. GlobalVariable *GV = dyn_cast<GlobalVariable>(CB.GetGlobalSymbol());
  395. if (GV == nullptr)
  396. continue;
  397. // Remove GEP created in HLObjectOperationLowerHelper::UniformCbPtr.
  398. GV->removeDeadConstantUsers();
  399. std::string handleName = std::string(GV->getName());
  400. DIVariable *DIV = nullptr;
  401. DILocation *DL = nullptr;
  402. if (m_HasDbgInfo) {
  403. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  404. DIV = HLModule::FindGlobalVariableDebugInfo(GV, Finder);
  405. if (DIV)
  406. // TODO: how to get col?
  407. DL = DILocation::get(Ctx, DIV->getLine(), 1,
  408. DIV->getScope());
  409. }
  410. if (CB.GetRangeSize() == 1) {
  411. Function *createHandle =
  412. hlslOP->GetOpFunc(OP::OpCode::CreateHandleForLib,
  413. GV->getType()->getElementType());
  414. for (auto U = GV->user_begin(); U != GV->user_end(); ) {
  415. // Must HLCreateHandle.
  416. CallInst *CI = cast<CallInst>(*(U++));
  417. // Put createHandle to entry block.
  418. IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(CI));
  419. Value *V = Builder.CreateLoad(GV);
  420. CallInst *handle = Builder.CreateCall(createHandle, {opArg, V}, handleName);
  421. if (m_HasDbgInfo) {
  422. // TODO: add debug info.
  423. //handle->setDebugLoc(DL);
  424. (void)(DL);
  425. }
  426. CI->replaceAllUsesWith(handle);
  427. CI->eraseFromParent();
  428. }
  429. } else {
  430. PointerType *Ty = GV->getType();
  431. Type *EltTy = Ty->getElementType()->getArrayElementType()->getPointerTo(
  432. Ty->getAddressSpace());
  433. Function *createHandle = hlslOP->GetOpFunc(
  434. OP::OpCode::CreateHandleForLib, EltTy->getPointerElementType());
  435. for (auto U = GV->user_begin(); U != GV->user_end();) {
  436. // Must HLCreateHandle.
  437. CallInst *CI = cast<CallInst>(*(U++));
  438. IRBuilder<> Builder(CI);
  439. Value *CBIndex = CI->getArgOperand(HLOperandIndex::kCreateHandleIndexOpIdx);
  440. if (isa<ConstantInt>(CBIndex)) {
  441. // Put createHandle to entry block for const index.
  442. Builder.SetInsertPoint(dxilutil::FirstNonAllocaInsertionPt(CI));
  443. }
  444. // Add GEP for cbv array use.
  445. Value *GEP = Builder.CreateGEP(GV, {zeroIdx, CBIndex});
  446. Value *V = Builder.CreateLoad(GEP);
  447. CallInst *handle = Builder.CreateCall(createHandle, {opArg, V}, handleName);
  448. CI->replaceAllUsesWith(handle);
  449. CI->eraseFromParent();
  450. }
  451. }
  452. }
  453. }
  454. void DxilGenerationPass::GenerateDxilOperations(
  455. Module &M, std::unordered_set<LoadInst *> &UpdateCounterSet) {
  456. // remove all functions except entry function
  457. Function *entry = m_pHLModule->GetEntryFunction();
  458. const ShaderModel *pSM = m_pHLModule->GetShaderModel();
  459. Function *patchConstantFunc = nullptr;
  460. if (pSM->IsHS()) {
  461. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(entry);
  462. patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc;
  463. }
  464. if (!pSM->IsLib()) {
  465. for (auto F = M.begin(); F != M.end();) {
  466. Function *func = F++;
  467. if (func->isDeclaration())
  468. continue;
  469. if (func == entry)
  470. continue;
  471. if (func == patchConstantFunc)
  472. continue;
  473. if (func->user_empty())
  474. func->eraseFromParent();
  475. }
  476. }
  477. TranslateBuiltinOperations(*m_pHLModule, m_extensionsCodegenHelper,
  478. UpdateCounterSet);
  479. // Remove unused HL Operation functions.
  480. std::vector<Function *> deadList;
  481. for (iplist<Function>::iterator F : M.getFunctionList()) {
  482. hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
  483. if (group != HLOpcodeGroup::NotHL || F->isIntrinsic())
  484. if (F->user_empty())
  485. deadList.emplace_back(F);
  486. }
  487. for (Function *F : deadList)
  488. F->eraseFromParent();
  489. }
  490. static void TranslatePreciseAttributeOnFunction(Function &F, Module &M) {
  491. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  492. // Find allocas that has precise attribute, by looking at all instructions in
  493. // the entry node
  494. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
  495. Instruction *Inst = (I++);
  496. if (AllocaInst *AI = dyn_cast<AllocaInst>(Inst)) {
  497. if (HLModule::HasPreciseAttributeWithMetadata(AI)) {
  498. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(AI, M);
  499. }
  500. } else {
  501. DXASSERT(!HLModule::HasPreciseAttributeWithMetadata(Inst), "Only alloca can has precise metadata.");
  502. }
  503. }
  504. FastMathFlags FMF;
  505. FMF.setUnsafeAlgebra();
  506. // Set fast math for all FPMathOperators.
  507. // Already set FastMath in options. But that only enable things like fadd.
  508. // Every inst which type is float can be cast to FPMathOperator.
  509. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  510. BasicBlock *BB = BBI;
  511. for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
  512. if (dyn_cast<FPMathOperator>(I)) {
  513. // Set precise fast math on those instructions that support it.
  514. if (DxilModule::PreservesFastMathFlags(I))
  515. I->copyFastMathFlags(FMF);
  516. }
  517. }
  518. }
  519. }
  520. void DxilGenerationPass::TranslatePreciseAttribute() {
  521. bool bIEEEStrict = m_pHLModule->GetHLOptions().bIEEEStrict;
  522. // If IEEE strict, everying is precise, don't need to mark it.
  523. if (bIEEEStrict)
  524. return;
  525. Module &M = *m_pHLModule->GetModule();
  526. // TODO: If not inline every function, for function has call site with precise
  527. // argument and call site without precise argument, need to clone the function
  528. // to propagate the precise for the precise call site.
  529. // This should be done at CGMSHLSLRuntime::FinishCodeGen.
  530. if (m_pHLModule->GetShaderModel()->IsLib()) {
  531. // TODO: If all functions have been inlined, and unreferenced functions removed,
  532. // it should make sense to run on all funciton bodies,
  533. // even when not processing a library.
  534. for (Function &F : M.functions()) {
  535. if (!F.isDeclaration())
  536. TranslatePreciseAttributeOnFunction(F, M);
  537. }
  538. } else {
  539. Function *EntryFn = m_pHLModule->GetEntryFunction();
  540. TranslatePreciseAttributeOnFunction(*EntryFn, M);
  541. if (m_pHLModule->GetShaderModel()->IsHS()) {
  542. DxilFunctionProps &EntryQual = m_pHLModule->GetDxilFunctionProps(EntryFn);
  543. Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
  544. TranslatePreciseAttributeOnFunction(*patchConstantFunc, M);
  545. }
  546. }
  547. }
  548. char DxilGenerationPass::ID = 0;
  549. ModulePass *llvm::createDxilGenerationPass(bool NotOptimized, hlsl::HLSLExtensionsCodegenHelper *extensionsHelper) {
  550. DxilGenerationPass *dxilPass = new DxilGenerationPass(NotOptimized);
  551. dxilPass->SetExtensionsHelper(extensionsHelper);
  552. return dxilPass;
  553. }
  554. INITIALIZE_PASS(DxilGenerationPass, "dxilgen", "HLSL DXIL Generation", false, false)
  555. ///////////////////////////////////////////////////////////////////////////////
  556. namespace {
  557. class HLEmitMetadata : public ModulePass {
  558. public:
  559. static char ID; // Pass identification, replacement for typeid
  560. explicit HLEmitMetadata() : ModulePass(ID) {}
  561. const char *getPassName() const override { return "HLSL High-Level Metadata Emit"; }
  562. bool runOnModule(Module &M) override {
  563. if (M.HasHLModule()) {
  564. HLModule::ClearHLMetadata(M);
  565. M.GetHLModule().EmitHLMetadata();
  566. return true;
  567. }
  568. return false;
  569. }
  570. };
  571. }
  572. char HLEmitMetadata::ID = 0;
  573. ModulePass *llvm::createHLEmitMetadataPass() {
  574. return new HLEmitMetadata();
  575. }
  576. INITIALIZE_PASS(HLEmitMetadata, "hlsl-hlemit", "HLSL High-Level Metadata Emit", false, false)
  577. ///////////////////////////////////////////////////////////////////////////////
  578. namespace {
  579. class HLEnsureMetadata : public ModulePass {
  580. public:
  581. static char ID; // Pass identification, replacement for typeid
  582. explicit HLEnsureMetadata() : ModulePass(ID) {}
  583. const char *getPassName() const override { return "HLSL High-Level Metadata Ensure"; }
  584. bool runOnModule(Module &M) override {
  585. if (!M.HasHLModule()) {
  586. M.GetOrCreateHLModule();
  587. return true;
  588. }
  589. return false;
  590. }
  591. };
  592. }
  593. char HLEnsureMetadata::ID = 0;
  594. ModulePass *llvm::createHLEnsureMetadataPass() {
  595. return new HLEnsureMetadata();
  596. }
  597. INITIALIZE_PASS(HLEnsureMetadata, "hlsl-hlensure", "HLSL High-Level Metadata Ensure", false, false)
  598. ///////////////////////////////////////////////////////////////////////////////
  599. // Precise propagate.
  600. namespace {
  601. class DxilPrecisePropagatePass : public ModulePass {
  602. public:
  603. static char ID; // Pass identification, replacement for typeid
  604. explicit DxilPrecisePropagatePass() : ModulePass(ID) {}
  605. const char *getPassName() const override { return "DXIL Precise Propagate"; }
  606. bool runOnModule(Module &M) override {
  607. DxilModule &dxilModule = M.GetOrCreateDxilModule();
  608. DxilTypeSystem &typeSys = dxilModule.GetTypeSystem();
  609. std::unordered_set<Instruction*> processedSet;
  610. std::vector<Function*> deadList;
  611. for (Function &F : M.functions()) {
  612. if (HLModule::HasPreciseAttribute(&F)) {
  613. PropagatePreciseOnFunctionUser(F, typeSys, processedSet);
  614. deadList.emplace_back(&F);
  615. }
  616. }
  617. for (Function *F : deadList)
  618. F->eraseFromParent();
  619. return true;
  620. }
  621. private:
  622. void PropagatePreciseOnFunctionUser(
  623. Function &F, DxilTypeSystem &typeSys,
  624. std::unordered_set<Instruction *> &processedSet);
  625. };
  626. char DxilPrecisePropagatePass::ID = 0;
  627. }
  628. static void PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
  629. std::unordered_set<Instruction *> &processedSet);
  630. static void PropagatePreciseAttributeOnOperand(
  631. Value *V, DxilTypeSystem &typeSys, LLVMContext &Context,
  632. std::unordered_set<Instruction *> &processedSet) {
  633. Instruction *I = dyn_cast<Instruction>(V);
  634. // Skip none inst.
  635. if (!I)
  636. return;
  637. FPMathOperator *FPMath = dyn_cast<FPMathOperator>(I);
  638. // Skip none FPMath
  639. if (!FPMath)
  640. return;
  641. // Skip inst already marked.
  642. if (processedSet.count(I) > 0)
  643. return;
  644. // TODO: skip precise on integer type, sample instruction...
  645. processedSet.insert(I);
  646. // Set precise fast math on those instructions that support it.
  647. if (DxilModule::PreservesFastMathFlags(I))
  648. DxilModule::SetPreciseFastMathFlags(I);
  649. // Fast math not work on call, use metadata.
  650. if (CallInst *CI = dyn_cast<CallInst>(I))
  651. HLModule::MarkPreciseAttributeWithMetadata(CI);
  652. PropagatePreciseAttribute(I, typeSys, processedSet);
  653. }
  654. static void PropagatePreciseAttributeOnPointer(
  655. Value *Ptr, DxilTypeSystem &typeSys, LLVMContext &Context,
  656. std::unordered_set<Instruction *> &processedSet) {
  657. // Find all store and propagate on the val operand of store.
  658. // For CallInst, if Ptr is used as out parameter, mark it.
  659. for (User *U : Ptr->users()) {
  660. Instruction *user = cast<Instruction>(U);
  661. if (StoreInst *stInst = dyn_cast<StoreInst>(user)) {
  662. Value *val = stInst->getValueOperand();
  663. PropagatePreciseAttributeOnOperand(val, typeSys, Context, processedSet);
  664. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  665. bool bReadOnly = true;
  666. Function *F = CI->getCalledFunction();
  667. const DxilFunctionAnnotation *funcAnnotation =
  668. typeSys.GetFunctionAnnotation(F);
  669. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  670. if (Ptr != CI->getArgOperand(i))
  671. continue;
  672. const DxilParameterAnnotation &paramAnnotation =
  673. funcAnnotation->GetParameterAnnotation(i);
  674. // OutputPatch and OutputStream will be checked after scalar repl.
  675. // Here only check out/inout
  676. if (paramAnnotation.GetParamInputQual() == DxilParamInputQual::Out ||
  677. paramAnnotation.GetParamInputQual() == DxilParamInputQual::Inout) {
  678. bReadOnly = false;
  679. break;
  680. }
  681. }
  682. if (!bReadOnly)
  683. PropagatePreciseAttributeOnOperand(CI, typeSys, Context, processedSet);
  684. }
  685. }
  686. }
  687. static void
  688. PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
  689. std::unordered_set<Instruction *> &processedSet) {
  690. LLVMContext &Context = I->getContext();
  691. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
  692. PropagatePreciseAttributeOnPointer(AI, typeSys, Context, processedSet);
  693. } else if (dyn_cast<CallInst>(I)) {
  694. // Propagate every argument.
  695. // TODO: only propagate precise argument.
  696. for (Value *src : I->operands())
  697. PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
  698. } else if (dyn_cast<FPMathOperator>(I)) {
  699. // TODO: only propagate precise argument.
  700. for (Value *src : I->operands())
  701. PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
  702. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
  703. Value *Ptr = ldInst->getPointerOperand();
  704. PropagatePreciseAttributeOnPointer(Ptr, typeSys, Context, processedSet);
  705. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I))
  706. PropagatePreciseAttributeOnPointer(GEP, typeSys, Context, processedSet);
  707. // TODO: support more case which need
  708. }
  709. void DxilPrecisePropagatePass::PropagatePreciseOnFunctionUser(
  710. Function &F, DxilTypeSystem &typeSys,
  711. std::unordered_set<Instruction *> &processedSet) {
  712. LLVMContext &Context = F.getContext();
  713. for (auto U = F.user_begin(), E = F.user_end(); U != E;) {
  714. CallInst *CI = cast<CallInst>(*(U++));
  715. Value *V = CI->getArgOperand(0);
  716. PropagatePreciseAttributeOnOperand(V, typeSys, Context, processedSet);
  717. CI->eraseFromParent();
  718. }
  719. }
  720. ModulePass *llvm::createDxilPrecisePropagatePass() {
  721. return new DxilPrecisePropagatePass();
  722. }
  723. INITIALIZE_PASS(DxilPrecisePropagatePass, "hlsl-dxil-precise", "DXIL precise attribute propagate", false, false)
  724. ///////////////////////////////////////////////////////////////////////////////
  725. namespace {
  726. class HLDeadFunctionElimination : public ModulePass {
  727. public:
  728. static char ID; // Pass identification, replacement for typeid
  729. explicit HLDeadFunctionElimination () : ModulePass(ID) {}
  730. const char *getPassName() const override { return "Remove all unused function except entry from HLModule"; }
  731. bool runOnModule(Module &M) override {
  732. if (M.HasHLModule()) {
  733. HLModule &HLM = M.GetHLModule();
  734. bool IsLib = HLM.GetShaderModel()->IsLib();
  735. // Remove unused functions except entry and patch constant func.
  736. // For library profile, only remove unused external functions.
  737. Function *EntryFunc = HLM.GetEntryFunction();
  738. Function *PatchConstantFunc = HLM.GetPatchConstantFunction();
  739. return dxilutil::RemoveUnusedFunctions(M, EntryFunc, PatchConstantFunc,
  740. IsLib);
  741. }
  742. return false;
  743. }
  744. };
  745. }
  746. char HLDeadFunctionElimination::ID = 0;
  747. ModulePass *llvm::createHLDeadFunctionEliminationPass() {
  748. return new HLDeadFunctionElimination();
  749. }
  750. INITIALIZE_PASS(HLDeadFunctionElimination, "hl-dfe", "Remove all unused function except entry from HLModule", false, false)
  751. ///////////////////////////////////////////////////////////////////////////////
  752. // Legalize resource use.
  753. // Map local or static global resource to global resource.
  754. // Require inline for static global resource.
  755. namespace {
  756. static const StringRef kStaticResourceLibErrorMsg = "static global resource use is disallowed in library exports.";
  757. class DxilPromoteStaticResources : public ModulePass {
  758. public:
  759. static char ID; // Pass identification, replacement for typeid
  760. explicit DxilPromoteStaticResources()
  761. : ModulePass(ID) {}
  762. const char *getPassName() const override {
  763. return "DXIL Legalize Static Resource Use";
  764. }
  765. bool runOnModule(Module &M) override {
  766. // Promote static global variables.
  767. return PromoteStaticGlobalResources(M);
  768. }
  769. private:
  770. bool PromoteStaticGlobalResources(Module &M);
  771. };
  772. char DxilPromoteStaticResources::ID = 0;
  773. class DxilPromoteLocalResources : public FunctionPass {
  774. void getAnalysisUsage(AnalysisUsage &AU) const override;
  775. public:
  776. static char ID; // Pass identification, replacement for typeid
  777. explicit DxilPromoteLocalResources()
  778. : FunctionPass(ID) {}
  779. const char *getPassName() const override {
  780. return "DXIL Legalize Resource Use";
  781. }
  782. bool runOnFunction(Function &F) override {
  783. // Promote local resource first.
  784. return PromoteLocalResource(F);
  785. }
  786. private:
  787. bool PromoteLocalResource(Function &F);
  788. };
  789. char DxilPromoteLocalResources::ID = 0;
  790. }
  791. void DxilPromoteLocalResources::getAnalysisUsage(AnalysisUsage &AU) const {
  792. AU.addRequired<AssumptionCacheTracker>();
  793. AU.addRequired<DominatorTreeWrapperPass>();
  794. AU.setPreservesAll();
  795. }
  796. bool DxilPromoteLocalResources::PromoteLocalResource(Function &F) {
  797. bool bModified = false;
  798. std::vector<AllocaInst *> Allocas;
  799. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  800. AssumptionCache &AC =
  801. getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  802. BasicBlock &BB = F.getEntryBlock();
  803. unsigned allocaSize = 0;
  804. while (1) {
  805. Allocas.clear();
  806. // Find allocas that are safe to promote, by looking at all instructions in
  807. // the entry node
  808. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  809. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
  810. if (dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(AI->getAllocatedType()))) {
  811. if (isAllocaPromotable(AI))
  812. Allocas.push_back(AI);
  813. }
  814. }
  815. if (Allocas.empty())
  816. break;
  817. // No update.
  818. // Report error and break.
  819. if (allocaSize == Allocas.size()) {
  820. F.getContext().emitError(dxilutil::kResourceMapErrorMsg);
  821. break;
  822. }
  823. allocaSize = Allocas.size();
  824. PromoteMemToReg(Allocas, *DT, nullptr, &AC);
  825. bModified = true;
  826. }
  827. return bModified;
  828. }
  829. FunctionPass *llvm::createDxilPromoteLocalResources() {
  830. return new DxilPromoteLocalResources();
  831. }
  832. INITIALIZE_PASS_BEGIN(DxilPromoteLocalResources,
  833. "hlsl-dxil-promote-local-resources",
  834. "DXIL promote local resource use", false, true)
  835. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  836. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  837. INITIALIZE_PASS_END(DxilPromoteLocalResources,
  838. "hlsl-dxil-promote-local-resources",
  839. "DXIL promote local resource use", false, true)
  840. bool DxilPromoteStaticResources::PromoteStaticGlobalResources(
  841. Module &M) {
  842. if (M.GetOrCreateHLModule().GetShaderModel()->IsLib()) {
  843. // Read/write to global static resource is disallowed for libraries:
  844. // Resource use needs to be resolved to a single real global resource,
  845. // but it may not be possible since any external function call may re-enter
  846. // at any other library export, which could modify the global static
  847. // between write and read.
  848. // While it could work for certain cases, describing the boundary at
  849. // the HLSL level is difficult, so at this point it's better to disallow.
  850. // example of what could work:
  851. // After inlining, exported functions must have writes to static globals
  852. // before reads, and must not have any external function calls between
  853. // writes and subsequent reads, such that the static global may be
  854. // optimized away for the exported function.
  855. for (auto &GV : M.globals()) {
  856. if (GV.getLinkage() == GlobalVariable::LinkageTypes::InternalLinkage &&
  857. dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(GV.getType()))) {
  858. if (!GV.user_empty()) {
  859. if (Instruction *I = dyn_cast<Instruction>(*GV.user_begin())) {
  860. dxilutil::EmitErrorOnInstruction(I, kStaticResourceLibErrorMsg);
  861. break;
  862. }
  863. }
  864. }
  865. }
  866. return false;
  867. }
  868. bool bModified = false;
  869. std::set<GlobalVariable *> staticResources;
  870. for (auto &GV : M.globals()) {
  871. if (GV.getLinkage() == GlobalVariable::LinkageTypes::InternalLinkage &&
  872. dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(GV.getType()))) {
  873. staticResources.insert(&GV);
  874. }
  875. }
  876. SSAUpdater SSA;
  877. SmallVector<Instruction *, 4> Insts;
  878. // Make sure every resource load has mapped to global variable.
  879. while (!staticResources.empty()) {
  880. bool bUpdated = false;
  881. for (auto it = staticResources.begin(); it != staticResources.end();) {
  882. GlobalVariable *GV = *(it++);
  883. // Build list of instructions to promote.
  884. for (User *U : GV->users()) {
  885. Instruction *I = cast<Instruction>(U);
  886. Insts.emplace_back(I);
  887. }
  888. LoadAndStorePromoter(Insts, SSA).run(Insts);
  889. if (GV->user_empty()) {
  890. bUpdated = true;
  891. staticResources.erase(GV);
  892. }
  893. Insts.clear();
  894. }
  895. if (!bUpdated) {
  896. M.getContext().emitError(dxilutil::kResourceMapErrorMsg);
  897. break;
  898. }
  899. bModified = true;
  900. }
  901. return bModified;
  902. }
  903. ModulePass *llvm::createDxilPromoteStaticResources() {
  904. return new DxilPromoteStaticResources();
  905. }
  906. INITIALIZE_PASS(DxilPromoteStaticResources,
  907. "hlsl-dxil-promote-static-resources",
  908. "DXIL promote static resource use", false, false)
  909. ///////////////////////////////////////////////////////////////////////////////
  910. // Legalize EvalOperations.
  911. // Make sure src of EvalOperations are from function parameter.
  912. // This is needed in order to translate EvaluateAttribute operations that traces
  913. // back to LoadInput operations during translation stage. Promoting load/store
  914. // instructions beforehand will allow us to easily trace back to loadInput from
  915. // function call.
  916. namespace {
  917. class DxilLegalizeEvalOperations : public ModulePass {
  918. public:
  919. static char ID; // Pass identification, replacement for typeid
  920. explicit DxilLegalizeEvalOperations() : ModulePass(ID) {}
  921. const char *getPassName() const override {
  922. return "DXIL Legalize EvalOperations";
  923. }
  924. bool runOnModule(Module &M) override {
  925. for (Function &F : M.getFunctionList()) {
  926. hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(&F);
  927. if (group != HLOpcodeGroup::NotHL) {
  928. std::vector<CallInst *> EvalFunctionCalls;
  929. // Find all EvaluateAttribute calls
  930. for (User *U : F.users()) {
  931. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  932. IntrinsicOp evalOp =
  933. static_cast<IntrinsicOp>(hlsl::GetHLOpcode(CI));
  934. if (evalOp == IntrinsicOp::IOP_EvaluateAttributeAtSample ||
  935. evalOp == IntrinsicOp::IOP_EvaluateAttributeCentroid ||
  936. evalOp == IntrinsicOp::IOP_EvaluateAttributeSnapped ||
  937. evalOp == IntrinsicOp::IOP_GetAttributeAtVertex) {
  938. EvalFunctionCalls.push_back(CI);
  939. }
  940. }
  941. }
  942. if (EvalFunctionCalls.empty()) {
  943. continue;
  944. }
  945. // Start from the call instruction, find all allocas that this call
  946. // uses.
  947. std::unordered_set<AllocaInst *> allocas;
  948. for (CallInst *CI : EvalFunctionCalls) {
  949. FindAllocasForEvalOperations(CI, allocas);
  950. }
  951. SSAUpdater SSA;
  952. SmallVector<Instruction *, 4> Insts;
  953. for (AllocaInst *AI : allocas) {
  954. for (User *user : AI->users()) {
  955. if (isa<LoadInst>(user) || isa<StoreInst>(user)) {
  956. Insts.emplace_back(cast<Instruction>(user));
  957. }
  958. }
  959. LoadAndStorePromoter(Insts, SSA).run(Insts);
  960. Insts.clear();
  961. }
  962. }
  963. }
  964. return true;
  965. }
  966. private:
  967. void FindAllocasForEvalOperations(Value *val,
  968. std::unordered_set<AllocaInst *> &allocas);
  969. };
  970. char DxilLegalizeEvalOperations::ID = 0;
  971. // Find allocas for EvaluateAttribute operations
  972. void DxilLegalizeEvalOperations::FindAllocasForEvalOperations(
  973. Value *val, std::unordered_set<AllocaInst *> &allocas) {
  974. Value *CurVal = val;
  975. while (!isa<AllocaInst>(CurVal)) {
  976. if (CallInst *CI = dyn_cast<CallInst>(CurVal)) {
  977. CurVal = CI->getOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  978. } else if (InsertElementInst *IE = dyn_cast<InsertElementInst>(CurVal)) {
  979. Value *arg0 =
  980. IE->getOperand(0); // Could be another insertelement or undef
  981. Value *arg1 = IE->getOperand(1);
  982. FindAllocasForEvalOperations(arg0, allocas);
  983. CurVal = arg1;
  984. } else if (ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(CurVal)) {
  985. Value *arg0 = SV->getOperand(0);
  986. Value *arg1 = SV->getOperand(1);
  987. FindAllocasForEvalOperations(
  988. arg0, allocas); // Shuffle vector could come from different allocas
  989. CurVal = arg1;
  990. } else if (ExtractElementInst *EE = dyn_cast<ExtractElementInst>(CurVal)) {
  991. CurVal = EE->getOperand(0);
  992. } else if (LoadInst *LI = dyn_cast<LoadInst>(CurVal)) {
  993. CurVal = LI->getOperand(0);
  994. } else {
  995. break;
  996. }
  997. }
  998. if (AllocaInst *AI = dyn_cast<AllocaInst>(CurVal)) {
  999. allocas.insert(AI);
  1000. }
  1001. }
  1002. } // namespace
  1003. ModulePass *llvm::createDxilLegalizeEvalOperationsPass() {
  1004. return new DxilLegalizeEvalOperations();
  1005. }
  1006. INITIALIZE_PASS(DxilLegalizeEvalOperations,
  1007. "hlsl-dxil-legalize-eval-operations",
  1008. "DXIL legalize eval operations", false, false)
  1009. ///////////////////////////////////////////////////////////////////////////////
  1010. // Translate RawBufferLoad/RawBufferStore
  1011. // This pass is to make sure that we generate correct buffer load for DXIL
  1012. // For DXIL < 1.2, rawBufferLoad will be translated to BufferLoad instruction
  1013. // without mask.
  1014. // For DXIL >= 1.2, if min precision is enabled, currently generation pass is
  1015. // producing i16/f16 return type for min precisions. For rawBuffer, we will
  1016. // change this so that min precisions are returning its actual scalar type (i32/f32)
  1017. // and will be truncated to their corresponding types after loading / before storing.
  1018. namespace {
  1019. // Create { v0, v1 } from { v0.lo, v0.hi, v1.lo, v1.hi }
  1020. void Make64bitResultForLoad(Type *EltTy, ArrayRef<Value *> resultElts32,
  1021. unsigned size, MutableArrayRef<Value *> resultElts,
  1022. hlsl::OP *hlslOP, IRBuilder<> &Builder) {
  1023. Type *i64Ty = Builder.getInt64Ty();
  1024. Type *doubleTy = Builder.getDoubleTy();
  1025. if (EltTy == doubleTy) {
  1026. Function *makeDouble =
  1027. hlslOP->GetOpFunc(DXIL::OpCode::MakeDouble, doubleTy);
  1028. Value *makeDoubleOpArg =
  1029. Builder.getInt32((unsigned)DXIL::OpCode::MakeDouble);
  1030. for (unsigned i = 0; i < size; i++) {
  1031. Value *lo = resultElts32[2 * i];
  1032. Value *hi = resultElts32[2 * i + 1];
  1033. Value *V = Builder.CreateCall(makeDouble, {makeDoubleOpArg, lo, hi});
  1034. resultElts[i] = V;
  1035. }
  1036. } else {
  1037. for (unsigned i = 0; i < size; i++) {
  1038. Value *lo = resultElts32[2 * i];
  1039. Value *hi = resultElts32[2 * i + 1];
  1040. lo = Builder.CreateZExt(lo, i64Ty);
  1041. hi = Builder.CreateZExt(hi, i64Ty);
  1042. hi = Builder.CreateShl(hi, 32);
  1043. resultElts[i] = Builder.CreateOr(lo, hi);
  1044. }
  1045. }
  1046. }
  1047. // Split { v0, v1 } to { v0.lo, v0.hi, v1.lo, v1.hi }
  1048. void Split64bitValForStore(Type *EltTy, ArrayRef<Value *> vals, unsigned size,
  1049. MutableArrayRef<Value *> vals32, hlsl::OP *hlslOP,
  1050. IRBuilder<> &Builder) {
  1051. Type *i32Ty = Builder.getInt32Ty();
  1052. Type *doubleTy = Builder.getDoubleTy();
  1053. Value *undefI32 = UndefValue::get(i32Ty);
  1054. if (EltTy == doubleTy) {
  1055. Function *dToU = hlslOP->GetOpFunc(DXIL::OpCode::SplitDouble, doubleTy);
  1056. Value *dToUOpArg = Builder.getInt32((unsigned)DXIL::OpCode::SplitDouble);
  1057. for (unsigned i = 0; i < size; i++) {
  1058. if (isa<UndefValue>(vals[i])) {
  1059. vals32[2 * i] = undefI32;
  1060. vals32[2 * i + 1] = undefI32;
  1061. } else {
  1062. Value *retVal = Builder.CreateCall(dToU, {dToUOpArg, vals[i]});
  1063. Value *lo = Builder.CreateExtractValue(retVal, 0);
  1064. Value *hi = Builder.CreateExtractValue(retVal, 1);
  1065. vals32[2 * i] = lo;
  1066. vals32[2 * i + 1] = hi;
  1067. }
  1068. }
  1069. } else {
  1070. for (unsigned i = 0; i < size; i++) {
  1071. if (isa<UndefValue>(vals[i])) {
  1072. vals32[2 * i] = undefI32;
  1073. vals32[2 * i + 1] = undefI32;
  1074. } else {
  1075. Value *lo = Builder.CreateTrunc(vals[i], i32Ty);
  1076. Value *hi = Builder.CreateLShr(vals[i], 32);
  1077. hi = Builder.CreateTrunc(hi, i32Ty);
  1078. vals32[2 * i] = lo;
  1079. vals32[2 * i + 1] = hi;
  1080. }
  1081. }
  1082. }
  1083. }
  1084. class DxilTranslateRawBuffer : public ModulePass {
  1085. public:
  1086. static char ID;
  1087. explicit DxilTranslateRawBuffer() : ModulePass(ID) {}
  1088. bool runOnModule(Module &M) {
  1089. unsigned major, minor;
  1090. DxilModule &DM = M.GetDxilModule();
  1091. DM.GetDxilVersion(major, minor);
  1092. OP *hlslOP = DM.GetOP();
  1093. // Split 64bit for shader model less than 6.3.
  1094. if (major == 1 && minor <= 2) {
  1095. for (auto F = M.functions().begin(); F != M.functions().end();) {
  1096. Function *func = &*(F++);
  1097. DXIL::OpCodeClass opClass;
  1098. if (hlslOP->GetOpCodeClass(func, opClass)) {
  1099. if (opClass == DXIL::OpCodeClass::RawBufferLoad) {
  1100. Type *ETy =
  1101. hlslOP->GetOverloadType(DXIL::OpCode::RawBufferLoad, func);
  1102. bool is64 =
  1103. ETy->isDoubleTy() || ETy == Type::getInt64Ty(ETy->getContext());
  1104. if (is64) {
  1105. ReplaceRawBufferLoad64Bit(func, ETy, M);
  1106. func->eraseFromParent();
  1107. }
  1108. } else if (opClass == DXIL::OpCodeClass::RawBufferStore) {
  1109. Type *ETy =
  1110. hlslOP->GetOverloadType(DXIL::OpCode::RawBufferStore, func);
  1111. bool is64 =
  1112. ETy->isDoubleTy() || ETy == Type::getInt64Ty(ETy->getContext());
  1113. if (is64) {
  1114. ReplaceRawBufferStore64Bit(func, ETy, M);
  1115. func->eraseFromParent();
  1116. }
  1117. }
  1118. }
  1119. }
  1120. }
  1121. if (major == 1 && minor < 2) {
  1122. for (auto F = M.functions().begin(), E = M.functions().end(); F != E;) {
  1123. Function *func = &*(F++);
  1124. if (func->hasName()) {
  1125. if (func->getName().startswith("dx.op.rawBufferLoad")) {
  1126. ReplaceRawBufferLoad(func, M);
  1127. func->eraseFromParent();
  1128. } else if (func->getName().startswith("dx.op.rawBufferStore")) {
  1129. ReplaceRawBufferStore(func, M);
  1130. func->eraseFromParent();
  1131. }
  1132. }
  1133. }
  1134. } else if (M.GetDxilModule().GetUseMinPrecision()) {
  1135. for (auto F = M.functions().begin(), E = M.functions().end(); F != E;) {
  1136. Function *func = &*(F++);
  1137. if (func->hasName()) {
  1138. if (func->getName().startswith("dx.op.rawBufferLoad")) {
  1139. ReplaceMinPrecisionRawBufferLoad(func, M);
  1140. } else if (func->getName().startswith("dx.op.rawBufferStore")) {
  1141. ReplaceMinPrecisionRawBufferStore(func, M);
  1142. }
  1143. }
  1144. }
  1145. }
  1146. return true;
  1147. }
  1148. private:
  1149. // Replace RawBufferLoad/Store to BufferLoad/Store for DXIL < 1.2
  1150. void ReplaceRawBufferLoad(Function *F, Module &M);
  1151. void ReplaceRawBufferStore(Function *F, Module &M);
  1152. void ReplaceRawBufferLoad64Bit(Function *F, Type *EltTy, Module &M);
  1153. void ReplaceRawBufferStore64Bit(Function *F, Type *EltTy, Module &M);
  1154. // Replace RawBufferLoad/Store of min-precision types to have its actual storage size
  1155. void ReplaceMinPrecisionRawBufferLoad(Function *F, Module &M);
  1156. void ReplaceMinPrecisionRawBufferStore(Function *F, Module &M);
  1157. void ReplaceMinPrecisionRawBufferLoadByType(Function *F, Type *FromTy,
  1158. Type *ToTy, OP *Op,
  1159. const DataLayout &DL);
  1160. };
  1161. } // namespace
  1162. void DxilTranslateRawBuffer::ReplaceRawBufferLoad(Function *F,
  1163. Module &M) {
  1164. OP *op = M.GetDxilModule().GetOP();
  1165. Type *RTy = F->getReturnType();
  1166. if (StructType *STy = dyn_cast<StructType>(RTy)) {
  1167. Type *ETy = STy->getElementType(0);
  1168. Function *newFunction = op->GetOpFunc(hlsl::DXIL::OpCode::BufferLoad, ETy);
  1169. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  1170. User *user = *(U++);
  1171. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  1172. IRBuilder<> Builder(CI);
  1173. SmallVector<Value *, 4> args;
  1174. args.emplace_back(op->GetI32Const((unsigned)DXIL::OpCode::BufferLoad));
  1175. for (unsigned i = 1; i < 4; ++i) {
  1176. args.emplace_back(CI->getArgOperand(i));
  1177. }
  1178. CallInst *newCall = Builder.CreateCall(newFunction, args);
  1179. CI->replaceAllUsesWith(newCall);
  1180. CI->eraseFromParent();
  1181. } else {
  1182. DXASSERT(false, "function can only be used with call instructions.");
  1183. }
  1184. }
  1185. } else {
  1186. DXASSERT(false, "RawBufferLoad should return struct type.");
  1187. }
  1188. }
  1189. void DxilTranslateRawBuffer::ReplaceRawBufferLoad64Bit(Function *F, Type *EltTy, Module &M) {
  1190. OP *hlslOP = M.GetDxilModule().GetOP();
  1191. Function *bufLd = hlslOP->GetOpFunc(DXIL::OpCode::RawBufferLoad,
  1192. Type::getInt32Ty(M.getContext()));
  1193. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  1194. User *user = *(U++);
  1195. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  1196. IRBuilder<> Builder(CI);
  1197. SmallVector<Value *, 4> args(CI->arg_operands());
  1198. Value *offset = CI->getArgOperand(
  1199. DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx);
  1200. unsigned size = 0;
  1201. bool bNeedStatus = false;
  1202. for (User *U : CI->users()) {
  1203. ExtractValueInst *Elt = cast<ExtractValueInst>(U);
  1204. DXASSERT(Elt->getNumIndices() == 1, "else invalid use for resRet");
  1205. unsigned idx = Elt->getIndices()[0];
  1206. if (idx == 4) {
  1207. bNeedStatus = true;
  1208. } else {
  1209. size = std::max(size, idx+1);
  1210. }
  1211. }
  1212. unsigned maskHi = 0;
  1213. unsigned maskLo = 0;
  1214. switch (size) {
  1215. case 1:
  1216. maskLo = 3;
  1217. break;
  1218. case 2:
  1219. maskLo = 0xf;
  1220. break;
  1221. case 3:
  1222. maskLo = 0xf;
  1223. maskHi = 3;
  1224. break;
  1225. case 4:
  1226. maskLo = 0xf;
  1227. maskHi = 0xf;
  1228. break;
  1229. }
  1230. args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] =
  1231. Builder.getInt8(maskLo);
  1232. Value *resultElts[5] = {nullptr, nullptr, nullptr, nullptr, nullptr};
  1233. CallInst *newLd = Builder.CreateCall(bufLd, args);
  1234. Value *resultElts32[8];
  1235. unsigned eltBase = 0;
  1236. for (unsigned i = 0; i < size; i++) {
  1237. if (i == 2) {
  1238. // Update offset 4 by 4 bytes.
  1239. if (isa<UndefValue>(offset)) {
  1240. // [RW]ByteAddressBuffer has undef element offset -> update index
  1241. Value *index = CI->getArgOperand(DXIL::OperandIndex::kRawBufferLoadIndexOpIdx);
  1242. args[DXIL::OperandIndex::kRawBufferLoadIndexOpIdx] =
  1243. Builder.CreateAdd(index, Builder.getInt32(4 * 4));
  1244. }
  1245. else {
  1246. // [RW]StructuredBuffer -> update element offset
  1247. args[DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx] =
  1248. Builder.CreateAdd(offset, Builder.getInt32(4 * 4));
  1249. }
  1250. args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] =
  1251. Builder.getInt8(maskHi);
  1252. newLd = Builder.CreateCall(bufLd, args);
  1253. eltBase = 4;
  1254. }
  1255. unsigned resBase = 2 * i;
  1256. resultElts32[resBase] =
  1257. Builder.CreateExtractValue(newLd, resBase - eltBase);
  1258. resultElts32[resBase + 1] =
  1259. Builder.CreateExtractValue(newLd, resBase + 1 - eltBase);
  1260. }
  1261. Make64bitResultForLoad(EltTy, resultElts32, size, resultElts, hlslOP, Builder);
  1262. if (bNeedStatus) {
  1263. resultElts[4] = Builder.CreateExtractValue(newLd, 4);
  1264. }
  1265. for (auto it = CI->user_begin(); it != CI->user_end(); ) {
  1266. ExtractValueInst *Elt = cast<ExtractValueInst>(*(it++));
  1267. DXASSERT(Elt->getNumIndices() == 1, "else invalid use for resRet");
  1268. unsigned idx = Elt->getIndices()[0];
  1269. if (!Elt->user_empty()) {
  1270. Value *newElt = resultElts[idx];
  1271. Elt->replaceAllUsesWith(newElt);
  1272. }
  1273. Elt->eraseFromParent();
  1274. }
  1275. CI->eraseFromParent();
  1276. } else {
  1277. DXASSERT(false, "function can only be used with call instructions.");
  1278. }
  1279. }
  1280. }
  1281. void DxilTranslateRawBuffer::ReplaceRawBufferStore(Function *F,
  1282. Module &M) {
  1283. OP *op = M.GetDxilModule().GetOP();
  1284. DXASSERT(F->getReturnType()->isVoidTy(), "rawBufferStore should return a void type.");
  1285. Type *ETy = F->getFunctionType()->getParamType(4); // value
  1286. Function *newFunction = op->GetOpFunc(hlsl::DXIL::OpCode::BufferStore, ETy);
  1287. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  1288. User *user = *(U++);
  1289. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  1290. IRBuilder<> Builder(CI);
  1291. SmallVector<Value *, 4> args;
  1292. args.emplace_back(op->GetI32Const((unsigned)DXIL::OpCode::BufferStore));
  1293. for (unsigned i = 1; i < 9; ++i) {
  1294. args.emplace_back(CI->getArgOperand(i));
  1295. }
  1296. Builder.CreateCall(newFunction, args);
  1297. CI->eraseFromParent();
  1298. }
  1299. else {
  1300. DXASSERT(false, "function can only be used with call instructions.");
  1301. }
  1302. }
  1303. }
  1304. void DxilTranslateRawBuffer::ReplaceRawBufferStore64Bit(Function *F, Type *ETy,
  1305. Module &M) {
  1306. OP *hlslOP = M.GetDxilModule().GetOP();
  1307. Function *newFunction = hlslOP->GetOpFunc(hlsl::DXIL::OpCode::RawBufferStore,
  1308. Type::getInt32Ty(M.getContext()));
  1309. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  1310. User *user = *(U++);
  1311. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  1312. IRBuilder<> Builder(CI);
  1313. SmallVector<Value *, 4> args(CI->arg_operands());
  1314. Value *vals[4] = {
  1315. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal0OpIdx),
  1316. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal1OpIdx),
  1317. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal2OpIdx),
  1318. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal3OpIdx)};
  1319. ConstantInt *cMask = cast<ConstantInt>(
  1320. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreMaskOpIdx));
  1321. Value *undefI32 = UndefValue::get(Builder.getInt32Ty());
  1322. Value *vals32[8] = {undefI32, undefI32, undefI32, undefI32,
  1323. undefI32, undefI32, undefI32, undefI32};
  1324. unsigned maskLo = 0;
  1325. unsigned maskHi = 0;
  1326. unsigned size = 0;
  1327. unsigned mask = cMask->getLimitedValue();
  1328. switch (mask) {
  1329. case 1:
  1330. maskLo = 3;
  1331. size = 1;
  1332. break;
  1333. case 3:
  1334. maskLo = 15;
  1335. size = 2;
  1336. break;
  1337. case 7:
  1338. maskLo = 15;
  1339. maskHi = 3;
  1340. size = 3;
  1341. break;
  1342. case 15:
  1343. maskLo = 15;
  1344. maskHi = 15;
  1345. size = 4;
  1346. break;
  1347. default:
  1348. DXASSERT(0, "invalid mask");
  1349. }
  1350. Split64bitValForStore(ETy, vals, size, vals32, hlslOP, Builder);
  1351. args[DXIL::OperandIndex::kRawBufferStoreMaskOpIdx] =
  1352. Builder.getInt8(maskLo);
  1353. args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx] = vals32[0];
  1354. args[DXIL::OperandIndex::kRawBufferStoreVal1OpIdx] = vals32[1];
  1355. args[DXIL::OperandIndex::kRawBufferStoreVal2OpIdx] = vals32[2];
  1356. args[DXIL::OperandIndex::kRawBufferStoreVal3OpIdx] = vals32[3];
  1357. Builder.CreateCall(newFunction, args);
  1358. if (maskHi) {
  1359. // Update offset 4 by 4 bytes.
  1360. Value *offset = args[DXIL::OperandIndex::kBufferStoreCoord1OpIdx];
  1361. if (isa<UndefValue>(offset)) {
  1362. // [RW]ByteAddressBuffer has element offset == undef -> update index instead
  1363. Value *index = args[DXIL::OperandIndex::kBufferStoreCoord0OpIdx];
  1364. index = Builder.CreateAdd(index, Builder.getInt32(4 * 4));
  1365. args[DXIL::OperandIndex::kRawBufferStoreIndexOpIdx] = index;
  1366. }
  1367. else {
  1368. // [RW]StructuredBuffer -> update element offset
  1369. offset = Builder.CreateAdd(offset, Builder.getInt32(4 * 4));
  1370. args[DXIL::OperandIndex::kRawBufferStoreElementOffsetOpIdx] = offset;
  1371. }
  1372. args[DXIL::OperandIndex::kRawBufferStoreMaskOpIdx] =
  1373. Builder.getInt8(maskHi);
  1374. args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx] = vals32[4];
  1375. args[DXIL::OperandIndex::kRawBufferStoreVal1OpIdx] = vals32[5];
  1376. args[DXIL::OperandIndex::kRawBufferStoreVal2OpIdx] = vals32[6];
  1377. args[DXIL::OperandIndex::kRawBufferStoreVal3OpIdx] = vals32[7];
  1378. Builder.CreateCall(newFunction, args);
  1379. }
  1380. CI->eraseFromParent();
  1381. } else {
  1382. DXASSERT(false, "function can only be used with call instructions.");
  1383. }
  1384. }
  1385. }
  1386. void DxilTranslateRawBuffer::ReplaceMinPrecisionRawBufferLoad(Function *F,
  1387. Module &M) {
  1388. OP *Op = M.GetDxilModule().GetOP();
  1389. Type *RetTy = F->getReturnType();
  1390. if (StructType *STy = dyn_cast<StructType>(RetTy)) {
  1391. Type *EltTy = STy->getElementType(0);
  1392. if (EltTy->isHalfTy()) {
  1393. ReplaceMinPrecisionRawBufferLoadByType(F, Type::getHalfTy(M.getContext()),
  1394. Type::getFloatTy(M.getContext()),
  1395. Op, M.getDataLayout());
  1396. } else if (EltTy == Type::getInt16Ty(M.getContext())) {
  1397. ReplaceMinPrecisionRawBufferLoadByType(
  1398. F, Type::getInt16Ty(M.getContext()), Type::getInt32Ty(M.getContext()),
  1399. Op, M.getDataLayout());
  1400. }
  1401. } else {
  1402. DXASSERT(false, "RawBufferLoad should return struct type.");
  1403. }
  1404. }
  1405. void DxilTranslateRawBuffer::ReplaceMinPrecisionRawBufferStore(Function *F,
  1406. Module &M) {
  1407. DXASSERT(F->getReturnType()->isVoidTy(), "rawBufferStore should return a void type.");
  1408. Type *ETy = F->getFunctionType()->getParamType(4); // value
  1409. Type *NewETy;
  1410. if (ETy->isHalfTy()) {
  1411. NewETy = Type::getFloatTy(M.getContext());
  1412. }
  1413. else if (ETy == Type::getInt16Ty(M.getContext())) {
  1414. NewETy = Type::getInt32Ty(M.getContext());
  1415. }
  1416. else {
  1417. return; // not a min precision type
  1418. }
  1419. Function *newFunction = M.GetDxilModule().GetOP()->GetOpFunc(
  1420. DXIL::OpCode::RawBufferStore, NewETy);
  1421. // for each function
  1422. // add argument 4-7 to its upconverted values
  1423. // replace function call
  1424. for (auto FuncUser = F->user_begin(), FuncEnd = F->user_end(); FuncUser != FuncEnd;) {
  1425. CallInst *CI = dyn_cast<CallInst>(*(FuncUser++));
  1426. DXASSERT(CI, "function user must be a call instruction.");
  1427. IRBuilder<> CIBuilder(CI);
  1428. SmallVector<Value *, 9> Args;
  1429. for (unsigned i = 0; i < 4; ++i) {
  1430. Args.emplace_back(CI->getArgOperand(i));
  1431. }
  1432. // values to store should be converted to its higher precision types
  1433. if (ETy->isHalfTy()) {
  1434. for (unsigned i = 4; i < 8; ++i) {
  1435. Value *NewV = CIBuilder.CreateFPExt(CI->getArgOperand(i),
  1436. Type::getFloatTy(M.getContext()));
  1437. Args.emplace_back(NewV);
  1438. }
  1439. }
  1440. else if (ETy == Type::getInt16Ty(M.getContext())) {
  1441. // This case only applies to typed buffer since Store operation of byte
  1442. // address buffer for min precision is handled by implicit conversion on
  1443. // intrinsic call. Since we are extending integer, we have to know if we
  1444. // should sign ext or zero ext. We can do this by iterating checking the
  1445. // size of the element at struct type and comp type at type annotation
  1446. CallInst *handleCI = dyn_cast<CallInst>(CI->getArgOperand(1));
  1447. DXASSERT(handleCI, "otherwise handle was not an argument to buffer store.");
  1448. ConstantInt *resClass = dyn_cast<ConstantInt>(handleCI->getArgOperand(1));
  1449. DXASSERT_LOCALVAR(resClass, resClass && resClass->getSExtValue() ==
  1450. (unsigned)DXIL::ResourceClass::UAV,
  1451. "otherwise buffer store called on non uav kind.");
  1452. ConstantInt *rangeID = dyn_cast<ConstantInt>(handleCI->getArgOperand(2)); // range id or idx?
  1453. DXASSERT(rangeID, "wrong createHandle call.");
  1454. DxilResource dxilRes = M.GetDxilModule().GetUAV(rangeID->getSExtValue());
  1455. StructType *STy = dyn_cast<StructType>(dxilRes.GetRetType());
  1456. DxilStructAnnotation *SAnnot = M.GetDxilModule().GetTypeSystem().GetStructAnnotation(STy);
  1457. ConstantInt *offsetInt = dyn_cast<ConstantInt>(CI->getArgOperand(3));
  1458. unsigned offset = offsetInt->getSExtValue();
  1459. unsigned currentOffset = 0;
  1460. for (DxilStructTypeIterator iter = begin(STy, SAnnot), ItEnd = end(STy, SAnnot); iter != ItEnd; ++iter) {
  1461. std::pair<Type *, DxilFieldAnnotation*> pair = *iter;
  1462. currentOffset += M.getDataLayout().getTypeAllocSize(pair.first);
  1463. if (currentOffset > offset) {
  1464. if (pair.second->GetCompType().IsUIntTy()) {
  1465. for (unsigned i = 4; i < 8; ++i) {
  1466. Value *NewV = CIBuilder.CreateZExt(CI->getArgOperand(i), Type::getInt32Ty(M.getContext()));
  1467. Args.emplace_back(NewV);
  1468. }
  1469. break;
  1470. }
  1471. else if (pair.second->GetCompType().IsIntTy()) {
  1472. for (unsigned i = 4; i < 8; ++i) {
  1473. Value *NewV = CIBuilder.CreateSExt(CI->getArgOperand(i), Type::getInt32Ty(M.getContext()));
  1474. Args.emplace_back(NewV);
  1475. }
  1476. break;
  1477. }
  1478. else {
  1479. DXASSERT(false, "Invalid comp type");
  1480. }
  1481. }
  1482. }
  1483. }
  1484. // mask
  1485. Args.emplace_back(CI->getArgOperand(8));
  1486. // alignment
  1487. Args.emplace_back(M.GetDxilModule().GetOP()->GetI32Const(
  1488. M.getDataLayout().getTypeAllocSize(NewETy)));
  1489. CIBuilder.CreateCall(newFunction, Args);
  1490. CI->eraseFromParent();
  1491. }
  1492. }
  1493. void DxilTranslateRawBuffer::ReplaceMinPrecisionRawBufferLoadByType(
  1494. Function *F, Type *FromTy, Type *ToTy, OP *Op, const DataLayout &DL) {
  1495. Function *newFunction = Op->GetOpFunc(DXIL::OpCode::RawBufferLoad, ToTy);
  1496. for (auto FUser = F->user_begin(), FEnd = F->user_end(); FUser != FEnd;) {
  1497. User *UserCI = *(FUser++);
  1498. if (CallInst *CI = dyn_cast<CallInst>(UserCI)) {
  1499. IRBuilder<> CIBuilder(CI);
  1500. SmallVector<Value *, 5> newFuncArgs;
  1501. // opcode, handle, index, elementOffset, mask
  1502. // Compiler is generating correct element offset even for min precision types
  1503. // So no need to recalculate here
  1504. for (unsigned i = 0; i < 5; ++i) {
  1505. newFuncArgs.emplace_back(CI->getArgOperand(i));
  1506. }
  1507. // new alignment for new type
  1508. newFuncArgs.emplace_back(Op->GetI32Const(DL.getTypeAllocSize(ToTy)));
  1509. CallInst *newCI = CIBuilder.CreateCall(newFunction, newFuncArgs);
  1510. for (auto CIUser = CI->user_begin(), CIEnd = CI->user_end();
  1511. CIUser != CIEnd;) {
  1512. User *UserEV = *(CIUser++);
  1513. if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(UserEV)) {
  1514. IRBuilder<> EVBuilder(EV);
  1515. ArrayRef<unsigned> Indices = EV->getIndices();
  1516. DXASSERT(Indices.size() == 1, "Otherwise we have wrong extract value.");
  1517. Value *newEV = EVBuilder.CreateExtractValue(newCI, Indices);
  1518. Value *newTruncV = nullptr;
  1519. if (4 == Indices[0]) { // Don't truncate status
  1520. newTruncV = newEV;
  1521. }
  1522. else if (FromTy->isHalfTy()) {
  1523. newTruncV = EVBuilder.CreateFPTrunc(newEV, FromTy);
  1524. } else if (FromTy->isIntegerTy()) {
  1525. newTruncV = EVBuilder.CreateTrunc(newEV, FromTy);
  1526. } else {
  1527. DXASSERT(false, "unexpected type conversion");
  1528. }
  1529. EV->replaceAllUsesWith(newTruncV);
  1530. EV->eraseFromParent();
  1531. }
  1532. }
  1533. CI->eraseFromParent();
  1534. }
  1535. }
  1536. F->eraseFromParent();
  1537. }
  1538. char DxilTranslateRawBuffer::ID = 0;
  1539. ModulePass *llvm::createDxilTranslateRawBuffer() {
  1540. return new DxilTranslateRawBuffer();
  1541. }
  1542. INITIALIZE_PASS(DxilTranslateRawBuffer, "hlsl-translate-dxil-raw-buffer",
  1543. "Translate raw buffer load", false, false)