DxilPatchShaderRecordBindings.cpp 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPatchShaderRecordBindings.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides a pass used by the RayTracing Fallback Lyaer to add modify //
  9. // bindings to pull local root signature parameters from a global //
  10. // "shader table" buffer instead //
  11. // //
  12. ///////////////////////////////////////////////////////////////////////////////
  13. #include "dxc/HLSL/DxilGenerationPass.h"
  14. #include "dxc/HLSL/DxilFallbackLayerPass.h"
  15. #include "dxc/DXIL/DxilOperations.h"
  16. #include "dxc/DXIL/DxilSignatureElement.h"
  17. #include "dxc/DXIL/DxilFunctionProps.h"
  18. #include "dxc/DXIL/DxilModule.h"
  19. #include "dxc/Support/Global.h"
  20. #include "dxc/Support/Unicode.h"
  21. #include "dxc/DXIL/DxilTypeSystem.h"
  22. #include "dxc/DXIL/DxilConstants.h"
  23. #include "dxc/DXIL/DxilInstructions.h"
  24. #include "dxc/HLSL/DxilSpanAllocator.h"
  25. #include "dxc/DxilRootSignature/DxilRootSignature.h"
  26. #include "dxc/DXIL/DxilUtil.h"
  27. #include "llvm/Transforms/Utils/Cloning.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/IntrinsicInst.h"
  30. #include "llvm/IR/InstIterator.h"
  31. #include "llvm/IR/Module.h"
  32. #include "llvm/IR/PassManager.h"
  33. #include "llvm/ADT/BitVector.h"
  34. #include "llvm/Pass.h"
  35. #include "llvm/Transforms/Utils/Local.h"
  36. #include "llvm/Transforms/Scalar.h"
  37. #include <memory>
  38. #include <unordered_set>
  39. #include <functional>
  40. #include <unordered_map>
  41. #include <array>
  42. struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
  43. #include "DxilPatchShaderRecordBindingsShared.h"
  44. using namespace llvm;
  45. using namespace hlsl;
  46. bool operator==(const ViewKey &a, const ViewKey &b) {
  47. return memcmp(&a, &b, sizeof(a)) == 0;
  48. }
  49. const size_t SizeofD3D12GpuVA = sizeof(uint64_t);
  50. const size_t SizeofD3D12GpuDescriptorHandle = sizeof(uint64_t);
  51. Function *CloneFunction(Function *Orig,
  52. const llvm::Twine &Name,
  53. llvm::Module *llvmModule) {
  54. Function *F = Function::Create(Orig->getFunctionType(),
  55. GlobalValue::LinkageTypes::ExternalLinkage,
  56. Name, llvmModule);
  57. SmallVector<ReturnInst *, 2> Returns;
  58. ValueToValueMapTy vmap;
  59. // Map params.
  60. auto entryParamIt = F->arg_begin();
  61. for (Argument &param : Orig->args()) {
  62. vmap[&param] = (entryParamIt++);
  63. }
  64. DxilModule &DM = llvmModule->GetOrCreateDxilModule();
  65. llvm::CloneFunctionInto(F, Orig, vmap, /*ModuleLevelChagnes*/ false, Returns);
  66. DM.GetTypeSystem().CopyFunctionAnnotation(F, Orig, DM.GetTypeSystem());
  67. if (DM.HasDxilFunctionProps(F)) {
  68. DM.CloneDxilEntryProps(Orig, F);
  69. }
  70. return F;
  71. }
  72. struct ShaderRecordEntry {
  73. DxilRootParameterType ParameterType;
  74. unsigned int RecordOffsetInBytes;
  75. unsigned int OffsetInDescriptors; // Only valid for descriptor tables
  76. static ShaderRecordEntry InvalidEntry() { return { (DxilRootParameterType)-1, (unsigned int)-1, 0 }; }
  77. bool IsInvalid() { return (unsigned int)ParameterType == (unsigned int)-1; }
  78. };
  79. struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
  80. class DxilPatchShaderRecordBindings : public ModulePass {
  81. public:
  82. static char ID; // Pass identification, replacement for typeid
  83. explicit DxilPatchShaderRecordBindings() : ModulePass(ID) {}
  84. const char *getPassName() const override { return "DXIL Patch Shader Record Binding"; }
  85. void applyOptions(PassOptions O) override;
  86. bool runOnModule(Module &M) override;
  87. private:
  88. void ValidateParameters();
  89. void AddInputBinding(Module &M);
  90. void PatchShaderBindings(Module &M);
  91. void InitializeViewTable();
  92. unsigned int AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName);
  93. unsigned int AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type = nullptr, unsigned int constantBufferSize = 0);
  94. unsigned int AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type);
  95. unsigned int AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName);
  96. llvm::Value *CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes);
  97. llvm::Value *CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type);
  98. llvm::Value *CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction);
  99. llvm::Value *CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad = 0);
  100. llvm::Value *LoadShaderRecordData(Module &M, IRBuilder<> &Builder,
  101. llvm::Value *offsetToShaderRecord,
  102. unsigned int dataOffsetInShaderRecord);
  103. void PatchCreateHandleToUseDescriptorIndex(
  104. _In_ Module &M,
  105. _In_ IRBuilder<> &Builder,
  106. _In_ DXIL::ResourceKind &resourceKind,
  107. _In_ DXIL::ResourceClass &resourceClass,
  108. _In_ llvm::Type *resourceType,
  109. _In_ llvm::Value *descriptorIndex,
  110. _Inout_ DxilInst_CreateHandleForLib &createHandleInstr);
  111. bool GetHandleInfo(
  112. Module &M,
  113. DxilInst_CreateHandleForLib &createHandleStructForLib,
  114. _Out_ unsigned int &shaderRegister,
  115. _Out_ unsigned int &registerSpace,
  116. _Out_ DXIL::ResourceKind &kind,
  117. _Out_ DXIL::ResourceClass &resClass,
  118. _Out_ llvm::Type *&resType);
  119. llvm::Value * GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind);
  120. unsigned int GetConstantBufferOffsetToShaderRecord();
  121. bool IsCBufferLoad(llvm::Instruction *instruction);
  122. // Unlike the LLVM version of this function, this does not requires the InstructionToReplace and the ValueToReplaceWith to be the same instruction type
  123. static void ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith);
  124. static ShaderRecordEntry FindRootSignatureDescriptor(const DxilVersionedRootSignatureDesc &rootSignatureDescriptor, unsigned int ShaderRecordIdentifierSizeInBytes, DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex, unsigned int registerSpace);
  125. // TODO: I would like to see these prefixed with m_
  126. llvm::Value *ShaderTableHandle = nullptr;
  127. llvm::Value *DispatchRaysConstantsHandle = nullptr;
  128. llvm::Value *BaseShaderRecordOffset = nullptr;
  129. static const unsigned int NumViewTypes = 4;
  130. struct ViewKeyHasher
  131. {
  132. public:
  133. std::size_t operator()(const ViewKey &x) const {
  134. return std::hash<unsigned int>()((unsigned int)x.ViewType) ^
  135. std::hash<unsigned int>()((unsigned int)x.StructuredStride);
  136. }
  137. };
  138. std::unordered_map<ViewKey, llvm::Value *, ViewKeyHasher>
  139. TypeToAliasedDescriptorHeap[NumViewTypes];
  140. llvm::Function *EntryPointFunction;
  141. ShaderInfo *pInputShaderInfo;
  142. const DxilVersionedRootSignatureDesc *pRootSignatureDesc;
  143. DXIL::ShaderKind ShaderKind;
  144. };
  145. char DxilPatchShaderRecordBindings::ID = 0;
  146. // TODO: Find the right thing to do on failure
  147. void ThrowFailure() {
  148. throw std::exception();
  149. }
  150. // TODO: Stolen from Brandon's code, merge
  151. // Remove ELF mangling
  152. static inline std::string GetUnmangledName(StringRef name) {
  153. if (!name.startswith("\x1?"))
  154. return name;
  155. size_t pos = name.find("@@");
  156. if (pos == name.npos)
  157. return name;
  158. return name.substr(2, pos - 2);
  159. }
  160. static Function* getFunctionFromName(Module &M, const std::wstring& exportName) {
  161. for (auto F = M.begin(), E = M.end(); F != E; ++F) {
  162. std::wstring functionName = Unicode::UTF8ToUTF16StringOrThrow(GetUnmangledName(F->getName()).c_str());
  163. if (exportName == functionName) {
  164. return F;
  165. }
  166. }
  167. return nullptr;
  168. }
  169. ModulePass *llvm::createDxilPatchShaderRecordBindingsPass() {
  170. return new DxilPatchShaderRecordBindings();
  171. }
  172. INITIALIZE_PASS(DxilPatchShaderRecordBindings, "hlsl-dxil-patch-shader-record-bindings", "Patch shader record bindings to instead pull from the fallback provided bindings", false, false)
  173. void DxilPatchShaderRecordBindings::applyOptions(PassOptions O) {
  174. for (const auto & option : O) {
  175. if (0 == option.first.compare("root-signature")) {
  176. unsigned int cHexRadix = 16;
  177. pInputShaderInfo = (ShaderInfo*)strtoull(option.second.data(), nullptr, cHexRadix);
  178. pRootSignatureDesc = (const DxilVersionedRootSignatureDesc*)pInputShaderInfo->pRootSignatureDesc;
  179. }
  180. }
  181. }
  182. void AddAnnoationsIfNeeded(DxilModule &DM, llvm::StructType *StructTy, const std::string &FieldName, unsigned int numFields = 1)
  183. {
  184. auto pAnnotation = DM.GetTypeSystem().GetStructAnnotation(StructTy);
  185. if (pAnnotation == nullptr)
  186. {
  187. pAnnotation = DM.GetTypeSystem().AddStructAnnotation(StructTy);
  188. pAnnotation->SetCBufferSize(sizeof(uint32_t) * numFields);
  189. for (unsigned int i = 0; i < numFields; i++)
  190. {
  191. pAnnotation->GetFieldAnnotation(i).SetCBufferOffset(sizeof(uint32_t) * i);
  192. pAnnotation->GetFieldAnnotation(i).SetCompType(hlsl::DXIL::ComponentType::I32);
  193. pAnnotation->GetFieldAnnotation(i).SetFieldName(FieldName + std::to_string(i));
  194. }
  195. }
  196. }
  197. unsigned int DxilPatchShaderRecordBindings::AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type, unsigned int constantBufferSize) {
  198. LLVMContext & Ctx = M.getContext();
  199. DxilModule &DM = M.GetOrCreateDxilModule();
  200. // Set up a SRV with byte address buffer
  201. unsigned int resourceHandle;
  202. std::unique_ptr<DxilResource> pHandle;
  203. std::unique_ptr<DxilCBuffer> pCBuf;
  204. std::unique_ptr<DxilSampler> pSampler;
  205. DxilResourceBase *pBaseHandle;
  206. switch (resClass) {
  207. case DXIL::ResourceClass::SRV:
  208. resourceHandle = static_cast<unsigned int>(DM.GetSRVs().size());
  209. pHandle = llvm::make_unique<DxilResource>();
  210. pHandle->SetRW(false);
  211. pBaseHandle = pHandle.get();
  212. break;
  213. case DXIL::ResourceClass::UAV:
  214. resourceHandle = static_cast<unsigned int>(DM.GetUAVs().size());
  215. pHandle = llvm::make_unique<DxilResource>();
  216. pHandle->SetRW(true);
  217. pBaseHandle = pHandle.get();
  218. break;
  219. case DXIL::ResourceClass::CBuffer:
  220. resourceHandle = static_cast<unsigned int>(DM.GetCBuffers().size());
  221. pCBuf = llvm::make_unique<DxilCBuffer>();
  222. pCBuf->SetSize(constantBufferSize);
  223. pBaseHandle = pCBuf.get();
  224. break;
  225. case DXIL::ResourceClass::Sampler:
  226. resourceHandle = static_cast<unsigned int>(DM.GetSamplers().size());
  227. pSampler = llvm::make_unique<DxilSampler>();
  228. // TODO: Is this okay? What if one of the samplers in the table is a comparison sampler?
  229. pSampler->SetSamplerKind(DxilSampler::SamplerKind::Default);
  230. pBaseHandle = pSampler.get();
  231. break;
  232. }
  233. if (!type) {
  234. SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
  235. std::string ByteAddressBufferName = "struct.ByteAddressBuffer";
  236. type = M.getTypeByName(ByteAddressBufferName);
  237. if (!type)
  238. {
  239. StructType *StructTy;
  240. type = StructTy = StructType::create(Elements, ByteAddressBufferName);
  241. AddAnnoationsIfNeeded(DM, StructTy, ByteAddressBufferName);
  242. }
  243. }
  244. GlobalVariable *GV = M.getGlobalVariable(bufferName);
  245. if (!GV) {
  246. GV = cast<GlobalVariable>(M.getOrInsertGlobal(bufferName, type));
  247. }
  248. pBaseHandle->SetGlobalName(bufferName.c_str());
  249. pBaseHandle->SetGlobalSymbol(GV);
  250. pBaseHandle->SetID(resourceHandle);
  251. pBaseHandle->SetSpaceID(registerSpace);
  252. pBaseHandle->SetLowerBound(baseRegisterIndex);
  253. pBaseHandle->SetRangeSize(rangeSize);
  254. pBaseHandle->SetKind(resKind);
  255. if (pHandle) {
  256. pHandle->SetGloballyCoherent(false);
  257. pHandle->SetHasCounter(false);
  258. pHandle->SetCompType(CompType::getF32()); // TODO: Need to handle all types
  259. }
  260. unsigned int ID;
  261. switch (resClass) {
  262. case DXIL::ResourceClass::SRV:
  263. ID = DM.AddSRV(std::move(pHandle));
  264. break;
  265. case DXIL::ResourceClass::UAV:
  266. ID = DM.AddUAV(std::move(pHandle));
  267. break;
  268. case DXIL::ResourceClass::CBuffer:
  269. ID = DM.AddCBuffer(std::move(pCBuf));
  270. break;
  271. case DXIL::ResourceClass::Sampler:
  272. ID = DM.AddSampler(std::move(pSampler));
  273. break;
  274. }
  275. assert(ID == resourceHandle);
  276. return ID;
  277. }
  278. unsigned int DxilPatchShaderRecordBindings::GetConstantBufferOffsetToShaderRecord()
  279. {
  280. switch (ShaderKind)
  281. {
  282. case DXIL::ShaderKind::ClosestHit:
  283. case DXIL::ShaderKind::AnyHit:
  284. case DXIL::ShaderKind::Intersection:
  285. return offsetof(DispatchRaysConstants, HitGroupShaderRecordStride);
  286. case DXIL::ShaderKind::Miss:
  287. return offsetof(DispatchRaysConstants, MissShaderRecordStride);
  288. default:
  289. ThrowFailure();
  290. return -1;
  291. }
  292. }
  293. unsigned int DxilPatchShaderRecordBindings::AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName) {
  294. return AddHandle(M, registerIndex, 1, registerSpace, DXIL::ResourceClass::SRV, DXIL::ResourceKind::RawBuffer, bufferName);
  295. }
  296. llvm::Constant *GetArraySymbol(Module &M, const std::string &bufferName) {
  297. LLVMContext & Ctx = M.getContext();
  298. SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
  299. llvm::StructType *StructTy = llvm::StructType::create(Elements, bufferName);
  300. llvm::ArrayType *ArrayTy = ArrayType::get(StructTy, -1);
  301. return UndefValue::get(ArrayTy->getPointerTo());
  302. }
  303. unsigned int DxilPatchShaderRecordBindings::AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName) {
  304. const unsigned int maxConstantBufferSize = 4096 * 16;
  305. return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, bufferName, GetArraySymbol(M, bufferName)->getType(), maxConstantBufferSize);
  306. }
  307. unsigned int DxilPatchShaderRecordBindings::AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type) {
  308. return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, resClass, resKind, bufferName, type);
  309. }
  310. // TODO: Stolen from Brandon's code
  311. DXIL::ShaderKind GetRayShaderKindCopy(Function* F)
  312. {
  313. if (F->hasFnAttribute("exp-shader"))
  314. return DXIL::ShaderKind::RayGeneration;
  315. DxilModule& DM = F->getParent()->GetDxilModule();
  316. if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
  317. return DM.GetDxilFunctionProps(F).shaderKind;
  318. return DXIL::ShaderKind::Invalid;
  319. }
  320. bool DxilPatchShaderRecordBindings::runOnModule(Module &M) {
  321. DxilModule &DM = M.GetOrCreateDxilModule();
  322. EntryPointFunction = pInputShaderInfo->ExportName ? getFunctionFromName(M, pInputShaderInfo->ExportName) : DM.GetEntryFunction();
  323. ShaderKind = GetRayShaderKindCopy(EntryPointFunction);
  324. ValidateParameters();
  325. InitializeViewTable();
  326. PatchShaderBindings(M);
  327. DM.ReEmitDxilResources();
  328. return true;
  329. }
  330. void DxilPatchShaderRecordBindings::ValidateParameters() {
  331. if (!pInputShaderInfo || !pInputShaderInfo->pRootSignatureDesc) {
  332. throw std::exception();
  333. }
  334. }
  335. DxilResourceBase &GetResourceFromID(DxilModule &DM, DXIL::ResourceClass resClass, unsigned int id)
  336. {
  337. switch (resClass)
  338. {
  339. case DXIL::ResourceClass::CBuffer:
  340. return DM.GetCBuffer(id);
  341. break;
  342. case DXIL::ResourceClass::SRV:
  343. return DM.GetSRV(id);
  344. break;
  345. case DXIL::ResourceClass::UAV:
  346. return DM.GetUAV(id);
  347. break;
  348. case DXIL::ResourceClass::Sampler:
  349. return DM.GetSampler(id);
  350. break;
  351. default:
  352. ThrowFailure();
  353. llvm_unreachable("invalid resource class");
  354. }
  355. }
  356. unsigned int FindOrInsertViewIntoList(const ViewKey &key, ViewKey *pViewList, unsigned int &numViews, unsigned int maxViews)
  357. {
  358. unsigned int viewIndex = 0;
  359. for (; viewIndex < numViews; viewIndex++)
  360. {
  361. if (pViewList[viewIndex] == key)
  362. {
  363. break;
  364. }
  365. }
  366. if (viewIndex == numViews)
  367. {
  368. if (viewIndex >= maxViews) {
  369. ThrowFailure();
  370. }
  371. pViewList[viewIndex] = key;
  372. numViews++;
  373. }
  374. return viewIndex;
  375. }
  376. llvm::Value *DxilPatchShaderRecordBindings::GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *type, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind)
  377. {
  378. DxilModule &DM = M.GetOrCreateDxilModule();
  379. unsigned int resClassIndex = (unsigned int)resClass;
  380. ViewKey key = {};
  381. key.ViewType = (unsigned int)resKind;
  382. if (DXIL::IsStructuredBuffer(resKind))
  383. {
  384. key.StructuredStride = type->getPrimitiveSizeInBits();
  385. } else if (resKind != DXIL::ResourceKind::RawBuffer)
  386. {
  387. auto containedType = type->getContainedType(0);
  388. // If it's a vector, get the type of just a single element
  389. if (containedType->getNumContainedTypes() > 0)
  390. {
  391. assert(containedType->getNumContainedTypes() <= 4);
  392. containedType = containedType->getContainedType(0);
  393. }
  394. key.SRVComponentType = (unsigned int)CompType::GetCompType(containedType).GetKind();
  395. }
  396. auto aliasedDescriptorHeapHandle = TypeToAliasedDescriptorHeap[resClassIndex].find(key);
  397. if (aliasedDescriptorHeapHandle == TypeToAliasedDescriptorHeap[resClassIndex].end())
  398. {
  399. unsigned int registerSpaceOffset = 0;
  400. std::string HandleName;
  401. if (resClass == DXIL::ResourceClass::SRV)
  402. {
  403. registerSpaceOffset = FindOrInsertViewIntoList(
  404. key,
  405. pInputShaderInfo->pSRVRegisterSpaceArray,
  406. *pInputShaderInfo->pNumSRVSpaces,
  407. FallbackLayerNumDescriptorHeapSpacesPerView);
  408. HandleName = std::string("SRVDescriptorHeapTable") +
  409. std::to_string(registerSpaceOffset);
  410. }
  411. else if (resClass == DXIL::ResourceClass::UAV)
  412. {
  413. registerSpaceOffset = FindOrInsertViewIntoList(
  414. key,
  415. pInputShaderInfo->pUAVRegisterSpaceArray,
  416. *pInputShaderInfo->pNumUAVSpaces,
  417. FallbackLayerNumDescriptorHeapSpacesPerView);
  418. if (registerSpaceOffset == 0)
  419. {
  420. // Using the descriptor heap declared by the fallback for handling emulated pointers,
  421. // make sure the name is an exact match
  422. assert(key.ViewType == (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer);
  423. HandleName = "\01?DescriptorHeapBufferTable@@3PAURWByteAddressBuffer@@A";
  424. }
  425. else
  426. {
  427. HandleName = std::string("UAVDescriptorHeapTable") +
  428. std::to_string(registerSpaceOffset);
  429. }
  430. }
  431. else if (resClass == DXIL::ResourceClass::CBuffer)
  432. {
  433. HandleName = std::string("CBVDescriptorHeapTable");
  434. } else {
  435. HandleName = std::string("SamplerDescriptorHeapTable");
  436. }
  437. llvm::ArrayType *descriptorHeapType = ArrayType::get(type, 0);
  438. unsigned int id = AddAliasedHandle(M, FallbackLayerDescriptorHeapTable, FallbackLayerRegisterSpace + FallbackLayerDescriptorHeapSpaceOffset + registerSpaceOffset, resClass, resKind, HandleName, descriptorHeapType);
  439. TypeToAliasedDescriptorHeap[resClassIndex][key] = GetResourceFromID(DM, resClass, id).GetGlobalSymbol();
  440. }
  441. return TypeToAliasedDescriptorHeap[resClassIndex][key];
  442. }
  443. void DxilPatchShaderRecordBindings::AddInputBinding(Module &M) {
  444. DxilModule &DM = M.GetOrCreateDxilModule();
  445. auto & EntryBlock = EntryPointFunction->getEntryBlock();
  446. auto & Instructions = EntryBlock.getInstList();
  447. std::string bufferName;
  448. unsigned int bufferRegister;
  449. switch (ShaderKind) {
  450. case DXIL::ShaderKind::AnyHit:
  451. case DXIL::ShaderKind::ClosestHit:
  452. case DXIL::ShaderKind::Intersection:
  453. bufferRegister = FallbackLayerHitGroupRecordByteAddressBufferRegister;
  454. bufferName = "\01?HitGroupShaderTable@@3UByteAddressBuffer@@A";
  455. break;
  456. case DXIL::ShaderKind::Miss:
  457. bufferRegister = FallbackLayerMissShaderRecordByteAddressBufferRegister;
  458. bufferName = "\01?MissShaderTable@@3UByteAddressBuffer@@A";
  459. break;
  460. case DXIL::ShaderKind::RayGeneration:
  461. bufferRegister = FallbackLayerRayGenShaderRecordByteAddressBufferRegister;
  462. bufferName = "\01?RayGenShaderTable@@3UByteAddressBuffer@@A";
  463. break;
  464. case DXIL::ShaderKind::Callable:
  465. bufferRegister = FallbackLayerCallableShaderRecordByteAddressBufferRegister;
  466. bufferName = "\01?CallableShaderTable@@3UByteAddressBuffer@@A";
  467. break;
  468. }
  469. unsigned int ShaderRecordID = AddSRVRawBuffer(M, bufferRegister, FallbackLayerRegisterSpace, bufferName);
  470. auto It = Instructions.begin();
  471. OP *HlslOP = DM.GetOP();
  472. LLVMContext & Ctx = M.getContext();
  473. IRBuilder<> Builder(It);
  474. {
  475. auto ShaderTableName = "ShaderTableHandle";
  476. llvm::Value *Symbol = DM.GetSRV(ShaderRecordID).GetGlobalSymbol();
  477. llvm::Value *Load = Builder.CreateLoad(Symbol, "LoadShaderTableHandle");
  478. Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
  479. Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
  480. ShaderTableHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, ShaderTableName);
  481. }
  482. {
  483. auto CbufferName = "Constants";
  484. const unsigned int sizeOfConstantsInBytes = sizeof(DispatchRaysConstants);
  485. llvm::StructType *StructTy= M.getTypeByName(CbufferName);
  486. if (!StructTy)
  487. {
  488. const unsigned int numUintsInConstants = sizeOfConstantsInBytes / sizeof(unsigned int);
  489. SmallVector<llvm::Type*, numUintsInConstants> Elements(numUintsInConstants);
  490. for (unsigned int i = 0; i < numUintsInConstants; i++)
  491. {
  492. Elements[i] = Type::getInt32Ty(Ctx);
  493. }
  494. StructTy = llvm::StructType::create(Elements, CbufferName);
  495. AddAnnoationsIfNeeded(DM, StructTy, std::string(CbufferName), numUintsInConstants);
  496. }
  497. unsigned int handle = AddHandle(M, FallbackLayerDispatchConstantsRegister, 1, FallbackLayerRegisterSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, CbufferName, StructTy, sizeOfConstantsInBytes);
  498. llvm::Value *Symbol = DM.GetCBuffer(handle).GetGlobalSymbol();
  499. llvm::Value *Load = Builder.CreateLoad(Symbol, "DispatchRaysConstants");
  500. Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
  501. Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
  502. DispatchRaysConstantsHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, CbufferName);
  503. }
  504. // Raygen always reads from the start so no offset calculations needed
  505. if (ShaderKind != DXIL::ShaderKind::RayGeneration)
  506. {
  507. std::string ShaderRecordOffsetFuncName = "\x1?Fallback_ShaderRecordOffset@@YAIXZ";
  508. Function *ShaderRecordOffsetFunc = M.getFunction(ShaderRecordOffsetFuncName);
  509. if (!ShaderRecordOffsetFunc)
  510. {
  511. FunctionType *ShaderRecordOffsetFuncType = FunctionType::get(llvm::Type::getInt32Ty(Ctx), {}, false);
  512. ShaderRecordOffsetFunc = Function::Create(ShaderRecordOffsetFuncType, GlobalValue::LinkageTypes::ExternalLinkage, ShaderRecordOffsetFuncName, &M);
  513. }
  514. BaseShaderRecordOffset = Builder.CreateCall(ShaderRecordOffsetFunc, {}, "shaderRecordOffset");
  515. }
  516. else
  517. {
  518. BaseShaderRecordOffset = HlslOP->GetU32Const(0);
  519. }
  520. }
  521. llvm::Value *DxilPatchShaderRecordBindings::CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes) {
  522. DxilModule &DM = M.GetOrCreateDxilModule();
  523. OP *HlslOP = DM.GetOP();
  524. // Create handle for the newly-added constant buffer (which is achieved via a function call)
  525. auto AdddName = "ShaderRecordOffsetInBytes";
  526. Constant *ShaderRecordOffsetInBytes = HlslOP->GetU32Const(RecordOffsetInBytes); // Offset of constants in shader record buffer
  527. return Builder.CreateAdd(CbufferOffsetInBytes, ShaderRecordOffsetInBytes, AdddName);
  528. }
  529. llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad) {
  530. DxilModule &DM = M.GetOrCreateDxilModule();
  531. OP *HlslOP = DM.GetOP();
  532. LLVMContext & Ctx = M.getContext();
  533. auto BufferLoadName = "ConstantBuffer";
  534. Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::CBufferLoadLegacy, Type::getInt32Ty(Ctx));
  535. Constant *CBufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CBufferLoadLegacy);
  536. Constant *RowToLoadConst = HlslOP->GetU32Const(RowToLoad);
  537. return Builder.CreateCall(BufferLoad, { CBufferLoadOpcodeArg, ResourceHandle, RowToLoadConst }, BufferLoadName);
  538. }
  539. llvm::Value *DxilPatchShaderRecordBindings::CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type) {
  540. DxilModule &DM = M.GetOrCreateDxilModule();
  541. OP *HlslOP = DM.GetOP();
  542. LLVMContext & Ctx = M.getContext();
  543. // Create handle for the newly-added constant buffer (which is achieved via a function call)
  544. auto BufferLoadName = "ShaderRecordBuffer";
  545. if (type->getNumContainedTypes() > 1)
  546. {
  547. // TODO: Buffer loads aren't legal with container types, check if this is the right wait to handle this
  548. type = type->getContainedType(0);
  549. }
  550. // TODO Do I need to check the result? Hopefully not
  551. Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::BufferLoad, type);
  552. Constant *BufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::BufferLoad);
  553. Constant *Unused = UndefValue::get(llvm::Type::getInt32Ty(Ctx));
  554. return Builder.CreateCall(BufferLoad, { BufferLoadOpcodeArg, ShaderTableHandle, ShaderRecordOffsetInBytes, Unused }, BufferLoadName);
  555. }
  556. void DxilPatchShaderRecordBindings::ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith) {
  557. for (auto UserIter = InstructionToReplace->user_begin(); UserIter != InstructionToReplace->user_end();) {
  558. // Increment the iterator before the replace since the replace alters the uses list
  559. auto userInstr = UserIter++;
  560. userInstr->replaceUsesOfWith(InstructionToReplace, ValueToReplaceWith);
  561. }
  562. InstructionToReplace->eraseFromParent();
  563. }
  564. llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction) {
  565. DxilModule &DM = M.GetOrCreateDxilModule();
  566. OP *HlslOP = DM.GetOP();
  567. DxilInst_CBufferLoad cbufferLoad(instruction);
  568. DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
  569. if (cbufferLoad) {
  570. return cbufferLoad.get_byteOffset();
  571. } else if (cbufferLoadLegacy) {
  572. Constant *LegacyMultiplier = HlslOP->GetU32Const(16);
  573. return Builder.CreateMul(cbufferLoadLegacy.get_regIndex(), LegacyMultiplier);
  574. } else {
  575. ThrowFailure();
  576. return nullptr;
  577. }
  578. }
  579. bool DxilPatchShaderRecordBindings::IsCBufferLoad(llvm::Instruction *instruction) {
  580. DxilInst_CBufferLoad cbufferLoad(instruction);
  581. DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
  582. return cbufferLoad || cbufferLoadLegacy;
  583. }
  584. unsigned int GetResolvedRangeID(DXIL::ResourceClass resClass, Value *rangeIdVal)
  585. {
  586. if (auto CI = dyn_cast<ConstantInt>(rangeIdVal))
  587. {
  588. return CI->getZExtValue();
  589. }
  590. else
  591. {
  592. assert(false);
  593. return 0;
  594. }
  595. }
  596. // TODO: This code is quite inefficient
  597. bool DxilPatchShaderRecordBindings::GetHandleInfo(
  598. Module &M,
  599. DxilInst_CreateHandleForLib &createHandleStructForLib,
  600. _Out_ unsigned int &shaderRegister,
  601. _Out_ unsigned int &registerSpace,
  602. _Out_ DXIL::ResourceKind &kind,
  603. _Out_ DXIL::ResourceClass &resClass,
  604. _Out_ llvm::Type *&resType)
  605. {
  606. DxilModule &DM = M.GetOrCreateDxilModule();
  607. LoadInst *loadRangeId = cast<LoadInst>(createHandleStructForLib.get_Resource());
  608. Value *ResourceSymbol = loadRangeId->getPointerOperand();
  609. DXIL::ResourceClass resourceClasses[] = {
  610. DXIL::ResourceClass::CBuffer,
  611. DXIL::ResourceClass::SRV,
  612. DXIL::ResourceClass::UAV,
  613. DXIL::ResourceClass::Sampler
  614. };
  615. hlsl::DxilResourceBase *Resource = nullptr;
  616. for (auto &resourceClass : resourceClasses) {
  617. switch (resourceClass)
  618. {
  619. case DXIL::ResourceClass::CBuffer:
  620. {
  621. auto &cbuffers = DM.GetCBuffers();
  622. for (auto &cbuffer : cbuffers)
  623. {
  624. if (cbuffer->GetGlobalSymbol() == ResourceSymbol)
  625. {
  626. Resource = cbuffer.get();
  627. break;
  628. }
  629. }
  630. break;
  631. }
  632. case DXIL::ResourceClass::SRV:
  633. case DXIL::ResourceClass::UAV:
  634. {
  635. auto &viewList = resourceClass == DXIL::ResourceClass::SRV ? DM.GetSRVs() : DM.GetUAVs();
  636. for (auto &view : viewList)
  637. {
  638. if (view->GetGlobalSymbol() == ResourceSymbol)
  639. {
  640. Resource = view.get();
  641. break;
  642. }
  643. }
  644. break;
  645. }
  646. case DXIL::ResourceClass::Sampler:
  647. {
  648. auto &samplers = DM.GetSamplers();
  649. for (auto &sampler : samplers)
  650. {
  651. if (sampler->GetGlobalSymbol() == ResourceSymbol)
  652. {
  653. Resource = sampler.get();
  654. break;
  655. }
  656. }
  657. break;
  658. }
  659. }
  660. }
  661. if (Resource)
  662. {
  663. registerSpace = Resource->GetSpaceID();
  664. shaderRegister = Resource->GetLowerBound();
  665. kind = Resource->GetKind();
  666. resClass = Resource->GetClass();
  667. resType = Resource->GetHLSLType()->getPointerElementType();
  668. }
  669. return Resource != nullptr;
  670. }
  671. llvm::Value *DxilPatchShaderRecordBindings::LoadShaderRecordData(
  672. Module &M,
  673. IRBuilder<> &Builder,
  674. llvm::Value *offsetToShaderRecord,
  675. unsigned int dataOffsetInShaderRecord)
  676. {
  677. DxilModule &DM = M.GetOrCreateDxilModule();
  678. LLVMContext &Ctx = M.getContext();
  679. OP *HlslOP = DM.GetOP();
  680. Constant *dataOffset =
  681. HlslOP->GetU32Const(dataOffsetInShaderRecord);
  682. Value *shaderTableOffsetToData = Builder.CreateAdd(dataOffset, offsetToShaderRecord);
  683. return CreateShaderRecordBufferLoad(M, Builder, shaderTableOffsetToData,
  684. llvm::Type::getInt32Ty(Ctx));
  685. }
  686. void DxilPatchShaderRecordBindings::PatchCreateHandleToUseDescriptorIndex(
  687. _In_ Module &M,
  688. _In_ IRBuilder<> &Builder,
  689. _In_ DXIL::ResourceKind &resourceKind,
  690. _In_ DXIL::ResourceClass &resourceClass,
  691. _In_ llvm::Type *resourceType,
  692. _In_ llvm::Value *descriptorIndex,
  693. _Inout_ DxilInst_CreateHandleForLib &createHandleInstr)
  694. {
  695. DxilModule &DM = M.GetOrCreateDxilModule();
  696. OP *HlslOP = DM.GetOP();
  697. llvm::Value *descriptorHeapSymbol = GetAliasedDescriptorHeapHandle(M, resourceType, resourceClass, resourceKind);
  698. llvm::Value *viewSymbol = Builder.CreateGEP(descriptorHeapSymbol, { HlslOP->GetU32Const(0), descriptorIndex }, "IndexIntoDH");
  699. DxilMDHelper::MarkNonUniform(cast<Instruction>(viewSymbol));
  700. llvm::Value *handle = Builder.CreateLoad(viewSymbol);
  701. auto callInst = cast<CallInst>(createHandleInstr.Instr);
  702. callInst->setCalledFunction(HlslOP->GetOpFunc(
  703. DXIL::OpCode::CreateHandleForLib,
  704. handle->getType()));
  705. createHandleInstr.set_Resource(handle);
  706. }
  707. void DxilPatchShaderRecordBindings::InitializeViewTable() {
  708. // The Fallback Layer declares a bindless raw buffer that spans the entire descriptor heap,
  709. // manually add it to the list of UAV register spaces used
  710. if (*pInputShaderInfo->pNumUAVSpaces == 0)
  711. {
  712. ViewKey key = { (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer, {0} };
  713. unsigned int index = FindOrInsertViewIntoList(
  714. key,
  715. pInputShaderInfo->pUAVRegisterSpaceArray,
  716. *pInputShaderInfo->pNumUAVSpaces,
  717. FallbackLayerNumDescriptorHeapSpacesPerView);
  718. (void)index;
  719. assert(index == 0);
  720. }
  721. }
  722. void DxilPatchShaderRecordBindings::PatchShaderBindings(Module &M) {
  723. DxilModule &DM = M.GetOrCreateDxilModule();
  724. OP *HlslOP = DM.GetOP();
  725. // Don't erase instructions until the very end because it throws off the iterator
  726. std::vector<llvm::Instruction *> instructionsToRemove;
  727. for (BasicBlock &block : EntryPointFunction->getBasicBlockList()) {
  728. auto & Instructions = block.getInstList();
  729. for (auto &instr : Instructions) {
  730. DxilInst_CreateHandleForLib createHandleForLib(&instr);
  731. if (createHandleForLib) {
  732. DXIL::ResourceClass resourceClass;
  733. unsigned int registerSpace;
  734. unsigned int registerIndex;
  735. DXIL::ResourceKind kind;
  736. llvm::Type *resType;
  737. bool resourceIsResolved = true;
  738. resourceIsResolved = GetHandleInfo(M, createHandleForLib, registerIndex, registerSpace, kind, resourceClass, resType);
  739. if (!resourceIsResolved) continue; // TODO: This shouldn't actually be happening?
  740. ShaderRecordEntry shaderRecord = FindRootSignatureDescriptor(
  741. *pRootSignatureDesc,
  742. pInputShaderInfo->ShaderRecordIdentifierSizeInBytes,
  743. resourceClass,
  744. registerIndex,
  745. registerSpace);
  746. const bool IsBindingSpecifiedInLocalRootSignature = !shaderRecord.IsInvalid();
  747. if (IsBindingSpecifiedInLocalRootSignature) {
  748. if (!DispatchRaysConstantsHandle) {
  749. AddInputBinding(M);
  750. }
  751. switch (shaderRecord.ParameterType) {
  752. case DxilRootParameterType::Constants32Bit:
  753. {
  754. for (User *U : instr.users()) {
  755. llvm::Instruction *instruction = cast<CallInst>(U);
  756. if (IsCBufferLoad(instruction)) {
  757. llvm::Instruction *cbufferLoadInstr = instruction;
  758. IRBuilder<> Builder(cbufferLoadInstr);
  759. llvm::Value * cbufferOffsetInBytes = CreateCBufferLoadOffsetInBytes(M, Builder, cbufferLoadInstr);
  760. llvm::Value *LocalOffsetToRootConstant = CreateOffsetToShaderRecord(M, Builder, shaderRecord.RecordOffsetInBytes, cbufferOffsetInBytes);
  761. llvm::Value *GlobalOffsetToRootConstant = Builder.CreateAdd(LocalOffsetToRootConstant, BaseShaderRecordOffset);
  762. llvm::Value *srvBufferLoad = CreateShaderRecordBufferLoad(M, Builder, GlobalOffsetToRootConstant, cbufferLoadInstr->getType());
  763. ReplaceUsesOfWith(cbufferLoadInstr, srvBufferLoad);
  764. } else {
  765. ThrowFailure();
  766. }
  767. }
  768. instructionsToRemove.push_back(&instr);
  769. break;
  770. }
  771. case DxilRootParameterType::DescriptorTable:
  772. {
  773. IRBuilder<> Builder(&instr);
  774. llvm::Value *srvBufferLoad = LoadShaderRecordData(
  775. M,
  776. Builder,
  777. BaseShaderRecordOffset,
  778. shaderRecord.RecordOffsetInBytes);
  779. llvm::Value *DescriptorTableEntryLo = Builder.CreateExtractValue(srvBufferLoad, 0, "DescriptorTableHandleLo");
  780. unsigned int offsetToLoadInUints = offsetof(DispatchRaysConstants, SrvCbvUavDescriptorHeapStart) / sizeof(uint32_t);
  781. unsigned int uintsPerRow = 4;
  782. unsigned int rowToLoad = offsetToLoadInUints / uintsPerRow;
  783. unsigned int extractValueOffset = offsetToLoadInUints % uintsPerRow;
  784. llvm::Value *DescHeapConstants = CreateCBufferLoadLegacy(M, Builder, DispatchRaysConstantsHandle, rowToLoad);
  785. llvm::Value *DescriptorHeapStartAddressLo = Builder.CreateExtractValue(DescHeapConstants, extractValueOffset, "DescriptorHeapStartHandleLo");
  786. // TODO: The hi bits can only be ignored if the difference is guaranteed to be < 32 bytes. This is an unsafe assumption, particularly given
  787. // large descriptor sizes
  788. llvm::Value *DescriptorTableOffsetInBytes = Builder.CreateSub(DescriptorTableEntryLo, DescriptorHeapStartAddressLo, "TableOffsetInBytes");
  789. Constant *DescriptorSizeInBytes = HlslOP->GetU32Const(pInputShaderInfo->SrvCbvUavDescriptorSizeInBytes);
  790. llvm::Value * DescriptorTableStartIndex = Builder.CreateExactUDiv(DescriptorTableOffsetInBytes, DescriptorSizeInBytes, "TableStartIndex");
  791. Constant *RecordOffset = HlslOP->GetU32Const(shaderRecord.OffsetInDescriptors);
  792. llvm::Value * BaseDescriptorIndex = Builder.CreateAdd(DescriptorTableStartIndex, RecordOffset, "BaseDescriptorIndex");
  793. // TODO: Not supporting dynamic indexing yet, should be pulled from CreateHandleForLib
  794. // If dynamic indexing is being used, add the apps index on top of the calculated index
  795. llvm::Value * DynamicIndex = HlslOP->GetU32Const(0);
  796. llvm::Value * DescriptorIndex = Builder.CreateAdd(BaseDescriptorIndex, DynamicIndex, "DescriptorIndex");
  797. PatchCreateHandleToUseDescriptorIndex(
  798. M,
  799. Builder,
  800. kind,
  801. resourceClass,
  802. resType,
  803. DescriptorIndex,
  804. createHandleForLib);
  805. break;
  806. }
  807. case DxilRootParameterType::CBV:
  808. case DxilRootParameterType::SRV:
  809. case DxilRootParameterType::UAV: {
  810. IRBuilder<> Builder(&instr);
  811. llvm::Value *srvBufferLoad = LoadShaderRecordData(
  812. M,
  813. Builder,
  814. BaseShaderRecordOffset,
  815. shaderRecord.RecordOffsetInBytes);
  816. llvm::Value *DescriptorIndex = Builder.CreateExtractValue(
  817. srvBufferLoad, 1, "DescriptorHeapIndex");
  818. // TODO: Handle offset in bytes
  819. // llvm::Value *OffsetInBytes = Builder.CreateExtractValue(
  820. // srvBufferLoad, 0, "OffsetInBytes");
  821. PatchCreateHandleToUseDescriptorIndex(
  822. M,
  823. Builder,
  824. kind,
  825. resourceClass,
  826. resType,
  827. DescriptorIndex,
  828. createHandleForLib);
  829. break;
  830. }
  831. default:
  832. ThrowFailure();
  833. break;
  834. }
  835. }
  836. }
  837. }
  838. }
  839. for (auto instruction : instructionsToRemove) {
  840. instruction->eraseFromParent();
  841. }
  842. }
  843. bool IsParameterTypeCompatibleWithResourceClass(
  844. DXIL::ResourceClass resourceClass,
  845. DxilRootParameterType parameterType) {
  846. switch (parameterType) {
  847. case DxilRootParameterType::DescriptorTable:
  848. return true;
  849. case DxilRootParameterType::Constants32Bit:
  850. case DxilRootParameterType::CBV:
  851. return resourceClass == DXIL::ResourceClass::CBuffer;
  852. case DxilRootParameterType::SRV:
  853. return resourceClass == DXIL::ResourceClass::SRV;
  854. case DxilRootParameterType::UAV:
  855. return resourceClass == DXIL::ResourceClass::UAV;
  856. default:
  857. ThrowFailure();
  858. return false;
  859. }
  860. }
  861. DxilRootParameterType ConvertD3D12ParameterTypeToDxil(DxilRootParameterType parameter) {
  862. switch (parameter) {
  863. case DxilRootParameterType::Constants32Bit:
  864. return DxilRootParameterType::Constants32Bit;
  865. case DxilRootParameterType::DescriptorTable:
  866. return DxilRootParameterType::DescriptorTable;
  867. case DxilRootParameterType::CBV:
  868. return DxilRootParameterType::CBV;
  869. case DxilRootParameterType::SRV:
  870. return DxilRootParameterType::SRV;
  871. case DxilRootParameterType::UAV:
  872. return DxilRootParameterType::UAV;
  873. }
  874. assert(false);
  875. return (DxilRootParameterType)-1;
  876. }
  877. DXIL::ResourceClass ConvertD3D12RangeTypeToDxil(DxilDescriptorRangeType rangeType) {
  878. switch (rangeType) {
  879. case DxilDescriptorRangeType::SRV:
  880. return DXIL::ResourceClass::SRV;
  881. case DxilDescriptorRangeType::UAV:
  882. return DXIL::ResourceClass::UAV;
  883. case DxilDescriptorRangeType::CBV:
  884. return DXIL::ResourceClass::CBuffer;
  885. case DxilDescriptorRangeType::Sampler:
  886. return DXIL::ResourceClass::Sampler;
  887. }
  888. assert(false);
  889. return (DXIL::ResourceClass) - 1;
  890. }
  891. unsigned int GetParameterTypeAlignment(DxilRootParameterType parameterType) {
  892. switch (parameterType) {
  893. case DxilRootParameterType::DescriptorTable:
  894. return SizeofD3D12GpuDescriptorHandle;
  895. case DxilRootParameterType::Constants32Bit:
  896. return sizeof(uint32_t);
  897. case DxilRootParameterType::CBV: // fallthrough
  898. case DxilRootParameterType::SRV: // fallthrough
  899. case DxilRootParameterType::UAV:
  900. return SizeofD3D12GpuVA;
  901. default:
  902. return UINT_MAX;
  903. }
  904. }
  905. template <typename TD3D12_ROOT_SIGNATURE_DESC>
  906. ShaderRecordEntry FindRootSignatureDescriptorHelper(
  907. const TD3D12_ROOT_SIGNATURE_DESC &rootSignatureDescriptor,
  908. unsigned int ShaderRecordIdentifierSizeInBytes,
  909. DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
  910. unsigned int registerSpace) {
  911. // Automatically fail if it's looking for a fallback binding as these never
  912. // need to be patched
  913. if (registerSpace != FallbackLayerRegisterSpace) {
  914. unsigned int recordOffset = ShaderRecordIdentifierSizeInBytes;
  915. for (unsigned int rootParamIndex = 0;
  916. rootParamIndex < rootSignatureDescriptor.NumParameters;
  917. rootParamIndex++) {
  918. auto &rootParam = rootSignatureDescriptor.pParameters[rootParamIndex];
  919. auto dxilParamType =
  920. ConvertD3D12ParameterTypeToDxil(rootParam.ParameterType);
  921. #define ALIGN(alignment, num) (((num + alignment - 1) / alignment) * alignment)
  922. recordOffset = ALIGN(GetParameterTypeAlignment(rootParam.ParameterType),
  923. recordOffset);
  924. switch (rootParam.ParameterType) {
  925. case DxilRootParameterType::Constants32Bit:
  926. if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
  927. dxilParamType) &&
  928. baseRegisterIndex == rootParam.Constants.ShaderRegister &&
  929. registerSpace == rootParam.Constants.RegisterSpace) {
  930. return {dxilParamType, recordOffset, 0};
  931. }
  932. recordOffset += rootParam.Constants.Num32BitValues * sizeof(uint32_t);
  933. break;
  934. case DxilRootParameterType::DescriptorTable: {
  935. auto &descriptorTable = rootParam.DescriptorTable;
  936. unsigned int rangeOffsetInDescriptors = 0;
  937. for (unsigned int rangeIndex = 0;
  938. rangeIndex < descriptorTable.NumDescriptorRanges; rangeIndex++) {
  939. auto &range = descriptorTable.pDescriptorRanges[rangeIndex];
  940. if (range.OffsetInDescriptorsFromTableStart != (unsigned)-1) {
  941. rangeOffsetInDescriptors = range.OffsetInDescriptorsFromTableStart;
  942. }
  943. if (ConvertD3D12RangeTypeToDxil(range.RangeType) == resourceClass &&
  944. range.RegisterSpace == registerSpace &&
  945. range.BaseShaderRegister <= baseRegisterIndex &&
  946. range.BaseShaderRegister + range.NumDescriptors >
  947. baseRegisterIndex) {
  948. rangeOffsetInDescriptors +=
  949. baseRegisterIndex - range.BaseShaderRegister;
  950. return {dxilParamType, recordOffset, rangeOffsetInDescriptors};
  951. }
  952. rangeOffsetInDescriptors += range.NumDescriptors;
  953. }
  954. recordOffset += SizeofD3D12GpuDescriptorHandle;
  955. break;
  956. }
  957. case DxilRootParameterType::CBV:
  958. case DxilRootParameterType::SRV:
  959. case DxilRootParameterType::UAV:
  960. if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
  961. dxilParamType) &&
  962. baseRegisterIndex == rootParam.Descriptor.ShaderRegister &&
  963. registerSpace == rootParam.Descriptor.RegisterSpace) {
  964. return {dxilParamType, recordOffset, 0};
  965. }
  966. recordOffset += SizeofD3D12GpuVA;
  967. break;
  968. }
  969. }
  970. }
  971. return ShaderRecordEntry::InvalidEntry();
  972. }
  973. // TODO: Consider pre-calculating this into a map
  974. ShaderRecordEntry DxilPatchShaderRecordBindings::FindRootSignatureDescriptor(
  975. const DxilVersionedRootSignatureDesc &rootSignatureDescriptor,
  976. unsigned int ShaderRecordIdentifierSizeInBytes,
  977. DXIL::ResourceClass resourceClass,
  978. unsigned int baseRegisterIndex,
  979. unsigned int registerSpace) {
  980. switch (rootSignatureDescriptor.Version) {
  981. case DxilRootSignatureVersion::Version_1_0:
  982. return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_0, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
  983. case DxilRootSignatureVersion::Version_1_1:
  984. return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_1, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
  985. default:
  986. ThrowFailure();
  987. return ShaderRecordEntry::InvalidEntry();
  988. }
  989. }