DxilPatchShaderRecordBindings.cpp 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPatchShaderRecordBindings.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides a pass used by the RayTracing Fallback Lyaer to add modify //
  9. // bindings to pull local root signature parameters from a global //
  10. // "shader table" buffer instead //
  11. // //
  12. ///////////////////////////////////////////////////////////////////////////////
  13. #include "dxc/HLSL/DxilGenerationPass.h"
  14. #include "dxc/HLSL/DxilFallbackLayerPass.h"
  15. #include "dxc/HLSL/DxilOperations.h"
  16. #include "dxc/HLSL/DxilSignatureElement.h"
  17. #include "dxc/HLSL/DxilFunctionProps.h"
  18. #include "dxc/HLSL/DxilModule.h"
  19. #include "dxc/Support/Global.h"
  20. #include "dxc/Support/Unicode.h"
  21. #include "dxc/HLSL/DxilTypeSystem.h"
  22. #include "dxc/HLSL/DxilConstants.h"
  23. #include "dxc/HLSL/DxilInstructions.h"
  24. #include "dxc/HLSL/DxilSpanAllocator.h"
  25. #include "dxc/HLSL/DxilRootSignature.h"
  26. #include "dxc/HLSL/DxilUtil.h"
  27. #include "llvm/Transforms/Utils/Cloning.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/IntrinsicInst.h"
  30. #include "llvm/IR/InstIterator.h"
  31. #include "llvm/IR/Module.h"
  32. #include "llvm/IR/PassManager.h"
  33. #include "llvm/ADT/BitVector.h"
  34. #include "llvm/Pass.h"
  35. #include "llvm/Transforms/Utils/Local.h"
  36. #include "llvm/Transforms/Scalar.h"
  37. #include <memory>
  38. #include <unordered_set>
  39. #include <functional>
  40. #include <unordered_map>
  41. #include <array>
  42. struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
  43. #include "DxilPatchShaderRecordBindingsShared.h"
  44. using namespace llvm;
  45. using namespace hlsl;
  46. bool operator==(const ViewKey &a, const ViewKey &b) {
  47. return memcmp(&a, &b, sizeof(a)) == 0;
  48. }
  49. const size_t SizeofD3D12GpuVA = sizeof(uint64_t);
  50. const size_t SizeofD3D12GpuDescriptorHandle = sizeof(uint64_t);
  51. Function *CloneFunction(Function *Orig,
  52. const llvm::Twine &Name,
  53. llvm::Module *llvmModule) {
  54. Function *F = Function::Create(Orig->getFunctionType(),
  55. GlobalValue::LinkageTypes::ExternalLinkage,
  56. Name, llvmModule);
  57. SmallVector<ReturnInst *, 2> Returns;
  58. ValueToValueMapTy vmap;
  59. // Map params.
  60. auto entryParamIt = F->arg_begin();
  61. for (Argument &param : Orig->args()) {
  62. vmap[&param] = (entryParamIt++);
  63. }
  64. DxilModule &DM = llvmModule->GetOrCreateDxilModule();
  65. llvm::CloneFunctionInto(F, Orig, vmap, /*ModuleLevelChagnes*/ false, Returns);
  66. DM.GetTypeSystem().CopyFunctionAnnotation(F, Orig, DM.GetTypeSystem());
  67. if (DM.HasDxilFunctionProps(F)) {
  68. DM.CloneDxilEntryProps(Orig, F);
  69. }
  70. return F;
  71. }
  72. struct ShaderRecordEntry {
  73. DxilRootParameterType ParameterType;
  74. unsigned int RecordOffsetInBytes;
  75. unsigned int OffsetInDescriptors; // Only valid for descriptor tables
  76. static ShaderRecordEntry InvalidEntry() { return { (DxilRootParameterType)-1, (unsigned int)-1 }; }
  77. bool IsInvalid() { return (unsigned int)ParameterType == (unsigned int)-1; }
  78. };
  79. struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
  80. class DxilPatchShaderRecordBindings : public ModulePass {
  81. public:
  82. static char ID; // Pass identification, replacement for typeid
  83. explicit DxilPatchShaderRecordBindings() : ModulePass(ID) {}
  84. const char *getPassName() const override { return "DXIL Patch Shader Record Binding"; }
  85. void applyOptions(PassOptions O) override;
  86. bool runOnModule(Module &M) override;
  87. private:
  88. void ValidateParameters();
  89. void AddInputBinding(Module &M);
  90. void PatchShaderBindings(Module &M);
  91. void InitializeViewTable();
  92. unsigned int AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName);
  93. unsigned int AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type = nullptr, unsigned int constantBufferSize = 0);
  94. unsigned int AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type);
  95. unsigned int AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName);
  96. llvm::Value *CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes);
  97. llvm::Value *CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type);
  98. llvm::Value *CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction);
  99. llvm::Value *CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad = 0);
  100. llvm::Value *LoadShaderRecordData(Module &M, IRBuilder<> &Builder,
  101. llvm::Value *offsetToShaderRecord,
  102. unsigned int dataOffsetInShaderRecord);
  103. void PatchCreateHandleToUseDescriptorIndex(
  104. _In_ Module &M,
  105. _In_ IRBuilder<> &Builder,
  106. _In_ DXIL::ResourceKind &resourceKind,
  107. _In_ DXIL::ResourceClass &resourceClass,
  108. _In_ llvm::Type *resourceType,
  109. _In_ llvm::Value *descriptorIndex,
  110. _Inout_ DxilInst_CreateHandleForLib &createHandleInstr);
  111. bool GetHandleInfo(
  112. Module &M,
  113. DxilInst_CreateHandleForLib &createHandleStructForLib,
  114. _Out_ unsigned int &shaderRegister,
  115. _Out_ unsigned int &registerSpace,
  116. _Out_ DXIL::ResourceKind &kind,
  117. _Out_ DXIL::ResourceClass &resClass,
  118. _Out_ llvm::Type *&resType);
  119. llvm::Value * GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind);
  120. unsigned int GetConstantBufferOffsetToShaderRecord();
  121. bool IsCBufferLoad(llvm::Instruction *instruction);
  122. // Unlike the LLVM version of this function, this does not requires the InstructionToReplace and the ValueToReplaceWith to be the same instruction type
  123. static void ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith);
  124. static ShaderRecordEntry FindRootSignatureDescriptor(const DxilVersionedRootSignatureDesc &rootSignatureDescriptor, unsigned int ShaderRecordIdentifierSizeInBytes, DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex, unsigned int registerSpace);
  125. // TODO: I would like to see these prefixed with m_
  126. llvm::Value *ShaderTableHandle = nullptr;
  127. llvm::Value *DispatchRaysConstantsHandle = nullptr;
  128. llvm::Value *BaseShaderRecordOffset = nullptr;
  129. static const unsigned int NumViewTypes = 4;
  130. struct ViewKeyHasher
  131. {
  132. public:
  133. std::size_t operator()(const ViewKey &x) const {
  134. return std::hash<unsigned int>()((unsigned int)x.ViewType) ^
  135. std::hash<unsigned int>()((unsigned int)x.StructuredStride);
  136. }
  137. };
  138. std::unordered_map<ViewKey, llvm::Value *, ViewKeyHasher>
  139. TypeToAliasedDescriptorHeap[NumViewTypes];
  140. llvm::Function *EntryPointFunction;
  141. ShaderInfo *pInputShaderInfo;
  142. DxilVersionedRootSignatureDesc *pRootSignatureDesc;
  143. DXIL::ShaderKind ShaderKind;
  144. };
  145. char DxilPatchShaderRecordBindings::ID = 0;
  146. // TODO: Find the right thing to do on failure
  147. void ThrowFailure() {
  148. throw std::exception();
  149. }
  150. // TODO: Stolen from Brandon's code, merge
  151. // Remove ELF mangling
  152. static inline std::string GetUnmangledName(StringRef name) {
  153. if (!name.startswith("\x1?"))
  154. return name;
  155. size_t pos = name.find("@@");
  156. if (pos == name.npos)
  157. return name;
  158. return name.substr(2, pos - 2);
  159. }
  160. static Function* getFunctionFromName(Module &M, const std::wstring& exportName) {
  161. for (auto F = M.begin(), E = M.end(); F != E; ++F) {
  162. std::wstring functionName = Unicode::UTF8ToUTF16StringOrThrow(GetUnmangledName(F->getName()).c_str());
  163. if (exportName == functionName) {
  164. return F;
  165. }
  166. }
  167. return nullptr;
  168. }
  169. ModulePass *llvm::createDxilPatchShaderRecordBindingsPass() {
  170. return new DxilPatchShaderRecordBindings();
  171. }
  172. INITIALIZE_PASS(DxilPatchShaderRecordBindings, "hlsl-dxil-patch-shader-record-bindings", "Patch shader record bindings to instead pull from the fallback provided bindings", false, false)
  173. void DxilPatchShaderRecordBindings::applyOptions(PassOptions O) {
  174. for (const auto & option : O) {
  175. if (0 == option.first.compare("root-signature")) {
  176. unsigned int cHexRadix = 16;
  177. pInputShaderInfo = (ShaderInfo*)strtoull(option.second.data(), nullptr, cHexRadix);
  178. pRootSignatureDesc = (DxilVersionedRootSignatureDesc*)pInputShaderInfo->pRootSignatureDesc;
  179. }
  180. }
  181. }
  182. void AddAnnoationsIfNeeded(DxilModule &DM, llvm::StructType *StructTy, const std::string &FieldName, unsigned int numFields = 1)
  183. {
  184. auto pAnnotation = DM.GetTypeSystem().GetStructAnnotation(StructTy);
  185. if (pAnnotation == nullptr)
  186. {
  187. pAnnotation = DM.GetTypeSystem().AddStructAnnotation(StructTy);
  188. pAnnotation->SetCBufferSize(sizeof(uint32_t) * numFields);
  189. for (unsigned int i = 0; i < numFields; i++)
  190. {
  191. pAnnotation->GetFieldAnnotation(i).SetCBufferOffset(sizeof(uint32_t) * i);
  192. pAnnotation->GetFieldAnnotation(i).SetCompType(hlsl::DXIL::ComponentType::I32);
  193. pAnnotation->GetFieldAnnotation(i).SetFieldName(FieldName + std::to_string(i));
  194. }
  195. }
  196. }
  197. unsigned int DxilPatchShaderRecordBindings::AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type, unsigned int constantBufferSize) {
  198. LLVMContext & Ctx = M.getContext();
  199. DxilModule &DM = M.GetOrCreateDxilModule();
  200. // Set up a SRV with byte address buffer
  201. unsigned int resourceHandle;
  202. std::unique_ptr<DxilResource> pHandle;
  203. std::unique_ptr<DxilCBuffer> pCBuf;
  204. std::unique_ptr<DxilSampler> pSampler;
  205. DxilResourceBase *pBaseHandle;
  206. switch (resClass) {
  207. case DXIL::ResourceClass::SRV:
  208. resourceHandle = static_cast<unsigned int>(DM.GetSRVs().size());
  209. pHandle = llvm::make_unique<DxilResource>();
  210. pHandle->SetRW(false);
  211. pBaseHandle = pHandle.get();
  212. break;
  213. case DXIL::ResourceClass::UAV:
  214. resourceHandle = static_cast<unsigned int>(DM.GetUAVs().size());
  215. pHandle = llvm::make_unique<DxilResource>();
  216. pHandle->SetRW(true);
  217. pBaseHandle = pHandle.get();
  218. break;
  219. case DXIL::ResourceClass::CBuffer:
  220. resourceHandle = static_cast<unsigned int>(DM.GetCBuffers().size());
  221. pCBuf = llvm::make_unique<DxilCBuffer>();
  222. pCBuf->SetSize(constantBufferSize);
  223. pBaseHandle = pCBuf.get();
  224. break;
  225. case DXIL::ResourceClass::Sampler:
  226. resourceHandle = static_cast<unsigned int>(DM.GetSamplers().size());
  227. pSampler = llvm::make_unique<DxilSampler>();
  228. // TODO: Is this okay? What if one of the samplers in the table is a comparison sampler?
  229. pSampler->SetSamplerKind(DxilSampler::SamplerKind::Default);
  230. pBaseHandle = pSampler.get();
  231. break;
  232. }
  233. if (!type) {
  234. SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
  235. std::string ByteAddressBufferName = "struct.ByteAddressBuffer";
  236. type = M.getTypeByName(ByteAddressBufferName);
  237. if (!type)
  238. {
  239. StructType *StructTy;
  240. type = StructTy = StructType::create(Elements, ByteAddressBufferName);
  241. AddAnnoationsIfNeeded(DM, StructTy, ByteAddressBufferName);
  242. }
  243. }
  244. GlobalVariable *GV = M.getGlobalVariable(bufferName);
  245. if (!GV) {
  246. GV = cast<GlobalVariable>(M.getOrInsertGlobal(bufferName, type));
  247. }
  248. pBaseHandle->SetGlobalName(bufferName.c_str());
  249. pBaseHandle->SetGlobalSymbol(GV);
  250. pBaseHandle->SetID(resourceHandle);
  251. pBaseHandle->SetSpaceID(registerSpace);
  252. pBaseHandle->SetLowerBound(baseRegisterIndex);
  253. pBaseHandle->SetRangeSize(rangeSize);
  254. pBaseHandle->SetKind(resKind);
  255. if (pHandle) {
  256. pHandle->SetGloballyCoherent(false);
  257. pHandle->SetHasCounter(false);
  258. pHandle->SetCompType(CompType::getF32()); // TODO: Need to handle all types
  259. }
  260. unsigned int ID;
  261. switch (resClass) {
  262. case DXIL::ResourceClass::SRV:
  263. ID = DM.AddSRV(std::move(pHandle));
  264. break;
  265. case DXIL::ResourceClass::UAV:
  266. ID = DM.AddUAV(std::move(pHandle));
  267. break;
  268. case DXIL::ResourceClass::CBuffer:
  269. ID = DM.AddCBuffer(std::move(pCBuf));
  270. break;
  271. case DXIL::ResourceClass::Sampler:
  272. ID = DM.AddSampler(std::move(pSampler));
  273. break;
  274. }
  275. assert(ID == resourceHandle);
  276. return ID;
  277. }
  278. unsigned int DxilPatchShaderRecordBindings::GetConstantBufferOffsetToShaderRecord()
  279. {
  280. switch (ShaderKind)
  281. {
  282. case DXIL::ShaderKind::ClosestHit:
  283. case DXIL::ShaderKind::AnyHit:
  284. case DXIL::ShaderKind::Intersection:
  285. return offsetof(DispatchRaysConstants, HitGroupShaderRecordStride);
  286. case DXIL::ShaderKind::Miss:
  287. return offsetof(DispatchRaysConstants, MissShaderRecordStride);
  288. default:
  289. ThrowFailure();
  290. return -1;
  291. }
  292. }
  293. unsigned int DxilPatchShaderRecordBindings::AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName) {
  294. return AddHandle(M, registerIndex, 1, registerSpace, DXIL::ResourceClass::SRV, DXIL::ResourceKind::RawBuffer, bufferName);
  295. }
  296. llvm::Constant *GetArraySymbol(Module &M, const std::string &bufferName) {
  297. LLVMContext & Ctx = M.getContext();
  298. SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
  299. llvm::StructType *StructTy = llvm::StructType::create(Elements, bufferName);
  300. llvm::ArrayType *ArrayTy = ArrayType::get(StructTy, -1);
  301. return UndefValue::get(ArrayTy->getPointerTo());
  302. }
  303. unsigned int DxilPatchShaderRecordBindings::AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName) {
  304. const unsigned int maxConstantBufferSize = 4096 * 16;
  305. return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, bufferName, GetArraySymbol(M, bufferName)->getType(), maxConstantBufferSize);
  306. }
  307. unsigned int DxilPatchShaderRecordBindings::AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type) {
  308. return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, resClass, resKind, bufferName, type);
  309. }
  310. // TODO: Stolen from Brandon's code
  311. DXIL::ShaderKind GetRayShaderKindCopy(Function* F)
  312. {
  313. if (F->hasFnAttribute("exp-shader"))
  314. return DXIL::ShaderKind::RayGeneration;
  315. DxilModule& DM = F->getParent()->GetDxilModule();
  316. if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
  317. return DM.GetDxilFunctionProps(F).shaderKind;
  318. return DXIL::ShaderKind::Invalid;
  319. }
  320. static std::string ws2s(const std::wstring& wide)
  321. {
  322. return std::string(wide.begin(), wide.end());
  323. }
  324. bool DxilPatchShaderRecordBindings::runOnModule(Module &M) {
  325. DxilModule &DM = M.GetOrCreateDxilModule();
  326. EntryPointFunction = pInputShaderInfo->ExportName ? getFunctionFromName(M, pInputShaderInfo->ExportName) : DM.GetEntryFunction();
  327. ShaderKind = GetRayShaderKindCopy(EntryPointFunction);
  328. ValidateParameters();
  329. InitializeViewTable();
  330. PatchShaderBindings(M);
  331. DM.ReEmitDxilResources();
  332. return true;
  333. }
  334. void DxilPatchShaderRecordBindings::ValidateParameters() {
  335. if (!pInputShaderInfo || !pInputShaderInfo->pRootSignatureDesc) {
  336. throw std::exception();
  337. }
  338. }
  339. DxilResourceBase &GetResourceFromID(DxilModule &DM, DXIL::ResourceClass resClass, unsigned int id)
  340. {
  341. switch (resClass)
  342. {
  343. case DXIL::ResourceClass::CBuffer:
  344. return DM.GetCBuffer(id);
  345. break;
  346. case DXIL::ResourceClass::SRV:
  347. return DM.GetSRV(id);
  348. break;
  349. case DXIL::ResourceClass::UAV:
  350. return DM.GetUAV(id);
  351. break;
  352. case DXIL::ResourceClass::Sampler:
  353. return DM.GetSampler(id);
  354. break;
  355. default:
  356. ThrowFailure();
  357. return *(DxilResourceBase*)nullptr;
  358. }
  359. }
  360. unsigned int FindOrInsertViewIntoList(const ViewKey &key, ViewKey *pViewList, unsigned int &numViews, unsigned int maxViews)
  361. {
  362. unsigned int viewIndex = 0;
  363. for (; viewIndex < numViews; viewIndex++)
  364. {
  365. if (pViewList[viewIndex] == key)
  366. {
  367. break;
  368. }
  369. }
  370. if (viewIndex == numViews)
  371. {
  372. if (viewIndex >= maxViews) {
  373. ThrowFailure();
  374. }
  375. pViewList[viewIndex] = key;
  376. numViews++;
  377. }
  378. return viewIndex;
  379. }
  380. llvm::Value *DxilPatchShaderRecordBindings::GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *type, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind)
  381. {
  382. DxilModule &DM = M.GetOrCreateDxilModule();
  383. unsigned int resClassIndex = (unsigned int)resClass;
  384. ViewKey key = {};
  385. key.ViewType = (unsigned int)resKind;
  386. if (resKind == DXIL::ResourceKind::StructuredBuffer)
  387. {
  388. key.StructuredStride = type->getPrimitiveSizeInBits();
  389. } else if (resKind != DXIL::ResourceKind::RawBuffer)
  390. {
  391. auto containedType = type->getContainedType(0);
  392. // If it's a vector, get the type of just a single element
  393. if (containedType->getNumContainedTypes() > 0)
  394. {
  395. assert(containedType->getNumContainedTypes() <= 4);
  396. containedType = containedType->getContainedType(0);
  397. }
  398. key.SRVComponentType = (unsigned int)CompType::GetCompType(containedType).GetKind();
  399. }
  400. auto aliasedDescriptorHeapHandle = TypeToAliasedDescriptorHeap[resClassIndex].find(key);
  401. if (aliasedDescriptorHeapHandle == TypeToAliasedDescriptorHeap[resClassIndex].end())
  402. {
  403. unsigned int registerSpaceOffset = 0;
  404. std::string HandleName;
  405. if (resClass == DXIL::ResourceClass::SRV)
  406. {
  407. registerSpaceOffset = FindOrInsertViewIntoList(
  408. key,
  409. pInputShaderInfo->pSRVRegisterSpaceArray,
  410. *pInputShaderInfo->pNumSRVSpaces,
  411. FallbackLayerNumDescriptorHeapSpacesPerView);
  412. HandleName = std::string("SRVDescriptorHeapTable") +
  413. std::to_string(registerSpaceOffset);
  414. }
  415. else if (resClass == DXIL::ResourceClass::UAV)
  416. {
  417. registerSpaceOffset = FindOrInsertViewIntoList(
  418. key,
  419. pInputShaderInfo->pUAVRegisterSpaceArray,
  420. *pInputShaderInfo->pNumUAVSpaces,
  421. FallbackLayerNumDescriptorHeapSpacesPerView);
  422. if (registerSpaceOffset == 0)
  423. {
  424. // Using the descriptor heap declared by the fallback for handling emulated pointers,
  425. // make sure the name is an exact match
  426. assert(key.ViewType == (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer);
  427. HandleName = "\01?DescriptorHeapBufferTable@@3PAURWByteAddressBuffer@@A";
  428. }
  429. else
  430. {
  431. HandleName = std::string("UAVDescriptorHeapTable") +
  432. std::to_string(registerSpaceOffset);
  433. }
  434. }
  435. else if (resClass == DXIL::ResourceClass::CBuffer)
  436. {
  437. HandleName = std::string("CBVDescriptorHeapTable");
  438. } else {
  439. HandleName = std::string("SamplerDescriptorHeapTable");
  440. }
  441. llvm::ArrayType *descriptorHeapType = ArrayType::get(type, 0);
  442. static unsigned int i = 0;
  443. unsigned int id = AddAliasedHandle(M, FallbackLayerDescriptorHeapTable, FallbackLayerRegisterSpace + FallbackLayerDescriptorHeapSpaceOffset + registerSpaceOffset, resClass, resKind, HandleName, descriptorHeapType);
  444. TypeToAliasedDescriptorHeap[resClassIndex][key] = GetResourceFromID(DM, resClass, id).GetGlobalSymbol();
  445. }
  446. return TypeToAliasedDescriptorHeap[resClassIndex][key];
  447. }
  448. void DxilPatchShaderRecordBindings::AddInputBinding(Module &M) {
  449. DxilModule &DM = M.GetOrCreateDxilModule();
  450. auto & EntryBlock = EntryPointFunction->getEntryBlock();
  451. auto & Instructions = EntryBlock.getInstList();
  452. std::string bufferName;
  453. unsigned int bufferRegister;
  454. switch (ShaderKind) {
  455. case DXIL::ShaderKind::AnyHit:
  456. case DXIL::ShaderKind::ClosestHit:
  457. case DXIL::ShaderKind::Intersection:
  458. bufferRegister = FallbackLayerHitGroupRecordByteAddressBufferRegister;
  459. bufferName = "\01?HitGroupShaderTable@@3UByteAddressBuffer@@A";
  460. break;
  461. case DXIL::ShaderKind::Miss:
  462. bufferRegister = FallbackLayerMissShaderRecordByteAddressBufferRegister;
  463. bufferName = "\01?MissShaderTable@@3UByteAddressBuffer@@A";
  464. break;
  465. case DXIL::ShaderKind::RayGeneration:
  466. bufferRegister = FallbackLayerRayGenShaderRecordByteAddressBufferRegister;
  467. bufferName = "\01?RayGenShaderTable@@3UByteAddressBuffer@@A";
  468. break;
  469. case DXIL::ShaderKind::Callable:
  470. bufferRegister = FallbackLayerCallableShaderRecordByteAddressBufferRegister;
  471. bufferName = "\01?CallableShaderTable@@3UByteAddressBuffer@@A";
  472. break;
  473. }
  474. unsigned int ShaderRecordID = AddSRVRawBuffer(M, bufferRegister, FallbackLayerRegisterSpace, bufferName);
  475. auto It = Instructions.begin();
  476. OP *HlslOP = DM.GetOP();
  477. LLVMContext & Ctx = M.getContext();
  478. IRBuilder<> Builder(It);
  479. {
  480. auto ShaderTableName = "ShaderTableHandle";
  481. llvm::Value *Symbol = DM.GetSRV(ShaderRecordID).GetGlobalSymbol();
  482. llvm::Value *Load = Builder.CreateLoad(Symbol, "LoadShaderTableHandle");
  483. Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
  484. Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
  485. ShaderTableHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, ShaderTableName);
  486. }
  487. {
  488. auto CbufferName = "Constants";
  489. const unsigned int sizeOfConstantsInBytes = sizeof(DispatchRaysConstants);
  490. llvm::StructType *StructTy= M.getTypeByName(CbufferName);
  491. if (!StructTy)
  492. {
  493. const unsigned int numUintsInConstants = sizeOfConstantsInBytes / sizeof(unsigned int);
  494. SmallVector<llvm::Type*, numUintsInConstants> Elements(numUintsInConstants);
  495. for (unsigned int i = 0; i < numUintsInConstants; i++)
  496. {
  497. Elements[i] = Type::getInt32Ty(Ctx);
  498. }
  499. StructTy = llvm::StructType::create(Elements, CbufferName);
  500. AddAnnoationsIfNeeded(DM, StructTy, std::string(CbufferName), numUintsInConstants);
  501. }
  502. unsigned int handle = AddHandle(M, FallbackLayerDispatchConstantsRegister, 1, FallbackLayerRegisterSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, CbufferName, StructTy, sizeOfConstantsInBytes);
  503. llvm::Value *Symbol = DM.GetCBuffer(handle).GetGlobalSymbol();
  504. llvm::Value *Load = Builder.CreateLoad(Symbol, "DispatchRaysConstants");
  505. Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
  506. Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
  507. DispatchRaysConstantsHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, CbufferName);
  508. }
  509. // Raygen always reads from the start so no offset calculations needed
  510. if (ShaderKind != DXIL::ShaderKind::RayGeneration)
  511. {
  512. std::string ShaderRecordOffsetFuncName = "\x1?Fallback_ShaderRecordOffset@@YAIXZ";
  513. Function *ShaderRecordOffsetFunc = M.getFunction(ShaderRecordOffsetFuncName);
  514. if (!ShaderRecordOffsetFunc)
  515. {
  516. FunctionType *ShaderRecordOffsetFuncType = FunctionType::get(llvm::Type::getInt32Ty(Ctx), {}, false);
  517. ShaderRecordOffsetFunc = Function::Create(ShaderRecordOffsetFuncType, GlobalValue::LinkageTypes::ExternalLinkage, ShaderRecordOffsetFuncName, &M);
  518. }
  519. BaseShaderRecordOffset = Builder.CreateCall(ShaderRecordOffsetFunc, {}, "shaderRecordOffset");
  520. }
  521. else
  522. {
  523. BaseShaderRecordOffset = HlslOP->GetU32Const(0);
  524. }
  525. }
  526. llvm::Value *DxilPatchShaderRecordBindings::CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes) {
  527. DxilModule &DM = M.GetOrCreateDxilModule();
  528. OP *HlslOP = DM.GetOP();
  529. // Create handle for the newly-added constant buffer (which is achieved via a function call)
  530. auto AdddName = "ShaderRecordOffsetInBytes";
  531. Constant *ShaderRecordOffsetInBytes = HlslOP->GetU32Const(RecordOffsetInBytes); // Offset of constants in shader record buffer
  532. return Builder.CreateAdd(CbufferOffsetInBytes, ShaderRecordOffsetInBytes, AdddName);
  533. }
  534. llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad) {
  535. DxilModule &DM = M.GetOrCreateDxilModule();
  536. OP *HlslOP = DM.GetOP();
  537. LLVMContext & Ctx = M.getContext();
  538. auto BufferLoadName = "ConstantBuffer";
  539. Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::CBufferLoadLegacy, Type::getInt32Ty(Ctx));
  540. Constant *CBufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CBufferLoadLegacy);
  541. Constant *RowToLoadConst = HlslOP->GetU32Const(RowToLoad);
  542. return Builder.CreateCall(BufferLoad, { CBufferLoadOpcodeArg, ResourceHandle, RowToLoadConst }, BufferLoadName);
  543. }
  544. llvm::Value *DxilPatchShaderRecordBindings::CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type) {
  545. DxilModule &DM = M.GetOrCreateDxilModule();
  546. OP *HlslOP = DM.GetOP();
  547. LLVMContext & Ctx = M.getContext();
  548. // Create handle for the newly-added constant buffer (which is achieved via a function call)
  549. auto BufferLoadName = "ShaderRecordBuffer";
  550. if (type->getNumContainedTypes() > 1)
  551. {
  552. // TODO: Buffer loads aren't legal with container types, check if this is the right wait to handle this
  553. type = type->getContainedType(0);
  554. }
  555. // TODO Do I need to check the result? Hopefully not
  556. Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::BufferLoad, type);
  557. Constant *BufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::BufferLoad);
  558. Constant *Unused = UndefValue::get(llvm::Type::getInt32Ty(Ctx));
  559. return Builder.CreateCall(BufferLoad, { BufferLoadOpcodeArg, ShaderTableHandle, ShaderRecordOffsetInBytes, Unused }, BufferLoadName);
  560. }
  561. void DxilPatchShaderRecordBindings::ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith) {
  562. for (auto UserIter = InstructionToReplace->user_begin(); UserIter != InstructionToReplace->user_end();) {
  563. // Increment the iterator before the replace since the replace alters the uses list
  564. auto userInstr = UserIter++;
  565. userInstr->replaceUsesOfWith(InstructionToReplace, ValueToReplaceWith);
  566. }
  567. InstructionToReplace->eraseFromParent();
  568. }
  569. llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction) {
  570. DxilModule &DM = M.GetOrCreateDxilModule();
  571. OP *HlslOP = DM.GetOP();
  572. DxilInst_CBufferLoad cbufferLoad(instruction);
  573. DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
  574. if (cbufferLoad) {
  575. return cbufferLoad.get_byteOffset();
  576. } else if (cbufferLoadLegacy) {
  577. Constant *LegacyMultiplier = HlslOP->GetU32Const(16);
  578. return Builder.CreateMul(cbufferLoadLegacy.get_regIndex(), LegacyMultiplier);
  579. } else {
  580. ThrowFailure();
  581. return nullptr;
  582. }
  583. }
  584. bool DxilPatchShaderRecordBindings::IsCBufferLoad(llvm::Instruction *instruction) {
  585. DxilInst_CBufferLoad cbufferLoad(instruction);
  586. DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
  587. return cbufferLoad || cbufferLoadLegacy;
  588. }
  589. const unsigned int GetResolvedRangeID(DXIL::ResourceClass resClass, Value *rangeIdVal)
  590. {
  591. if (auto CI = dyn_cast<ConstantInt>(rangeIdVal))
  592. {
  593. return CI->getZExtValue();
  594. }
  595. else
  596. {
  597. assert(false);
  598. return 0;
  599. }
  600. }
  601. // TODO: This code is quite inefficient
  602. bool DxilPatchShaderRecordBindings::GetHandleInfo(
  603. Module &M,
  604. DxilInst_CreateHandleForLib &createHandleStructForLib,
  605. _Out_ unsigned int &shaderRegister,
  606. _Out_ unsigned int &registerSpace,
  607. _Out_ DXIL::ResourceKind &kind,
  608. _Out_ DXIL::ResourceClass &resClass,
  609. _Out_ llvm::Type *&resType)
  610. {
  611. DxilModule &DM = M.GetOrCreateDxilModule();
  612. LoadInst *loadRangeId = cast<LoadInst>(createHandleStructForLib.get_Resource());
  613. Value *ResourceSymbol = loadRangeId->getPointerOperand();
  614. DXIL::ResourceClass resourceClasses[] = {
  615. DXIL::ResourceClass::CBuffer,
  616. DXIL::ResourceClass::SRV,
  617. DXIL::ResourceClass::UAV,
  618. DXIL::ResourceClass::Sampler
  619. };
  620. hlsl::DxilResourceBase *Resource = nullptr;
  621. for (auto &resourceClass : resourceClasses) {
  622. switch (resourceClass)
  623. {
  624. case DXIL::ResourceClass::CBuffer:
  625. {
  626. auto &cbuffers = DM.GetCBuffers();
  627. for (auto &cbuffer : cbuffers)
  628. {
  629. if (cbuffer->GetGlobalSymbol() == ResourceSymbol)
  630. {
  631. Resource = cbuffer.get();
  632. break;
  633. }
  634. }
  635. break;
  636. }
  637. case DXIL::ResourceClass::SRV:
  638. case DXIL::ResourceClass::UAV:
  639. {
  640. auto &viewList = resourceClass == DXIL::ResourceClass::SRV ? DM.GetSRVs() : DM.GetUAVs();
  641. for (auto &view : viewList)
  642. {
  643. if (view->GetGlobalSymbol() == ResourceSymbol)
  644. {
  645. Resource = view.get();
  646. break;
  647. }
  648. }
  649. break;
  650. }
  651. case DXIL::ResourceClass::Sampler:
  652. {
  653. auto &samplers = DM.GetSamplers();
  654. for (auto &sampler : samplers)
  655. {
  656. if (sampler->GetGlobalSymbol() == ResourceSymbol)
  657. {
  658. Resource = sampler.get();
  659. break;
  660. }
  661. }
  662. break;
  663. }
  664. }
  665. }
  666. if (Resource)
  667. {
  668. registerSpace = Resource->GetSpaceID();
  669. shaderRegister = Resource->GetLowerBound();
  670. kind = Resource->GetKind();
  671. resClass = Resource->GetClass();
  672. resType = cast<GlobalVariable>(Resource->GetGlobalSymbol())->getType()->getPointerElementType();
  673. }
  674. return Resource != nullptr;
  675. }
  676. llvm::Value *DxilPatchShaderRecordBindings::LoadShaderRecordData(
  677. Module &M,
  678. IRBuilder<> &Builder,
  679. llvm::Value *offsetToShaderRecord,
  680. unsigned int dataOffsetInShaderRecord)
  681. {
  682. DxilModule &DM = M.GetOrCreateDxilModule();
  683. LLVMContext &Ctx = M.getContext();
  684. OP *HlslOP = DM.GetOP();
  685. Constant *dataOffset =
  686. HlslOP->GetU32Const(dataOffsetInShaderRecord);
  687. Value *shaderTableOffsetToData = Builder.CreateAdd(dataOffset, offsetToShaderRecord);
  688. return CreateShaderRecordBufferLoad(M, Builder, shaderTableOffsetToData,
  689. llvm::Type::getInt32Ty(Ctx));
  690. }
  691. void DxilPatchShaderRecordBindings::PatchCreateHandleToUseDescriptorIndex(
  692. _In_ Module &M,
  693. _In_ IRBuilder<> &Builder,
  694. _In_ DXIL::ResourceKind &resourceKind,
  695. _In_ DXIL::ResourceClass &resourceClass,
  696. _In_ llvm::Type *resourceType,
  697. _In_ llvm::Value *descriptorIndex,
  698. _Inout_ DxilInst_CreateHandleForLib &createHandleInstr)
  699. {
  700. DxilModule &DM = M.GetOrCreateDxilModule();
  701. OP *HlslOP = DM.GetOP();
  702. llvm::Value *descriptorHeapSymbol = GetAliasedDescriptorHeapHandle(M, resourceType, resourceClass, resourceKind);
  703. llvm::Value *viewSymbol = Builder.CreateGEP(descriptorHeapSymbol, { HlslOP->GetU32Const(0), descriptorIndex }, "IndexIntoDH");
  704. DxilMDHelper::MarkNonUniform(cast<Instruction>(viewSymbol));
  705. llvm::Value *handle = Builder.CreateLoad(viewSymbol);
  706. auto callInst = cast<CallInst>(createHandleInstr.Instr);
  707. callInst->setCalledFunction(HlslOP->GetOpFunc(
  708. DXIL::OpCode::CreateHandleForLib,
  709. handle->getType()));
  710. createHandleInstr.set_Resource(handle);
  711. }
  712. void DxilPatchShaderRecordBindings::InitializeViewTable() {
  713. // The Fallback Layer declares a bindless raw buffer that spans the entire descriptor heap,
  714. // manually add it to the list of UAV register spaces used
  715. if (*pInputShaderInfo->pNumUAVSpaces == 0)
  716. {
  717. ViewKey key = { (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer, 0 };
  718. unsigned int index = FindOrInsertViewIntoList(
  719. key,
  720. pInputShaderInfo->pUAVRegisterSpaceArray,
  721. *pInputShaderInfo->pNumUAVSpaces,
  722. FallbackLayerNumDescriptorHeapSpacesPerView);
  723. (void)index;
  724. assert(index == 0);
  725. }
  726. }
  727. void DxilPatchShaderRecordBindings::PatchShaderBindings(Module &M) {
  728. DxilModule &DM = M.GetOrCreateDxilModule();
  729. OP *HlslOP = DM.GetOP();
  730. // Don't erase instructions until the very end because it throws off the iterator
  731. std::vector<llvm::Instruction *> instructionsToRemove;
  732. for (BasicBlock &block : EntryPointFunction->getBasicBlockList()) {
  733. auto & Instructions = block.getInstList();
  734. auto It = Instructions.begin();
  735. for (auto &instr : Instructions) {
  736. DxilInst_CreateHandleForLib createHandleForLib(&instr);
  737. if (createHandleForLib) {
  738. DXIL::ResourceClass resourceClass;
  739. unsigned int registerSpace;
  740. unsigned int registerIndex;
  741. DXIL::ResourceKind kind;
  742. llvm::Type *resType;
  743. bool resourceIsResolved = true;
  744. resourceIsResolved = GetHandleInfo(M, createHandleForLib, registerIndex, registerSpace, kind, resourceClass, resType);
  745. if (!resourceIsResolved) continue; // TODO: This shouldn't actually be happening?
  746. ShaderRecordEntry shaderRecord = FindRootSignatureDescriptor(
  747. *pRootSignatureDesc,
  748. pInputShaderInfo->ShaderRecordIdentifierSizeInBytes,
  749. resourceClass,
  750. registerIndex,
  751. registerSpace);
  752. const bool IsBindingSpecifiedInLocalRootSignature = !shaderRecord.IsInvalid();
  753. if (IsBindingSpecifiedInLocalRootSignature) {
  754. if (!DispatchRaysConstantsHandle) {
  755. AddInputBinding(M);
  756. }
  757. switch (shaderRecord.ParameterType) {
  758. case DxilRootParameterType::Constants32Bit:
  759. {
  760. for (User *U : instr.users()) {
  761. llvm::Instruction *instruction = cast<CallInst>(U);
  762. if (IsCBufferLoad(instruction)) {
  763. llvm::Instruction *cbufferLoadInstr = instruction;
  764. IRBuilder<> Builder(cbufferLoadInstr);
  765. llvm::Value * cbufferOffsetInBytes = CreateCBufferLoadOffsetInBytes(M, Builder, cbufferLoadInstr);
  766. llvm::Value *LocalOffsetToRootConstant = CreateOffsetToShaderRecord(M, Builder, shaderRecord.RecordOffsetInBytes, cbufferOffsetInBytes);
  767. llvm::Value *GlobalOffsetToRootConstant = Builder.CreateAdd(LocalOffsetToRootConstant, BaseShaderRecordOffset);
  768. llvm::Value *srvBufferLoad = CreateShaderRecordBufferLoad(M, Builder, GlobalOffsetToRootConstant, cbufferLoadInstr->getType());
  769. ReplaceUsesOfWith(cbufferLoadInstr, srvBufferLoad);
  770. } else {
  771. ThrowFailure();
  772. }
  773. }
  774. instructionsToRemove.push_back(&instr);
  775. break;
  776. }
  777. case DxilRootParameterType::DescriptorTable:
  778. {
  779. IRBuilder<> Builder(&instr);
  780. llvm::Value *srvBufferLoad = LoadShaderRecordData(
  781. M,
  782. Builder,
  783. BaseShaderRecordOffset,
  784. shaderRecord.RecordOffsetInBytes);
  785. llvm::Value *DescriptorTableEntryLo = Builder.CreateExtractValue(srvBufferLoad, 0, "DescriptorTableHandleLo");
  786. unsigned int offsetToLoadInUints = offsetof(DispatchRaysConstants, SrvCbvUavDescriptorHeapStart) / sizeof(uint32_t);
  787. unsigned int uintsPerRow = 4;
  788. unsigned int rowToLoad = offsetToLoadInUints / uintsPerRow;
  789. unsigned int extractValueOffset = offsetToLoadInUints % uintsPerRow;
  790. llvm::Value *DescHeapConstants = CreateCBufferLoadLegacy(M, Builder, DispatchRaysConstantsHandle, rowToLoad);
  791. llvm::Value *DescriptorHeapStartAddressLo = Builder.CreateExtractValue(DescHeapConstants, extractValueOffset, "DescriptorHeapStartHandleLo");
  792. // TODO: The hi bits can only be ignored if the difference is guaranteed to be < 32 bytes. This is an unsafe assumption, particularly given
  793. // large descriptor sizes
  794. llvm::Value *DescriptorTableOffsetInBytes = Builder.CreateSub(DescriptorTableEntryLo, DescriptorHeapStartAddressLo, "TableOffsetInBytes");
  795. Constant *DescriptorSizeInBytes = HlslOP->GetU32Const(pInputShaderInfo->SrvCbvUavDescriptorSizeInBytes);
  796. llvm::Value * DescriptorTableStartIndex = Builder.CreateExactUDiv(DescriptorTableOffsetInBytes, DescriptorSizeInBytes, "TableStartIndex");
  797. Constant *RecordOffset = HlslOP->GetU32Const(shaderRecord.OffsetInDescriptors);
  798. llvm::Value * BaseDescriptorIndex = Builder.CreateAdd(DescriptorTableStartIndex, RecordOffset, "BaseDescriptorIndex");
  799. // TODO: Not supporting dynamic indexing yet, should be pulled from CreateHandleForLib
  800. // If dynamic indexing is being used, add the apps index on top of the calculated index
  801. llvm::Value * DynamicIndex = HlslOP->GetU32Const(0);
  802. llvm::Value * DescriptorIndex = Builder.CreateAdd(BaseDescriptorIndex, DynamicIndex, "DescriptorIndex");
  803. PatchCreateHandleToUseDescriptorIndex(
  804. M,
  805. Builder,
  806. kind,
  807. resourceClass,
  808. resType,
  809. DescriptorIndex,
  810. createHandleForLib);
  811. break;
  812. }
  813. case DxilRootParameterType::CBV:
  814. case DxilRootParameterType::SRV:
  815. case DxilRootParameterType::UAV: {
  816. IRBuilder<> Builder(&instr);
  817. llvm::Value *srvBufferLoad = LoadShaderRecordData(
  818. M,
  819. Builder,
  820. BaseShaderRecordOffset,
  821. shaderRecord.RecordOffsetInBytes);
  822. llvm::Value *DescriptorIndex = Builder.CreateExtractValue(
  823. srvBufferLoad, 1, "DescriptorHeapIndex");
  824. // TODO: Handle offset in bytes
  825. // llvm::Value *OffsetInBytes = Builder.CreateExtractValue(
  826. // srvBufferLoad, 0, "OffsetInBytes");
  827. PatchCreateHandleToUseDescriptorIndex(
  828. M,
  829. Builder,
  830. kind,
  831. resourceClass,
  832. resType,
  833. DescriptorIndex,
  834. createHandleForLib);
  835. break;
  836. }
  837. default:
  838. ThrowFailure();
  839. break;
  840. }
  841. }
  842. }
  843. }
  844. }
  845. for (auto instruction : instructionsToRemove) {
  846. instruction->eraseFromParent();
  847. }
  848. }
  849. bool IsParameterTypeCompatibleWithResourceClass(
  850. DXIL::ResourceClass resourceClass,
  851. DxilRootParameterType parameterType) {
  852. switch (parameterType) {
  853. case DxilRootParameterType::DescriptorTable:
  854. return true;
  855. case DxilRootParameterType::Constants32Bit:
  856. case DxilRootParameterType::CBV:
  857. return resourceClass == DXIL::ResourceClass::CBuffer;
  858. case DxilRootParameterType::SRV:
  859. return resourceClass == DXIL::ResourceClass::SRV;
  860. case DxilRootParameterType::UAV:
  861. return resourceClass == DXIL::ResourceClass::UAV;
  862. default:
  863. ThrowFailure();
  864. return false;
  865. }
  866. }
  867. DxilRootParameterType ConvertD3D12ParameterTypeToDxil(DxilRootParameterType parameter) {
  868. switch (parameter) {
  869. case DxilRootParameterType::Constants32Bit:
  870. return DxilRootParameterType::Constants32Bit;
  871. case DxilRootParameterType::DescriptorTable:
  872. return DxilRootParameterType::DescriptorTable;
  873. case DxilRootParameterType::CBV:
  874. return DxilRootParameterType::CBV;
  875. case DxilRootParameterType::SRV:
  876. return DxilRootParameterType::SRV;
  877. case DxilRootParameterType::UAV:
  878. return DxilRootParameterType::UAV;
  879. }
  880. assert(false);
  881. return (DxilRootParameterType)-1;
  882. }
  883. DXIL::ResourceClass ConvertD3D12RangeTypeToDxil(DxilDescriptorRangeType rangeType) {
  884. switch (rangeType) {
  885. case DxilDescriptorRangeType::SRV:
  886. return DXIL::ResourceClass::SRV;
  887. case DxilDescriptorRangeType::UAV:
  888. return DXIL::ResourceClass::UAV;
  889. case DxilDescriptorRangeType::CBV:
  890. return DXIL::ResourceClass::CBuffer;
  891. case DxilDescriptorRangeType::Sampler:
  892. return DXIL::ResourceClass::Sampler;
  893. }
  894. assert(false);
  895. return (DXIL::ResourceClass) - 1;
  896. }
  897. unsigned int GetParameterTypeAlignment(DxilRootParameterType parameterType) {
  898. switch (parameterType) {
  899. case DxilRootParameterType::DescriptorTable:
  900. return SizeofD3D12GpuDescriptorHandle;
  901. case DxilRootParameterType::Constants32Bit:
  902. return sizeof(uint32_t);
  903. case DxilRootParameterType::CBV: // fallthrough
  904. case DxilRootParameterType::SRV: // fallthrough
  905. case DxilRootParameterType::UAV:
  906. return SizeofD3D12GpuVA;
  907. default:
  908. return UINT_MAX;
  909. }
  910. }
  911. template <typename TD3D12_ROOT_SIGNATURE_DESC>
  912. ShaderRecordEntry FindRootSignatureDescriptorHelper(
  913. const TD3D12_ROOT_SIGNATURE_DESC &rootSignatureDescriptor,
  914. unsigned int ShaderRecordIdentifierSizeInBytes,
  915. DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
  916. unsigned int registerSpace) {
  917. // Automatically fail if it's looking for a fallback binding as these never
  918. // need to be patched
  919. if (registerSpace != FallbackLayerRegisterSpace) {
  920. unsigned int recordOffset = ShaderRecordIdentifierSizeInBytes;
  921. for (unsigned int rootParamIndex = 0;
  922. rootParamIndex < rootSignatureDescriptor.NumParameters;
  923. rootParamIndex++) {
  924. auto &rootParam = rootSignatureDescriptor.pParameters[rootParamIndex];
  925. auto dxilParamType =
  926. ConvertD3D12ParameterTypeToDxil(rootParam.ParameterType);
  927. #define ALIGN(alignment, num) (((num + alignment - 1) / alignment) * alignment)
  928. recordOffset = ALIGN(GetParameterTypeAlignment(rootParam.ParameterType),
  929. recordOffset);
  930. switch (rootParam.ParameterType) {
  931. case DxilRootParameterType::Constants32Bit:
  932. if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
  933. dxilParamType) &&
  934. baseRegisterIndex == rootParam.Constants.ShaderRegister &&
  935. registerSpace == rootParam.Constants.RegisterSpace) {
  936. return {dxilParamType, recordOffset};
  937. }
  938. recordOffset += rootParam.Constants.Num32BitValues * sizeof(uint32_t);
  939. break;
  940. case DxilRootParameterType::DescriptorTable: {
  941. auto &descriptorTable = rootParam.DescriptorTable;
  942. unsigned int rangeOffsetInDescriptors = 0;
  943. for (unsigned int rangeIndex = 0;
  944. rangeIndex < descriptorTable.NumDescriptorRanges; rangeIndex++) {
  945. auto &range = descriptorTable.pDescriptorRanges[rangeIndex];
  946. if (range.OffsetInDescriptorsFromTableStart != -1) {
  947. rangeOffsetInDescriptors = range.OffsetInDescriptorsFromTableStart;
  948. }
  949. if (ConvertD3D12RangeTypeToDxil(range.RangeType) == resourceClass &&
  950. range.RegisterSpace == registerSpace &&
  951. range.BaseShaderRegister <= baseRegisterIndex &&
  952. range.BaseShaderRegister + range.NumDescriptors >
  953. baseRegisterIndex) {
  954. rangeOffsetInDescriptors +=
  955. baseRegisterIndex - range.BaseShaderRegister;
  956. return {dxilParamType, recordOffset, rangeOffsetInDescriptors};
  957. }
  958. rangeOffsetInDescriptors += range.NumDescriptors;
  959. }
  960. recordOffset += SizeofD3D12GpuDescriptorHandle;
  961. break;
  962. }
  963. case DxilRootParameterType::CBV:
  964. case DxilRootParameterType::SRV:
  965. case DxilRootParameterType::UAV:
  966. if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
  967. dxilParamType) &&
  968. baseRegisterIndex == rootParam.Descriptor.ShaderRegister &&
  969. registerSpace == rootParam.Descriptor.RegisterSpace) {
  970. return {dxilParamType, recordOffset};
  971. }
  972. recordOffset += SizeofD3D12GpuVA;
  973. break;
  974. }
  975. }
  976. }
  977. return ShaderRecordEntry::InvalidEntry();
  978. }
  979. // TODO: Consider pre-calculating this into a map
  980. ShaderRecordEntry DxilPatchShaderRecordBindings::FindRootSignatureDescriptor(
  981. const DxilVersionedRootSignatureDesc &rootSignatureDescriptor,
  982. unsigned int ShaderRecordIdentifierSizeInBytes,
  983. DXIL::ResourceClass resourceClass,
  984. unsigned int baseRegisterIndex,
  985. unsigned int registerSpace) {
  986. switch (rootSignatureDescriptor.Version) {
  987. case DxilRootSignatureVersion::Version_1_0:
  988. return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_0, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
  989. case DxilRootSignatureVersion::Version_1_1:
  990. return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_1, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
  991. default:
  992. ThrowFailure();
  993. return ShaderRecordEntry::InvalidEntry();
  994. }
  995. }