DxilRuntimeReflection.inl 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilRuntimeReflection.inl //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Defines shader reflection for runtime usage. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/DxilContainer/DxilRuntimeReflection.h"
  12. #include <unordered_map>
  13. #include <vector>
  14. #include <memory>
  15. #include <cwchar>
  16. namespace hlsl {
  17. namespace RDAT {
  18. struct ResourceKey {
  19. uint32_t Class, ID;
  20. ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
  21. bool operator==(const ResourceKey& other) const {
  22. return other.Class == Class && other.ID == ID;
  23. }
  24. };
  25. // Size-checked reader
  26. // on overrun: throw buffer_overrun{};
  27. // on overlap: throw buffer_overlap{};
  28. class CheckedReader {
  29. const char *Ptr;
  30. size_t Size;
  31. size_t Offset;
  32. public:
  33. class exception : public std::exception {};
  34. class buffer_overrun : public exception {
  35. public:
  36. buffer_overrun() noexcept {}
  37. virtual const char * what() const noexcept override {
  38. return ("buffer_overrun");
  39. }
  40. };
  41. class buffer_overlap : public exception {
  42. public:
  43. buffer_overlap() noexcept {}
  44. virtual const char * what() const noexcept override {
  45. return ("buffer_overlap");
  46. }
  47. };
  48. CheckedReader(const void *ptr, size_t size) :
  49. Ptr(reinterpret_cast<const char*>(ptr)), Size(size), Offset(0) {}
  50. void Reset(size_t offset = 0) {
  51. if (offset >= Size) throw buffer_overrun{};
  52. Offset = offset;
  53. }
  54. // offset is absolute, ensure offset is >= current offset
  55. void Advance(size_t offset = 0) {
  56. if (offset < Offset) throw buffer_overlap{};
  57. if (offset >= Size) throw buffer_overrun{};
  58. Offset = offset;
  59. }
  60. void CheckBounds(size_t size) const {
  61. assert(Offset <= Size && "otherwise, offset larger than size");
  62. if (size > Size - Offset)
  63. throw buffer_overrun{};
  64. }
  65. template <typename T>
  66. const T *Cast(size_t size = 0) {
  67. if (0 == size) size = sizeof(T);
  68. CheckBounds(size);
  69. return reinterpret_cast<const T*>(Ptr + Offset);
  70. }
  71. template <typename T>
  72. const T &Read() {
  73. const size_t size = sizeof(T);
  74. const T* p = Cast<T>(size);
  75. Offset += size;
  76. return *p;
  77. }
  78. template <typename T>
  79. const T *ReadArray(size_t count = 1) {
  80. const size_t size = sizeof(T) * count;
  81. const T* p = Cast<T>(size);
  82. Offset += size;
  83. return p;
  84. }
  85. };
  86. DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr, 0) {}
  87. DxilRuntimeData::DxilRuntimeData(const void *ptr, size_t size)
  88. : m_StringReader(), m_IndexTableReader(), m_RawBytesReader(),
  89. m_ResourceTableReader(), m_FunctionTableReader(),
  90. m_SubobjectTableReader(), m_Context() {
  91. m_Context = {&m_StringReader, &m_IndexTableReader, &m_RawBytesReader,
  92. &m_ResourceTableReader, &m_FunctionTableReader,
  93. &m_SubobjectTableReader};
  94. m_ResourceTableReader.SetContext(&m_Context);
  95. m_FunctionTableReader.SetContext(&m_Context);
  96. m_SubobjectTableReader.SetContext(&m_Context);
  97. InitFromRDAT(ptr, size);
  98. }
  99. // initializing reader from RDAT. return true if no error has occured.
  100. bool DxilRuntimeData::InitFromRDAT(const void *pRDAT, size_t size) {
  101. if (pRDAT) {
  102. try {
  103. CheckedReader Reader(pRDAT, size);
  104. RuntimeDataHeader RDATHeader = Reader.Read<RuntimeDataHeader>();
  105. if (RDATHeader.Version < RDAT_Version_10) {
  106. return false;
  107. }
  108. const uint32_t *offsets = Reader.ReadArray<uint32_t>(RDATHeader.PartCount);
  109. for (uint32_t i = 0; i < RDATHeader.PartCount; ++i) {
  110. Reader.Advance(offsets[i]);
  111. RuntimeDataPartHeader part = Reader.Read<RuntimeDataPartHeader>();
  112. CheckedReader PR(Reader.ReadArray<char>(part.Size), part.Size);
  113. switch (part.Type) {
  114. case RuntimeDataPartType::StringBuffer: {
  115. m_StringReader = StringTableReader(
  116. PR.ReadArray<char>(part.Size), part.Size);
  117. break;
  118. }
  119. case RuntimeDataPartType::IndexArrays: {
  120. uint32_t count = part.Size / sizeof(uint32_t);
  121. m_IndexTableReader = IndexTableReader(
  122. PR.ReadArray<uint32_t>(count), count);
  123. break;
  124. }
  125. case RuntimeDataPartType::RawBytes: {
  126. m_RawBytesReader = RawBytesReader(
  127. PR.ReadArray<char>(part.Size), part.Size);
  128. break;
  129. }
  130. case RuntimeDataPartType::ResourceTable: {
  131. RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
  132. size_t tableSize = table.RecordCount * table.RecordStride;
  133. m_ResourceTableReader.SetResourceInfo(PR.ReadArray<char>(tableSize),
  134. table.RecordCount, table.RecordStride);
  135. break;
  136. }
  137. case RuntimeDataPartType::FunctionTable: {
  138. RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
  139. size_t tableSize = table.RecordCount * table.RecordStride;
  140. m_FunctionTableReader.SetFunctionInfo(PR.ReadArray<char>(tableSize),
  141. table.RecordCount, table.RecordStride);
  142. break;
  143. }
  144. case RuntimeDataPartType::SubobjectTable: {
  145. RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
  146. size_t tableSize = table.RecordCount * table.RecordStride;
  147. m_SubobjectTableReader.SetSubobjectInfo(PR.ReadArray<char>(tableSize),
  148. table.RecordCount, table.RecordStride);
  149. break;
  150. }
  151. default:
  152. continue; // Skip unrecognized parts
  153. }
  154. }
  155. return true;
  156. } catch(CheckedReader::exception e) {
  157. // TODO: error handling
  158. //throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
  159. return false;
  160. }
  161. }
  162. return false;
  163. }
  164. FunctionTableReader *DxilRuntimeData::GetFunctionTableReader() {
  165. return &m_FunctionTableReader;
  166. }
  167. ResourceTableReader *DxilRuntimeData::GetResourceTableReader() {
  168. return &m_ResourceTableReader;
  169. }
  170. SubobjectTableReader *DxilRuntimeData::GetSubobjectTableReader() {
  171. return &m_SubobjectTableReader;
  172. }
  173. }} // hlsl::RDAT
  174. using namespace hlsl;
  175. using namespace RDAT;
  176. namespace std {
  177. template <> struct hash<ResourceKey> {
  178. size_t operator()(const ResourceKey &key) const throw() {
  179. return (hash<uint32_t>()(key.Class) * (size_t)16777619U) ^
  180. hash<uint32_t>()(key.ID);
  181. }
  182. };
  183. } // namespace std
  184. namespace {
  185. class DxilRuntimeReflection_impl : public DxilRuntimeReflection {
  186. private:
  187. typedef std::unordered_map<const char *, std::unique_ptr<wchar_t[]>> StringMap;
  188. typedef std::unordered_map<const void *, std::unique_ptr<char[]>> BytesMap;
  189. typedef std::vector<const wchar_t *> WStringList;
  190. typedef std::vector<DxilResourceDesc> ResourceList;
  191. typedef std::vector<DxilResourceDesc *> ResourceRefList;
  192. typedef std::vector<DxilFunctionDesc> FunctionList;
  193. typedef std::vector<DxilSubobjectDesc> SubobjectList;
  194. DxilRuntimeData m_RuntimeData;
  195. StringMap m_StringMap;
  196. BytesMap m_BytesMap;
  197. ResourceList m_Resources;
  198. FunctionList m_Functions;
  199. SubobjectList m_Subobjects;
  200. std::unordered_map<ResourceKey, DxilResourceDesc *> m_ResourceMap;
  201. std::unordered_map<DxilFunctionDesc *, ResourceRefList> m_FuncToResMap;
  202. std::unordered_map<DxilFunctionDesc *, WStringList> m_FuncToDependenciesMap;
  203. std::unordered_map<DxilSubobjectDesc *, WStringList> m_SubobjectToExportsMap;
  204. bool m_initialized;
  205. const wchar_t *GetWideString(const char *ptr);
  206. void AddString(const char *ptr);
  207. const void *GetBytes(const void *ptr, size_t size);
  208. void InitializeReflection();
  209. const DxilResourceDesc * const*GetResourcesForFunction(DxilFunctionDesc &function,
  210. const FunctionReader &functionReader);
  211. const wchar_t **GetDependenciesForFunction(DxilFunctionDesc &function,
  212. const FunctionReader &functionReader);
  213. const wchar_t **GetExportsForAssociation(DxilSubobjectDesc &subobject,
  214. const SubobjectReader &subobjectReader);
  215. DxilResourceDesc *AddResource(const ResourceReader &resourceReader);
  216. DxilFunctionDesc *AddFunction(const FunctionReader &functionReader);
  217. DxilSubobjectDesc *AddSubobject(const SubobjectReader &subobjectReader);
  218. public:
  219. // TODO: Implement pipeline state validation with runtime data
  220. // TODO: Update BlobContainer.h to recognize 'RDAT' blob
  221. DxilRuntimeReflection_impl()
  222. : m_RuntimeData(), m_StringMap(), m_BytesMap(), m_Resources(), m_Functions(),
  223. m_FuncToResMap(), m_FuncToDependenciesMap(), m_SubobjectToExportsMap(),
  224. m_initialized(false) {}
  225. virtual ~DxilRuntimeReflection_impl() {}
  226. // This call will allocate memory for GetLibraryReflection call
  227. bool InitFromRDAT(const void *pRDAT, size_t size) override;
  228. const DxilLibraryDesc GetLibraryReflection() override;
  229. };
  230. void DxilRuntimeReflection_impl::AddString(const char *ptr) {
  231. if (m_StringMap.find(ptr) == m_StringMap.end()) {
  232. auto state = std::mbstate_t();
  233. size_t size = std::mbsrtowcs(nullptr, &ptr, 0, &state);
  234. if (size != static_cast<size_t>(-1)) {
  235. std::unique_ptr<wchar_t[]> pNew(new wchar_t[size + 1]);
  236. auto pOldPtr = ptr;
  237. std::mbsrtowcs(pNew.get(), &ptr, size + 1, &state);
  238. m_StringMap[pOldPtr] = std::move(pNew);
  239. }
  240. }
  241. }
  242. const wchar_t *DxilRuntimeReflection_impl::GetWideString(const char *ptr) {
  243. if (m_StringMap.find(ptr) == m_StringMap.end()) {
  244. AddString(ptr);
  245. }
  246. return m_StringMap.at(ptr).get();
  247. }
  248. const void *DxilRuntimeReflection_impl::GetBytes(const void *ptr, size_t size) {
  249. auto it = m_BytesMap.find(ptr);
  250. if (it != m_BytesMap.end())
  251. return it->second.get();
  252. auto inserted = m_BytesMap.insert(std::make_pair(ptr, std::unique_ptr<char[]>(new char[size])));
  253. void *newPtr = inserted.first->second.get();
  254. memcpy(newPtr, ptr, size);
  255. return newPtr;
  256. }
  257. bool DxilRuntimeReflection_impl::InitFromRDAT(const void *pRDAT, size_t size) {
  258. assert(!m_initialized && "may only initialize once");
  259. m_initialized = m_RuntimeData.InitFromRDAT(pRDAT, size);
  260. if (m_initialized)
  261. InitializeReflection();
  262. return m_initialized;
  263. }
  264. const DxilLibraryDesc DxilRuntimeReflection_impl::GetLibraryReflection() {
  265. DxilLibraryDesc reflection = {};
  266. if (m_initialized) {
  267. reflection.NumResources =
  268. m_RuntimeData.GetResourceTableReader()->GetNumResources();
  269. reflection.pResource = m_Resources.data();
  270. reflection.NumFunctions =
  271. m_RuntimeData.GetFunctionTableReader()->GetNumFunctions();
  272. reflection.pFunction = m_Functions.data();
  273. reflection.NumSubobjects =
  274. m_RuntimeData.GetSubobjectTableReader()->GetCount();
  275. reflection.pSubobjects = m_Subobjects.data();
  276. }
  277. return reflection;
  278. }
  279. void DxilRuntimeReflection_impl::InitializeReflection() {
  280. // First need to reserve spaces for resources because functions will need to
  281. // reference them via pointers.
  282. const ResourceTableReader *resourceTableReader = m_RuntimeData.GetResourceTableReader();
  283. m_Resources.reserve(resourceTableReader->GetNumResources());
  284. for (uint32_t i = 0; i < resourceTableReader->GetNumResources(); ++i) {
  285. ResourceReader resourceReader = resourceTableReader->GetItem(i);
  286. AddString(resourceReader.GetName());
  287. DxilResourceDesc *pResource = AddResource(resourceReader);
  288. if (pResource) {
  289. ResourceKey key(pResource->Class, pResource->ID);
  290. m_ResourceMap[key] = pResource;
  291. }
  292. }
  293. const FunctionTableReader *functionTableReader = m_RuntimeData.GetFunctionTableReader();
  294. m_Functions.reserve(functionTableReader->GetNumFunctions());
  295. for (uint32_t i = 0; i < functionTableReader->GetNumFunctions(); ++i) {
  296. FunctionReader functionReader = functionTableReader->GetItem(i);
  297. AddString(functionReader.GetName());
  298. AddFunction(functionReader);
  299. }
  300. const SubobjectTableReader *subobjectTableReader = m_RuntimeData.GetSubobjectTableReader();
  301. m_Subobjects.reserve(subobjectTableReader->GetCount());
  302. for (uint32_t i = 0; i < subobjectTableReader->GetCount(); ++i) {
  303. SubobjectReader subobjectReader = subobjectTableReader->GetItem(i);
  304. AddString(subobjectReader.GetName());
  305. AddSubobject(subobjectReader);
  306. }
  307. }
  308. DxilResourceDesc *
  309. DxilRuntimeReflection_impl::AddResource(const ResourceReader &resourceReader) {
  310. assert(m_Resources.size() < m_Resources.capacity() && "Otherwise, number of resources was incorrect");
  311. if (!(m_Resources.size() < m_Resources.capacity()))
  312. return nullptr;
  313. m_Resources.emplace_back(DxilResourceDesc({}));
  314. DxilResourceDesc &resource = m_Resources.back();
  315. resource.Class = (uint32_t)resourceReader.GetResourceClass();
  316. resource.Kind = (uint32_t)resourceReader.GetResourceKind();
  317. resource.Space = resourceReader.GetSpace();
  318. resource.LowerBound = resourceReader.GetLowerBound();
  319. resource.UpperBound = resourceReader.GetUpperBound();
  320. resource.ID = resourceReader.GetID();
  321. resource.Flags = resourceReader.GetFlags();
  322. resource.Name = GetWideString(resourceReader.GetName());
  323. return &resource;
  324. }
  325. const DxilResourceDesc * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
  326. DxilFunctionDesc &function, const FunctionReader &functionReader) {
  327. if (!functionReader.GetNumResources())
  328. return nullptr;
  329. auto it = m_FuncToResMap.insert(std::make_pair(&function, ResourceRefList()));
  330. assert(it.second && "otherwise, collision");
  331. ResourceRefList &resourceList = it.first->second;
  332. resourceList.reserve(functionReader.GetNumResources());
  333. for (uint32_t i = 0; i < functionReader.GetNumResources(); ++i) {
  334. const ResourceReader resourceReader = functionReader.GetResource(i);
  335. ResourceKey key((uint32_t)resourceReader.GetResourceClass(),
  336. resourceReader.GetID());
  337. auto it = m_ResourceMap.find(key);
  338. assert(it != m_ResourceMap.end() && it->second && "Otherwise, resource was not in map, or was null");
  339. resourceList.emplace_back(it->second);
  340. }
  341. return resourceList.data();
  342. }
  343. const wchar_t **DxilRuntimeReflection_impl::GetDependenciesForFunction(
  344. DxilFunctionDesc &function, const FunctionReader &functionReader) {
  345. auto it = m_FuncToDependenciesMap.insert(std::make_pair(&function, WStringList()));
  346. assert(it.second && "otherwise, collision");
  347. WStringList &wStringList = it.first->second;
  348. for (uint32_t i = 0; i < functionReader.GetNumDependencies(); ++i) {
  349. wStringList.emplace_back(GetWideString(functionReader.GetDependency(i)));
  350. }
  351. return wStringList.empty() ? nullptr : wStringList.data();
  352. }
  353. DxilFunctionDesc *
  354. DxilRuntimeReflection_impl::AddFunction(const FunctionReader &functionReader) {
  355. assert(m_Functions.size() < m_Functions.capacity() && "Otherwise, number of functions was incorrect");
  356. if (!(m_Functions.size() < m_Functions.capacity()))
  357. return nullptr;
  358. m_Functions.emplace_back(DxilFunctionDesc({}));
  359. DxilFunctionDesc &function = m_Functions.back();
  360. function.Name = GetWideString(functionReader.GetName());
  361. function.UnmangledName = GetWideString(functionReader.GetUnmangledName());
  362. function.NumResources = functionReader.GetNumResources();
  363. function.Resources = GetResourcesForFunction(function, functionReader);
  364. function.NumFunctionDependencies = functionReader.GetNumDependencies();
  365. function.FunctionDependencies =
  366. GetDependenciesForFunction(function, functionReader);
  367. function.ShaderKind = (uint32_t)functionReader.GetShaderKind();
  368. function.PayloadSizeInBytes = functionReader.GetPayloadSizeInBytes();
  369. function.AttributeSizeInBytes = functionReader.GetAttributeSizeInBytes();
  370. function.FeatureInfo1 = functionReader.GetFeatureInfo1();
  371. function.FeatureInfo2 = functionReader.GetFeatureInfo2();
  372. function.ShaderStageFlag = functionReader.GetShaderStageFlag();
  373. function.MinShaderTarget = functionReader.GetMinShaderTarget();
  374. return &function;
  375. }
  376. const wchar_t **DxilRuntimeReflection_impl::GetExportsForAssociation(
  377. DxilSubobjectDesc &subobject, const SubobjectReader &subobjectReader) {
  378. auto it = m_SubobjectToExportsMap.insert(std::make_pair(&subobject, WStringList()));
  379. assert(it.second && "otherwise, collision");
  380. WStringList &wStringList = it.first->second;
  381. for (uint32_t i = 0; i < subobjectReader.GetSubobjectToExportsAssociation_NumExports(); ++i) {
  382. wStringList.emplace_back(GetWideString(subobjectReader.GetSubobjectToExportsAssociation_Export(i)));
  383. }
  384. return wStringList.empty() ? nullptr : wStringList.data();
  385. }
  386. DxilSubobjectDesc *DxilRuntimeReflection_impl::AddSubobject(const SubobjectReader &subobjectReader) {
  387. assert(m_Subobjects.size() < m_Subobjects.capacity() && "Otherwise, number of subobjects was incorrect");
  388. if (!(m_Subobjects.size() < m_Subobjects.capacity()))
  389. return nullptr;
  390. m_Subobjects.emplace_back(DxilSubobjectDesc({}));
  391. DxilSubobjectDesc &subobject = m_Subobjects.back();
  392. subobject.Name = GetWideString(subobjectReader.GetName());
  393. subobject.Kind = (uint32_t)subobjectReader.GetKind();
  394. switch (subobjectReader.GetKind()) {
  395. case DXIL::SubobjectKind::StateObjectConfig:
  396. subobject.StateObjectConfig.Flags = subobjectReader.GetStateObjectConfig_Flags();
  397. break;
  398. case DXIL::SubobjectKind::GlobalRootSignature:
  399. case DXIL::SubobjectKind::LocalRootSignature:
  400. if (!subobjectReader.GetRootSignature(&subobject.RootSignature.pSerializedSignature, &subobject.RootSignature.SizeInBytes))
  401. return nullptr;
  402. subobject.RootSignature.pSerializedSignature = GetBytes(subobject.RootSignature.pSerializedSignature, subobject.RootSignature.SizeInBytes);
  403. break;
  404. case DXIL::SubobjectKind::SubobjectToExportsAssociation:
  405. subobject.SubobjectToExportsAssociation.Subobject =
  406. GetWideString(subobjectReader.GetSubobjectToExportsAssociation_Subobject());
  407. subobject.SubobjectToExportsAssociation.NumExports = subobjectReader.GetSubobjectToExportsAssociation_NumExports();
  408. subobject.SubobjectToExportsAssociation.Exports = GetExportsForAssociation(subobject, subobjectReader);
  409. break;
  410. case DXIL::SubobjectKind::RaytracingShaderConfig:
  411. subobject.RaytracingShaderConfig.MaxPayloadSizeInBytes = subobjectReader.GetRaytracingShaderConfig_MaxPayloadSizeInBytes();
  412. subobject.RaytracingShaderConfig.MaxAttributeSizeInBytes = subobjectReader.GetRaytracingShaderConfig_MaxAttributeSizeInBytes();
  413. break;
  414. case DXIL::SubobjectKind::RaytracingPipelineConfig:
  415. subobject.RaytracingPipelineConfig.MaxTraceRecursionDepth = subobjectReader.GetRaytracingPipelineConfig_MaxTraceRecursionDepth();
  416. break;
  417. case DXIL::SubobjectKind::HitGroup:
  418. subobject.HitGroup.Type = (uint32_t)subobjectReader.GetHitGroup_Type();
  419. subobject.HitGroup.Intersection = GetWideString(subobjectReader.GetHitGroup_Intersection());
  420. subobject.HitGroup.AnyHit = GetWideString(subobjectReader.GetHitGroup_AnyHit());
  421. subobject.HitGroup.ClosestHit = GetWideString(subobjectReader.GetHitGroup_ClosestHit());
  422. break;
  423. case DXIL::SubobjectKind::RaytracingPipelineConfig1:
  424. subobject.RaytracingPipelineConfig1.MaxTraceRecursionDepth = subobjectReader.GetRaytracingPipelineConfig1_MaxTraceRecursionDepth();
  425. subobject.RaytracingPipelineConfig1.Flags = subobjectReader.GetRaytracingPipelineConfig1_Flags();
  426. break;
  427. default:
  428. // Ignore contents of unrecognized subobject type (forward-compat)
  429. break;
  430. }
  431. return &subobject;
  432. }
  433. } // namespace anon
  434. DxilRuntimeReflection *hlsl::RDAT::CreateDxilRuntimeReflection() {
  435. return new DxilRuntimeReflection_impl();
  436. }