12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148 |
- ///////////////////////////////////////////////////////////////////////////////
- // //
- // DxilPatchShaderRecordBindings.cpp //
- // Copyright (C) Microsoft Corporation. All rights reserved. //
- // This file is distributed under the University of Illinois Open Source //
- // License. See LICENSE.TXT for details. //
- // //
- // Provides a pass used by the RayTracing Fallback Lyaer to add modify //
- // bindings to pull local root signature parameters from a global //
- // "shader table" buffer instead //
- // //
- ///////////////////////////////////////////////////////////////////////////////
- #include "dxc/HLSL/DxilGenerationPass.h"
- #include "dxc/HLSL/DxilFallbackLayerPass.h"
- #include "dxc/DXIL/DxilOperations.h"
- #include "dxc/DXIL/DxilSignatureElement.h"
- #include "dxc/DXIL/DxilFunctionProps.h"
- #include "dxc/DXIL/DxilModule.h"
- #include "dxc/Support/Global.h"
- #include "dxc/Support/Unicode.h"
- #include "dxc/DXIL/DxilTypeSystem.h"
- #include "dxc/DXIL/DxilConstants.h"
- #include "dxc/DXIL/DxilInstructions.h"
- #include "dxc/HLSL/DxilSpanAllocator.h"
- #include "dxc/DxilRootSignature/DxilRootSignature.h"
- #include "dxc/DXIL/DxilUtil.h"
- #include "llvm/Transforms/Utils/Cloning.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/InstIterator.h"
- #include "llvm/IR/Module.h"
- #include "llvm/IR/PassManager.h"
- #include "llvm/ADT/BitVector.h"
- #include "llvm/Pass.h"
- #include "llvm/Transforms/Utils/Local.h"
- #include "llvm/Transforms/Scalar.h"
- #include <memory>
- #include <unordered_set>
- #include <functional>
- #include <unordered_map>
- #include <array>
- struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
- #include "DxilPatchShaderRecordBindingsShared.h"
- using namespace llvm;
- using namespace hlsl;
- bool operator==(const ViewKey &a, const ViewKey &b) {
- return memcmp(&a, &b, sizeof(a)) == 0;
- }
- const size_t SizeofD3D12GpuVA = sizeof(uint64_t);
- const size_t SizeofD3D12GpuDescriptorHandle = sizeof(uint64_t);
- Function *CloneFunction(Function *Orig,
- const llvm::Twine &Name,
- llvm::Module *llvmModule) {
- Function *F = Function::Create(Orig->getFunctionType(),
- GlobalValue::LinkageTypes::ExternalLinkage,
- Name, llvmModule);
- SmallVector<ReturnInst *, 2> Returns;
- ValueToValueMapTy vmap;
- // Map params.
- auto entryParamIt = F->arg_begin();
- for (Argument ¶m : Orig->args()) {
- vmap[¶m] = (entryParamIt++);
- }
- DxilModule &DM = llvmModule->GetOrCreateDxilModule();
- llvm::CloneFunctionInto(F, Orig, vmap, /*ModuleLevelChagnes*/ false, Returns);
- DM.GetTypeSystem().CopyFunctionAnnotation(F, Orig, DM.GetTypeSystem());
- if (DM.HasDxilFunctionProps(F)) {
- DM.CloneDxilEntryProps(Orig, F);
- }
- return F;
- }
- struct ShaderRecordEntry {
- DxilRootParameterType ParameterType;
- unsigned int RecordOffsetInBytes;
- unsigned int OffsetInDescriptors; // Only valid for descriptor tables
- static ShaderRecordEntry InvalidEntry() { return { (DxilRootParameterType)-1, (unsigned int)-1, 0 }; }
- bool IsInvalid() { return (unsigned int)ParameterType == (unsigned int)-1; }
- };
- struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
- class DxilPatchShaderRecordBindings : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit DxilPatchShaderRecordBindings() : ModulePass(ID) {}
- const char *getPassName() const override { return "DXIL Patch Shader Record Binding"; }
- void applyOptions(PassOptions O) override;
- bool runOnModule(Module &M) override;
- private:
- void ValidateParameters();
- void AddInputBinding(Module &M);
- void PatchShaderBindings(Module &M);
- void InitializeViewTable();
- unsigned int AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName);
- unsigned int AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type = nullptr, unsigned int constantBufferSize = 0);
- unsigned int AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type);
- unsigned int AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName);
- llvm::Value *CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes);
- llvm::Value *CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type);
- llvm::Value *CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction);
- llvm::Value *CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad = 0);
- llvm::Value *LoadShaderRecordData(Module &M, IRBuilder<> &Builder,
- llvm::Value *offsetToShaderRecord,
- unsigned int dataOffsetInShaderRecord);
- void PatchCreateHandleToUseDescriptorIndex(
- _In_ Module &M,
- _In_ IRBuilder<> &Builder,
- _In_ DXIL::ResourceKind &resourceKind,
- _In_ DXIL::ResourceClass &resourceClass,
- _In_ llvm::Type *resourceType,
- _In_ llvm::Value *descriptorIndex,
- _Inout_ DxilInst_CreateHandleForLib &createHandleInstr);
- bool GetHandleInfo(
- Module &M,
- DxilInst_CreateHandleForLib &createHandleStructForLib,
- _Out_ unsigned int &shaderRegister,
- _Out_ unsigned int ®isterSpace,
- _Out_ DXIL::ResourceKind &kind,
- _Out_ DXIL::ResourceClass &resClass,
- _Out_ llvm::Type *&resType);
- llvm::Value * GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind);
- unsigned int GetConstantBufferOffsetToShaderRecord();
- bool IsCBufferLoad(llvm::Instruction *instruction);
- // Unlike the LLVM version of this function, this does not requires the InstructionToReplace and the ValueToReplaceWith to be the same instruction type
- static void ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith);
- static ShaderRecordEntry FindRootSignatureDescriptor(const DxilVersionedRootSignatureDesc &rootSignatureDescriptor, unsigned int ShaderRecordIdentifierSizeInBytes, DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex, unsigned int registerSpace);
- // TODO: I would like to see these prefixed with m_
- llvm::Value *ShaderTableHandle = nullptr;
- llvm::Value *DispatchRaysConstantsHandle = nullptr;
- llvm::Value *BaseShaderRecordOffset = nullptr;
- static const unsigned int NumViewTypes = 4;
- struct ViewKeyHasher
- {
- public:
- std::size_t operator()(const ViewKey &x) const {
- return std::hash<unsigned int>()((unsigned int)x.ViewType) ^
- std::hash<unsigned int>()((unsigned int)x.StructuredStride);
- }
- };
- std::unordered_map<ViewKey, llvm::Value *, ViewKeyHasher>
- TypeToAliasedDescriptorHeap[NumViewTypes];
- llvm::Function *EntryPointFunction;
- ShaderInfo *pInputShaderInfo;
- const DxilVersionedRootSignatureDesc *pRootSignatureDesc;
- DXIL::ShaderKind ShaderKind;
- };
- char DxilPatchShaderRecordBindings::ID = 0;
- // TODO: Find the right thing to do on failure
- void ThrowFailure() {
- throw std::exception();
- }
- // TODO: Stolen from Brandon's code, merge
- // Remove ELF mangling
- static inline std::string GetUnmangledName(StringRef name) {
- if (!name.startswith("\x1?"))
- return name;
- size_t pos = name.find("@@");
- if (pos == name.npos)
- return name;
- return name.substr(2, pos - 2);
- }
- static Function* getFunctionFromName(Module &M, const std::wstring& exportName) {
- for (auto F = M.begin(), E = M.end(); F != E; ++F) {
- std::wstring functionName = Unicode::UTF8ToUTF16StringOrThrow(GetUnmangledName(F->getName()).c_str());
- if (exportName == functionName) {
- return F;
- }
- }
- return nullptr;
- }
- ModulePass *llvm::createDxilPatchShaderRecordBindingsPass() {
- return new DxilPatchShaderRecordBindings();
- }
- INITIALIZE_PASS(DxilPatchShaderRecordBindings, "hlsl-dxil-patch-shader-record-bindings", "Patch shader record bindings to instead pull from the fallback provided bindings", false, false)
- void DxilPatchShaderRecordBindings::applyOptions(PassOptions O) {
- for (const auto & option : O) {
- if (0 == option.first.compare("root-signature")) {
- unsigned int cHexRadix = 16;
- pInputShaderInfo = (ShaderInfo*)strtoull(option.second.data(), nullptr, cHexRadix);
- pRootSignatureDesc = (const DxilVersionedRootSignatureDesc*)pInputShaderInfo->pRootSignatureDesc;
- }
- }
- }
- void AddAnnoationsIfNeeded(DxilModule &DM, llvm::StructType *StructTy, const std::string &FieldName, unsigned int numFields = 1)
- {
- auto pAnnotation = DM.GetTypeSystem().GetStructAnnotation(StructTy);
- if (pAnnotation == nullptr)
- {
- pAnnotation = DM.GetTypeSystem().AddStructAnnotation(StructTy);
- pAnnotation->SetCBufferSize(sizeof(uint32_t) * numFields);
- for (unsigned int i = 0; i < numFields; i++)
- {
- pAnnotation->GetFieldAnnotation(i).SetCBufferOffset(sizeof(uint32_t) * i);
- pAnnotation->GetFieldAnnotation(i).SetCompType(hlsl::DXIL::ComponentType::I32);
- pAnnotation->GetFieldAnnotation(i).SetFieldName(FieldName + std::to_string(i));
- }
- }
- }
- unsigned int DxilPatchShaderRecordBindings::AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type, unsigned int constantBufferSize) {
- LLVMContext & Ctx = M.getContext();
- DxilModule &DM = M.GetOrCreateDxilModule();
- // Set up a SRV with byte address buffer
- unsigned int resourceHandle;
- std::unique_ptr<DxilResource> pHandle;
- std::unique_ptr<DxilCBuffer> pCBuf;
- std::unique_ptr<DxilSampler> pSampler;
- DxilResourceBase *pBaseHandle;
- switch (resClass) {
- case DXIL::ResourceClass::SRV:
- resourceHandle = static_cast<unsigned int>(DM.GetSRVs().size());
- pHandle = llvm::make_unique<DxilResource>();
- pHandle->SetRW(false);
- pBaseHandle = pHandle.get();
- break;
- case DXIL::ResourceClass::UAV:
- resourceHandle = static_cast<unsigned int>(DM.GetUAVs().size());
- pHandle = llvm::make_unique<DxilResource>();
- pHandle->SetRW(true);
- pBaseHandle = pHandle.get();
- break;
- case DXIL::ResourceClass::CBuffer:
- resourceHandle = static_cast<unsigned int>(DM.GetCBuffers().size());
- pCBuf = llvm::make_unique<DxilCBuffer>();
- pCBuf->SetSize(constantBufferSize);
- pBaseHandle = pCBuf.get();
- break;
- case DXIL::ResourceClass::Sampler:
- resourceHandle = static_cast<unsigned int>(DM.GetSamplers().size());
- pSampler = llvm::make_unique<DxilSampler>();
- // TODO: Is this okay? What if one of the samplers in the table is a comparison sampler?
- pSampler->SetSamplerKind(DxilSampler::SamplerKind::Default);
- pBaseHandle = pSampler.get();
- break;
- }
- if (!type) {
- SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
- std::string ByteAddressBufferName = "struct.ByteAddressBuffer";
- type = M.getTypeByName(ByteAddressBufferName);
- if (!type)
- {
- StructType *StructTy;
- type = StructTy = StructType::create(Elements, ByteAddressBufferName);
-
- AddAnnoationsIfNeeded(DM, StructTy, ByteAddressBufferName);
- }
- }
- GlobalVariable *GV = M.getGlobalVariable(bufferName);
- if (!GV) {
- GV = cast<GlobalVariable>(M.getOrInsertGlobal(bufferName, type));
- }
- pBaseHandle->SetGlobalName(bufferName.c_str());
- pBaseHandle->SetGlobalSymbol(GV);
- pBaseHandle->SetID(resourceHandle);
- pBaseHandle->SetSpaceID(registerSpace);
- pBaseHandle->SetLowerBound(baseRegisterIndex);
- pBaseHandle->SetRangeSize(rangeSize);
- pBaseHandle->SetKind(resKind);
- if (pHandle) {
- pHandle->SetGloballyCoherent(false);
- pHandle->SetHasCounter(false);
- pHandle->SetCompType(CompType::getF32()); // TODO: Need to handle all types
- }
- unsigned int ID;
- switch (resClass) {
- case DXIL::ResourceClass::SRV:
- ID = DM.AddSRV(std::move(pHandle));
- break;
- case DXIL::ResourceClass::UAV:
- ID = DM.AddUAV(std::move(pHandle));
- break;
- case DXIL::ResourceClass::CBuffer:
- ID = DM.AddCBuffer(std::move(pCBuf));
- break;
- case DXIL::ResourceClass::Sampler:
- ID = DM.AddSampler(std::move(pSampler));
- break;
- }
- assert(ID == resourceHandle);
- return ID;
- }
- unsigned int DxilPatchShaderRecordBindings::GetConstantBufferOffsetToShaderRecord()
- {
- switch (ShaderKind)
- {
- case DXIL::ShaderKind::ClosestHit:
- case DXIL::ShaderKind::AnyHit:
- case DXIL::ShaderKind::Intersection:
- return offsetof(DispatchRaysConstants, HitGroupShaderRecordStride);
- case DXIL::ShaderKind::Miss:
- return offsetof(DispatchRaysConstants, MissShaderRecordStride);
- default:
- ThrowFailure();
- return -1;
- }
- }
- unsigned int DxilPatchShaderRecordBindings::AddSRVRawBuffer(Module &M, unsigned int registerIndex, unsigned int registerSpace, const std::string &bufferName) {
- return AddHandle(M, registerIndex, 1, registerSpace, DXIL::ResourceClass::SRV, DXIL::ResourceKind::RawBuffer, bufferName);
- }
- llvm::Constant *GetArraySymbol(Module &M, const std::string &bufferName) {
- LLVMContext & Ctx = M.getContext();
- SmallVector<llvm::Type*, 1> Elements{ Type::getInt32Ty(Ctx) };
- llvm::StructType *StructTy = llvm::StructType::create(Elements, bufferName);
- llvm::ArrayType *ArrayTy = ArrayType::get(StructTy, -1);
- return UndefValue::get(ArrayTy->getPointerTo());
- }
- unsigned int DxilPatchShaderRecordBindings::AddCBufferAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, const std::string &bufferName) {
- const unsigned int maxConstantBufferSize = 4096 * 16;
- return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, bufferName, GetArraySymbol(M, bufferName)->getType(), maxConstantBufferSize);
- }
- unsigned int DxilPatchShaderRecordBindings::AddAliasedHandle(Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type) {
- return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, resClass, resKind, bufferName, type);
- }
- // TODO: Stolen from Brandon's code
- DXIL::ShaderKind GetRayShaderKindCopy(Function* F)
- {
- if (F->hasFnAttribute("exp-shader"))
- return DXIL::ShaderKind::RayGeneration;
- DxilModule& DM = F->getParent()->GetDxilModule();
- if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
- return DM.GetDxilFunctionProps(F).shaderKind;
- return DXIL::ShaderKind::Invalid;
- }
- bool DxilPatchShaderRecordBindings::runOnModule(Module &M) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- EntryPointFunction = pInputShaderInfo->ExportName ? getFunctionFromName(M, pInputShaderInfo->ExportName) : DM.GetEntryFunction();
- ShaderKind = GetRayShaderKindCopy(EntryPointFunction);
- ValidateParameters();
- InitializeViewTable();
- PatchShaderBindings(M);
- DM.ReEmitDxilResources();
- return true;
- }
- void DxilPatchShaderRecordBindings::ValidateParameters() {
- if (!pInputShaderInfo || !pInputShaderInfo->pRootSignatureDesc) {
- throw std::exception();
- }
- }
- DxilResourceBase &GetResourceFromID(DxilModule &DM, DXIL::ResourceClass resClass, unsigned int id)
- {
- switch (resClass)
- {
- case DXIL::ResourceClass::CBuffer:
- return DM.GetCBuffer(id);
- break;
- case DXIL::ResourceClass::SRV:
- return DM.GetSRV(id);
- break;
- case DXIL::ResourceClass::UAV:
- return DM.GetUAV(id);
- break;
- case DXIL::ResourceClass::Sampler:
- return DM.GetSampler(id);
- break;
- default:
- ThrowFailure();
- llvm_unreachable("invalid resource class");
- }
- }
- unsigned int FindOrInsertViewIntoList(const ViewKey &key, ViewKey *pViewList, unsigned int &numViews, unsigned int maxViews)
- {
- unsigned int viewIndex = 0;
- for (; viewIndex < numViews; viewIndex++)
- {
- if (pViewList[viewIndex] == key)
- {
- break;
- }
- }
- if (viewIndex == numViews)
- {
- if (viewIndex >= maxViews) {
- ThrowFailure();
- }
- pViewList[viewIndex] = key;
- numViews++;
- }
- return viewIndex;
- }
- llvm::Value *DxilPatchShaderRecordBindings::GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *type, DXIL::ResourceClass resClass, DXIL::ResourceKind resKind)
- {
- DxilModule &DM = M.GetOrCreateDxilModule();
- unsigned int resClassIndex = (unsigned int)resClass;
-
- ViewKey key = {};
- key.ViewType = (unsigned int)resKind;
- if (DXIL::IsStructuredBuffer(resKind))
- {
- key.StructuredStride = type->getPrimitiveSizeInBits();
- } else if (resKind != DXIL::ResourceKind::RawBuffer)
- {
- auto containedType = type->getContainedType(0);
- // If it's a vector, get the type of just a single element
- if (containedType->getNumContainedTypes() > 0)
- {
- assert(containedType->getNumContainedTypes() <= 4);
- containedType = containedType->getContainedType(0);
- }
- key.SRVComponentType = (unsigned int)CompType::GetCompType(containedType).GetKind();
- }
- auto aliasedDescriptorHeapHandle = TypeToAliasedDescriptorHeap[resClassIndex].find(key);
- if (aliasedDescriptorHeapHandle == TypeToAliasedDescriptorHeap[resClassIndex].end())
- {
- unsigned int registerSpaceOffset = 0;
- std::string HandleName;
- if (resClass == DXIL::ResourceClass::SRV)
- {
- registerSpaceOffset = FindOrInsertViewIntoList(
- key,
- pInputShaderInfo->pSRVRegisterSpaceArray,
- *pInputShaderInfo->pNumSRVSpaces,
- FallbackLayerNumDescriptorHeapSpacesPerView);
- HandleName = std::string("SRVDescriptorHeapTable") +
- std::to_string(registerSpaceOffset);
- }
- else if (resClass == DXIL::ResourceClass::UAV)
- {
- registerSpaceOffset = FindOrInsertViewIntoList(
- key,
- pInputShaderInfo->pUAVRegisterSpaceArray,
- *pInputShaderInfo->pNumUAVSpaces,
- FallbackLayerNumDescriptorHeapSpacesPerView);
- if (registerSpaceOffset == 0)
- {
- // Using the descriptor heap declared by the fallback for handling emulated pointers,
- // make sure the name is an exact match
- assert(key.ViewType == (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer);
- HandleName = "\01?DescriptorHeapBufferTable@@3PAURWByteAddressBuffer@@A";
- }
- else
- {
- HandleName = std::string("UAVDescriptorHeapTable") +
- std::to_string(registerSpaceOffset);
- }
- }
- else if (resClass == DXIL::ResourceClass::CBuffer)
- {
- HandleName = std::string("CBVDescriptorHeapTable");
- } else {
- HandleName = std::string("SamplerDescriptorHeapTable");
- }
- llvm::ArrayType *descriptorHeapType = ArrayType::get(type, 0);
- unsigned int id = AddAliasedHandle(M, FallbackLayerDescriptorHeapTable, FallbackLayerRegisterSpace + FallbackLayerDescriptorHeapSpaceOffset + registerSpaceOffset, resClass, resKind, HandleName, descriptorHeapType);
-
- TypeToAliasedDescriptorHeap[resClassIndex][key] = GetResourceFromID(DM, resClass, id).GetGlobalSymbol();
- }
- return TypeToAliasedDescriptorHeap[resClassIndex][key];
- }
- void DxilPatchShaderRecordBindings::AddInputBinding(Module &M) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- auto & EntryBlock = EntryPointFunction->getEntryBlock();
- auto & Instructions = EntryBlock.getInstList();
- std::string bufferName;
- unsigned int bufferRegister;
- switch (ShaderKind) {
- case DXIL::ShaderKind::AnyHit:
- case DXIL::ShaderKind::ClosestHit:
- case DXIL::ShaderKind::Intersection:
- bufferRegister = FallbackLayerHitGroupRecordByteAddressBufferRegister;
- bufferName = "\01?HitGroupShaderTable@@3UByteAddressBuffer@@A";
- break;
- case DXIL::ShaderKind::Miss:
- bufferRegister = FallbackLayerMissShaderRecordByteAddressBufferRegister;
- bufferName = "\01?MissShaderTable@@3UByteAddressBuffer@@A";
- break;
- case DXIL::ShaderKind::RayGeneration:
- bufferRegister = FallbackLayerRayGenShaderRecordByteAddressBufferRegister;
- bufferName = "\01?RayGenShaderTable@@3UByteAddressBuffer@@A";
- break;
- case DXIL::ShaderKind::Callable:
- bufferRegister = FallbackLayerCallableShaderRecordByteAddressBufferRegister;
- bufferName = "\01?CallableShaderTable@@3UByteAddressBuffer@@A";
- break;
- }
- unsigned int ShaderRecordID = AddSRVRawBuffer(M, bufferRegister, FallbackLayerRegisterSpace, bufferName);
- auto It = Instructions.begin();
- OP *HlslOP = DM.GetOP();
- LLVMContext & Ctx = M.getContext();
- IRBuilder<> Builder(It);
- {
- auto ShaderTableName = "ShaderTableHandle";
- llvm::Value *Symbol = DM.GetSRV(ShaderRecordID).GetGlobalSymbol();
- llvm::Value *Load = Builder.CreateLoad(Symbol, "LoadShaderTableHandle");
- Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
- Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
- ShaderTableHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, ShaderTableName);
- }
- {
- auto CbufferName = "Constants";
- const unsigned int sizeOfConstantsInBytes = sizeof(DispatchRaysConstants);
- llvm::StructType *StructTy= M.getTypeByName(CbufferName);
- if (!StructTy)
- {
- const unsigned int numUintsInConstants = sizeOfConstantsInBytes / sizeof(unsigned int);
- SmallVector<llvm::Type*, numUintsInConstants> Elements(numUintsInConstants);
- for (unsigned int i = 0; i < numUintsInConstants; i++)
- {
- Elements[i] = Type::getInt32Ty(Ctx);
- }
- StructTy = llvm::StructType::create(Elements, CbufferName);
- AddAnnoationsIfNeeded(DM, StructTy, std::string(CbufferName), numUintsInConstants);
- }
- unsigned int handle = AddHandle(M, FallbackLayerDispatchConstantsRegister, 1, FallbackLayerRegisterSpace, DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer, CbufferName, StructTy, sizeOfConstantsInBytes);
- llvm::Value *Symbol = DM.GetCBuffer(handle).GetGlobalSymbol();
- llvm::Value *Load = Builder.CreateLoad(Symbol, "DispatchRaysConstants");
- Function *CreateHandleForLib = HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
- Constant *CreateHandleOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
- DispatchRaysConstantsHandle = Builder.CreateCall(CreateHandleForLib, { CreateHandleOpcodeArg, Load }, CbufferName);
- }
-
- // Raygen always reads from the start so no offset calculations needed
- if (ShaderKind != DXIL::ShaderKind::RayGeneration)
- {
- std::string ShaderRecordOffsetFuncName = "\x1?Fallback_ShaderRecordOffset@@YAIXZ";
- Function *ShaderRecordOffsetFunc = M.getFunction(ShaderRecordOffsetFuncName);
- if (!ShaderRecordOffsetFunc)
- {
- FunctionType *ShaderRecordOffsetFuncType = FunctionType::get(llvm::Type::getInt32Ty(Ctx), {}, false);
- ShaderRecordOffsetFunc = Function::Create(ShaderRecordOffsetFuncType, GlobalValue::LinkageTypes::ExternalLinkage, ShaderRecordOffsetFuncName, &M);
- }
- BaseShaderRecordOffset = Builder.CreateCall(ShaderRecordOffsetFunc, {}, "shaderRecordOffset");
- }
- else
- {
- BaseShaderRecordOffset = HlslOP->GetU32Const(0);
- }
- }
- llvm::Value *DxilPatchShaderRecordBindings::CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes, llvm::Value *CbufferOffsetInBytes) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- // Create handle for the newly-added constant buffer (which is achieved via a function call)
- auto AdddName = "ShaderRecordOffsetInBytes";
- Constant *ShaderRecordOffsetInBytes = HlslOP->GetU32Const(RecordOffsetInBytes); // Offset of constants in shader record buffer
- return Builder.CreateAdd(CbufferOffsetInBytes, ShaderRecordOffsetInBytes, AdddName);
- }
- llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle, unsigned int RowToLoad) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- LLVMContext & Ctx = M.getContext();
- auto BufferLoadName = "ConstantBuffer";
- Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::CBufferLoadLegacy, Type::getInt32Ty(Ctx));
- Constant *CBufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::CBufferLoadLegacy);
- Constant *RowToLoadConst = HlslOP->GetU32Const(RowToLoad);
- return Builder.CreateCall(BufferLoad, { CBufferLoadOpcodeArg, ResourceHandle, RowToLoadConst }, BufferLoadName);
- }
- llvm::Value *DxilPatchShaderRecordBindings::CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes, llvm::Type* type) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- LLVMContext & Ctx = M.getContext();
- // Create handle for the newly-added constant buffer (which is achieved via a function call)
- auto BufferLoadName = "ShaderRecordBuffer";
- if (type->getNumContainedTypes() > 1)
- {
- // TODO: Buffer loads aren't legal with container types, check if this is the right wait to handle this
- type = type->getContainedType(0);
- }
- // TODO Do I need to check the result? Hopefully not
- Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::BufferLoad, type);
- Constant *BufferLoadOpcodeArg = HlslOP->GetU32Const((unsigned)DXIL::OpCode::BufferLoad);
- Constant *Unused = UndefValue::get(llvm::Type::getInt32Ty(Ctx));
- return Builder.CreateCall(BufferLoad, { BufferLoadOpcodeArg, ShaderTableHandle, ShaderRecordOffsetInBytes, Unused }, BufferLoadName);
- }
- void DxilPatchShaderRecordBindings::ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith) {
- for (auto UserIter = InstructionToReplace->user_begin(); UserIter != InstructionToReplace->user_end();) {
- // Increment the iterator before the replace since the replace alters the uses list
- auto userInstr = UserIter++;
- userInstr->replaceUsesOfWith(InstructionToReplace, ValueToReplaceWith);
- }
- InstructionToReplace->eraseFromParent();
- }
- llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- DxilInst_CBufferLoad cbufferLoad(instruction);
- DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
- if (cbufferLoad) {
- return cbufferLoad.get_byteOffset();
- } else if (cbufferLoadLegacy) {
- Constant *LegacyMultiplier = HlslOP->GetU32Const(16);
- return Builder.CreateMul(cbufferLoadLegacy.get_regIndex(), LegacyMultiplier);
- } else {
- ThrowFailure();
- return nullptr;
- }
- }
- bool DxilPatchShaderRecordBindings::IsCBufferLoad(llvm::Instruction *instruction) {
- DxilInst_CBufferLoad cbufferLoad(instruction);
- DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
- return cbufferLoad || cbufferLoadLegacy;
- }
- unsigned int GetResolvedRangeID(DXIL::ResourceClass resClass, Value *rangeIdVal)
- {
- if (auto CI = dyn_cast<ConstantInt>(rangeIdVal))
- {
- return CI->getZExtValue();
- }
- else
- {
- assert(false);
- return 0;
- }
- }
- // TODO: This code is quite inefficient
- bool DxilPatchShaderRecordBindings::GetHandleInfo(
- Module &M,
- DxilInst_CreateHandleForLib &createHandleStructForLib,
- _Out_ unsigned int &shaderRegister,
- _Out_ unsigned int ®isterSpace,
- _Out_ DXIL::ResourceKind &kind,
- _Out_ DXIL::ResourceClass &resClass,
- _Out_ llvm::Type *&resType)
- {
- DxilModule &DM = M.GetOrCreateDxilModule();
- LoadInst *loadRangeId = cast<LoadInst>(createHandleStructForLib.get_Resource());
- Value *ResourceSymbol = loadRangeId->getPointerOperand();
- DXIL::ResourceClass resourceClasses[] = {
- DXIL::ResourceClass::CBuffer,
- DXIL::ResourceClass::SRV,
- DXIL::ResourceClass::UAV,
- DXIL::ResourceClass::Sampler
- };
- hlsl::DxilResourceBase *Resource = nullptr;
- for (auto &resourceClass : resourceClasses) {
-
- switch (resourceClass)
- {
- case DXIL::ResourceClass::CBuffer:
- {
- auto &cbuffers = DM.GetCBuffers();
- for (auto &cbuffer : cbuffers)
- {
- if (cbuffer->GetGlobalSymbol() == ResourceSymbol)
- {
- Resource = cbuffer.get();
- break;
- }
- }
- break;
- }
- case DXIL::ResourceClass::SRV:
- case DXIL::ResourceClass::UAV:
- {
- auto &viewList = resourceClass == DXIL::ResourceClass::SRV ? DM.GetSRVs() : DM.GetUAVs();
- for (auto &view : viewList)
- {
- if (view->GetGlobalSymbol() == ResourceSymbol)
- {
- Resource = view.get();
- break;
- }
- }
- break;
- }
- case DXIL::ResourceClass::Sampler:
- {
- auto &samplers = DM.GetSamplers();
- for (auto &sampler : samplers)
- {
- if (sampler->GetGlobalSymbol() == ResourceSymbol)
- {
- Resource = sampler.get();
- break;
- }
- }
- break;
- }
- }
- }
- if (Resource)
- {
- registerSpace = Resource->GetSpaceID();
- shaderRegister = Resource->GetLowerBound();
- kind = Resource->GetKind();
- resClass = Resource->GetClass();
- resType = Resource->GetHLSLType()->getPointerElementType();
- }
- return Resource != nullptr;
- }
- llvm::Value *DxilPatchShaderRecordBindings::LoadShaderRecordData(
- Module &M,
- IRBuilder<> &Builder,
- llvm::Value *offsetToShaderRecord,
- unsigned int dataOffsetInShaderRecord)
- {
- DxilModule &DM = M.GetOrCreateDxilModule();
- LLVMContext &Ctx = M.getContext();
- OP *HlslOP = DM.GetOP();
- Constant *dataOffset =
- HlslOP->GetU32Const(dataOffsetInShaderRecord);
- Value *shaderTableOffsetToData = Builder.CreateAdd(dataOffset, offsetToShaderRecord);
- return CreateShaderRecordBufferLoad(M, Builder, shaderTableOffsetToData,
- llvm::Type::getInt32Ty(Ctx));
- }
- void DxilPatchShaderRecordBindings::PatchCreateHandleToUseDescriptorIndex(
- _In_ Module &M,
- _In_ IRBuilder<> &Builder,
- _In_ DXIL::ResourceKind &resourceKind,
- _In_ DXIL::ResourceClass &resourceClass,
- _In_ llvm::Type *resourceType,
- _In_ llvm::Value *descriptorIndex,
- _Inout_ DxilInst_CreateHandleForLib &createHandleInstr)
- {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- llvm::Value *descriptorHeapSymbol = GetAliasedDescriptorHeapHandle(M, resourceType, resourceClass, resourceKind);
- llvm::Value *viewSymbol = Builder.CreateGEP(descriptorHeapSymbol, { HlslOP->GetU32Const(0), descriptorIndex }, "IndexIntoDH");
- DxilMDHelper::MarkNonUniform(cast<Instruction>(viewSymbol));
- llvm::Value *handle = Builder.CreateLoad(viewSymbol);
- auto callInst = cast<CallInst>(createHandleInstr.Instr);
- callInst->setCalledFunction(HlslOP->GetOpFunc(
- DXIL::OpCode::CreateHandleForLib,
- handle->getType()));
- createHandleInstr.set_Resource(handle);
- }
- void DxilPatchShaderRecordBindings::InitializeViewTable() {
- // The Fallback Layer declares a bindless raw buffer that spans the entire descriptor heap,
- // manually add it to the list of UAV register spaces used
- if (*pInputShaderInfo->pNumUAVSpaces == 0)
- {
- ViewKey key = { (unsigned int)hlsl::DXIL::ResourceKind::RawBuffer, {0} };
- unsigned int index = FindOrInsertViewIntoList(
- key,
- pInputShaderInfo->pUAVRegisterSpaceArray,
- *pInputShaderInfo->pNumUAVSpaces,
- FallbackLayerNumDescriptorHeapSpacesPerView);
- (void)index;
- assert(index == 0);
- }
- }
- void DxilPatchShaderRecordBindings::PatchShaderBindings(Module &M) {
- DxilModule &DM = M.GetOrCreateDxilModule();
- OP *HlslOP = DM.GetOP();
- // Don't erase instructions until the very end because it throws off the iterator
- std::vector<llvm::Instruction *> instructionsToRemove;
- for (BasicBlock &block : EntryPointFunction->getBasicBlockList()) {
- auto & Instructions = block.getInstList();
- for (auto &instr : Instructions) {
- DxilInst_CreateHandleForLib createHandleForLib(&instr);
- if (createHandleForLib) {
- DXIL::ResourceClass resourceClass;
- unsigned int registerSpace;
- unsigned int registerIndex;
- DXIL::ResourceKind kind;
- llvm::Type *resType;
- bool resourceIsResolved = true;
- resourceIsResolved = GetHandleInfo(M, createHandleForLib, registerIndex, registerSpace, kind, resourceClass, resType);
- if (!resourceIsResolved) continue; // TODO: This shouldn't actually be happening?
- ShaderRecordEntry shaderRecord = FindRootSignatureDescriptor(
- *pRootSignatureDesc,
- pInputShaderInfo->ShaderRecordIdentifierSizeInBytes,
- resourceClass,
- registerIndex,
- registerSpace);
- const bool IsBindingSpecifiedInLocalRootSignature = !shaderRecord.IsInvalid();
- if (IsBindingSpecifiedInLocalRootSignature) {
- if (!DispatchRaysConstantsHandle) {
- AddInputBinding(M);
- }
- switch (shaderRecord.ParameterType) {
- case DxilRootParameterType::Constants32Bit:
- {
- for (User *U : instr.users()) {
- llvm::Instruction *instruction = cast<CallInst>(U);
- if (IsCBufferLoad(instruction)) {
- llvm::Instruction *cbufferLoadInstr = instruction;
- IRBuilder<> Builder(cbufferLoadInstr);
- llvm::Value * cbufferOffsetInBytes = CreateCBufferLoadOffsetInBytes(M, Builder, cbufferLoadInstr);
- llvm::Value *LocalOffsetToRootConstant = CreateOffsetToShaderRecord(M, Builder, shaderRecord.RecordOffsetInBytes, cbufferOffsetInBytes);
- llvm::Value *GlobalOffsetToRootConstant = Builder.CreateAdd(LocalOffsetToRootConstant, BaseShaderRecordOffset);
- llvm::Value *srvBufferLoad = CreateShaderRecordBufferLoad(M, Builder, GlobalOffsetToRootConstant, cbufferLoadInstr->getType());
- ReplaceUsesOfWith(cbufferLoadInstr, srvBufferLoad);
- } else {
- ThrowFailure();
- }
- }
- instructionsToRemove.push_back(&instr);
- break;
- }
- case DxilRootParameterType::DescriptorTable:
- {
- IRBuilder<> Builder(&instr);
- llvm::Value *srvBufferLoad = LoadShaderRecordData(
- M,
- Builder,
- BaseShaderRecordOffset,
- shaderRecord.RecordOffsetInBytes);
- llvm::Value *DescriptorTableEntryLo = Builder.CreateExtractValue(srvBufferLoad, 0, "DescriptorTableHandleLo");
- unsigned int offsetToLoadInUints = offsetof(DispatchRaysConstants, SrvCbvUavDescriptorHeapStart) / sizeof(uint32_t);
- unsigned int uintsPerRow = 4;
- unsigned int rowToLoad = offsetToLoadInUints / uintsPerRow;
- unsigned int extractValueOffset = offsetToLoadInUints % uintsPerRow;
- llvm::Value *DescHeapConstants = CreateCBufferLoadLegacy(M, Builder, DispatchRaysConstantsHandle, rowToLoad);
- llvm::Value *DescriptorHeapStartAddressLo = Builder.CreateExtractValue(DescHeapConstants, extractValueOffset, "DescriptorHeapStartHandleLo");
- // TODO: The hi bits can only be ignored if the difference is guaranteed to be < 32 bytes. This is an unsafe assumption, particularly given
- // large descriptor sizes
- llvm::Value *DescriptorTableOffsetInBytes = Builder.CreateSub(DescriptorTableEntryLo, DescriptorHeapStartAddressLo, "TableOffsetInBytes");
- Constant *DescriptorSizeInBytes = HlslOP->GetU32Const(pInputShaderInfo->SrvCbvUavDescriptorSizeInBytes);
- llvm::Value * DescriptorTableStartIndex = Builder.CreateExactUDiv(DescriptorTableOffsetInBytes, DescriptorSizeInBytes, "TableStartIndex");
- Constant *RecordOffset = HlslOP->GetU32Const(shaderRecord.OffsetInDescriptors);
- llvm::Value * BaseDescriptorIndex = Builder.CreateAdd(DescriptorTableStartIndex, RecordOffset, "BaseDescriptorIndex");
- // TODO: Not supporting dynamic indexing yet, should be pulled from CreateHandleForLib
- // If dynamic indexing is being used, add the apps index on top of the calculated index
- llvm::Value * DynamicIndex = HlslOP->GetU32Const(0);
- llvm::Value * DescriptorIndex = Builder.CreateAdd(BaseDescriptorIndex, DynamicIndex, "DescriptorIndex");
- PatchCreateHandleToUseDescriptorIndex(
- M,
- Builder,
- kind,
- resourceClass,
- resType,
- DescriptorIndex,
- createHandleForLib);
- break;
- }
- case DxilRootParameterType::CBV:
- case DxilRootParameterType::SRV:
- case DxilRootParameterType::UAV: {
- IRBuilder<> Builder(&instr);
- llvm::Value *srvBufferLoad = LoadShaderRecordData(
- M,
- Builder,
- BaseShaderRecordOffset,
- shaderRecord.RecordOffsetInBytes);
- llvm::Value *DescriptorIndex = Builder.CreateExtractValue(
- srvBufferLoad, 1, "DescriptorHeapIndex");
- // TODO: Handle offset in bytes
- // llvm::Value *OffsetInBytes = Builder.CreateExtractValue(
- // srvBufferLoad, 0, "OffsetInBytes");
- PatchCreateHandleToUseDescriptorIndex(
- M,
- Builder,
- kind,
- resourceClass,
- resType,
- DescriptorIndex,
- createHandleForLib);
- break;
- }
- default:
- ThrowFailure();
- break;
- }
- }
- }
- }
- }
- for (auto instruction : instructionsToRemove) {
- instruction->eraseFromParent();
- }
- }
- bool IsParameterTypeCompatibleWithResourceClass(
- DXIL::ResourceClass resourceClass,
- DxilRootParameterType parameterType) {
- switch (parameterType) {
- case DxilRootParameterType::DescriptorTable:
- return true;
- case DxilRootParameterType::Constants32Bit:
- case DxilRootParameterType::CBV:
- return resourceClass == DXIL::ResourceClass::CBuffer;
- case DxilRootParameterType::SRV:
- return resourceClass == DXIL::ResourceClass::SRV;
- case DxilRootParameterType::UAV:
- return resourceClass == DXIL::ResourceClass::UAV;
- default:
- ThrowFailure();
- return false;
- }
- }
- DxilRootParameterType ConvertD3D12ParameterTypeToDxil(DxilRootParameterType parameter) {
- switch (parameter) {
- case DxilRootParameterType::Constants32Bit:
- return DxilRootParameterType::Constants32Bit;
- case DxilRootParameterType::DescriptorTable:
- return DxilRootParameterType::DescriptorTable;
- case DxilRootParameterType::CBV:
- return DxilRootParameterType::CBV;
- case DxilRootParameterType::SRV:
- return DxilRootParameterType::SRV;
- case DxilRootParameterType::UAV:
- return DxilRootParameterType::UAV;
- }
- assert(false);
- return (DxilRootParameterType)-1;
- }
- DXIL::ResourceClass ConvertD3D12RangeTypeToDxil(DxilDescriptorRangeType rangeType) {
- switch (rangeType) {
- case DxilDescriptorRangeType::SRV:
- return DXIL::ResourceClass::SRV;
- case DxilDescriptorRangeType::UAV:
- return DXIL::ResourceClass::UAV;
- case DxilDescriptorRangeType::CBV:
- return DXIL::ResourceClass::CBuffer;
- case DxilDescriptorRangeType::Sampler:
- return DXIL::ResourceClass::Sampler;
- }
- assert(false);
- return (DXIL::ResourceClass) - 1;
- }
- unsigned int GetParameterTypeAlignment(DxilRootParameterType parameterType) {
- switch (parameterType) {
- case DxilRootParameterType::DescriptorTable:
- return SizeofD3D12GpuDescriptorHandle;
- case DxilRootParameterType::Constants32Bit:
- return sizeof(uint32_t);
- case DxilRootParameterType::CBV: // fallthrough
- case DxilRootParameterType::SRV: // fallthrough
- case DxilRootParameterType::UAV:
- return SizeofD3D12GpuVA;
- default:
- return UINT_MAX;
- }
- }
- template <typename TD3D12_ROOT_SIGNATURE_DESC>
- ShaderRecordEntry FindRootSignatureDescriptorHelper(
- const TD3D12_ROOT_SIGNATURE_DESC &rootSignatureDescriptor,
- unsigned int ShaderRecordIdentifierSizeInBytes,
- DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
- unsigned int registerSpace) {
- // Automatically fail if it's looking for a fallback binding as these never
- // need to be patched
- if (registerSpace != FallbackLayerRegisterSpace) {
- unsigned int recordOffset = ShaderRecordIdentifierSizeInBytes;
- for (unsigned int rootParamIndex = 0;
- rootParamIndex < rootSignatureDescriptor.NumParameters;
- rootParamIndex++) {
- auto &rootParam = rootSignatureDescriptor.pParameters[rootParamIndex];
- auto dxilParamType =
- ConvertD3D12ParameterTypeToDxil(rootParam.ParameterType);
- #define ALIGN(alignment, num) (((num + alignment - 1) / alignment) * alignment)
- recordOffset = ALIGN(GetParameterTypeAlignment(rootParam.ParameterType),
- recordOffset);
- switch (rootParam.ParameterType) {
- case DxilRootParameterType::Constants32Bit:
- if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
- dxilParamType) &&
- baseRegisterIndex == rootParam.Constants.ShaderRegister &&
- registerSpace == rootParam.Constants.RegisterSpace) {
- return {dxilParamType, recordOffset, 0};
- }
- recordOffset += rootParam.Constants.Num32BitValues * sizeof(uint32_t);
- break;
- case DxilRootParameterType::DescriptorTable: {
- auto &descriptorTable = rootParam.DescriptorTable;
- unsigned int rangeOffsetInDescriptors = 0;
- for (unsigned int rangeIndex = 0;
- rangeIndex < descriptorTable.NumDescriptorRanges; rangeIndex++) {
- auto &range = descriptorTable.pDescriptorRanges[rangeIndex];
- if (range.OffsetInDescriptorsFromTableStart != (unsigned)-1) {
- rangeOffsetInDescriptors = range.OffsetInDescriptorsFromTableStart;
- }
- if (ConvertD3D12RangeTypeToDxil(range.RangeType) == resourceClass &&
- range.RegisterSpace == registerSpace &&
- range.BaseShaderRegister <= baseRegisterIndex &&
- range.BaseShaderRegister + range.NumDescriptors >
- baseRegisterIndex) {
- rangeOffsetInDescriptors +=
- baseRegisterIndex - range.BaseShaderRegister;
- return {dxilParamType, recordOffset, rangeOffsetInDescriptors};
- }
- rangeOffsetInDescriptors += range.NumDescriptors;
- }
- recordOffset += SizeofD3D12GpuDescriptorHandle;
- break;
- }
- case DxilRootParameterType::CBV:
- case DxilRootParameterType::SRV:
- case DxilRootParameterType::UAV:
- if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
- dxilParamType) &&
- baseRegisterIndex == rootParam.Descriptor.ShaderRegister &&
- registerSpace == rootParam.Descriptor.RegisterSpace) {
- return {dxilParamType, recordOffset, 0};
- }
- recordOffset += SizeofD3D12GpuVA;
- break;
- }
- }
- }
- return ShaderRecordEntry::InvalidEntry();
- }
- // TODO: Consider pre-calculating this into a map
- ShaderRecordEntry DxilPatchShaderRecordBindings::FindRootSignatureDescriptor(
- const DxilVersionedRootSignatureDesc &rootSignatureDescriptor,
- unsigned int ShaderRecordIdentifierSizeInBytes,
- DXIL::ResourceClass resourceClass,
- unsigned int baseRegisterIndex,
- unsigned int registerSpace) {
- switch (rootSignatureDescriptor.Version) {
- case DxilRootSignatureVersion::Version_1_0:
- return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_0, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
- case DxilRootSignatureVersion::Version_1_1:
- return FindRootSignatureDescriptorHelper(rootSignatureDescriptor.Desc_1_1, ShaderRecordIdentifierSizeInBytes, resourceClass, baseRegisterIndex, registerSpace);
- default:
- ThrowFailure();
- return ShaderRecordEntry::InvalidEntry();
- }
- }
|