DxilRuntimeReflection.inl 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilLibraryReflection.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Defines shader reflection for runtime usage. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/hlsl/DxilRuntimeReflection.h"
  12. #include <windows.h>
  13. #include <unordered_map>
  14. #include <vector>
  15. #include <memory>
  16. namespace hlsl {
  17. namespace RDAT {
  18. struct ResourceKey {
  19. uint32_t Class, ID;
  20. ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
  21. bool operator==(const ResourceKey& other) const {
  22. return other.Class == Class && other.ID == ID;
  23. }
  24. };
  25. // Size-checked reader
  26. // on overrun: throw buffer_overrun{};
  27. // on overlap: throw buffer_overlap{};
  28. class CheckedReader {
  29. const char *Ptr;
  30. size_t Size;
  31. size_t Offset;
  32. public:
  33. class exception : public std::exception {};
  34. class buffer_overrun : public exception {
  35. public:
  36. buffer_overrun() noexcept {}
  37. virtual const char * what() const noexcept override {
  38. return ("buffer_overrun");
  39. }
  40. };
  41. class buffer_overlap : public exception {
  42. public:
  43. buffer_overlap() noexcept {}
  44. virtual const char * what() const noexcept override {
  45. return ("buffer_overlap");
  46. }
  47. };
  48. CheckedReader(const void *ptr, size_t size) :
  49. Ptr(reinterpret_cast<const char*>(ptr)), Size(size), Offset(0) {}
  50. void Reset(size_t offset = 0) {
  51. if (offset >= Size) throw buffer_overrun{};
  52. Offset = offset;
  53. }
  54. // offset is absolute, ensure offset is >= current offset
  55. void Advance(size_t offset = 0) {
  56. if (offset < Offset) throw buffer_overlap{};
  57. if (offset >= Size) throw buffer_overrun{};
  58. Offset = offset;
  59. }
  60. void CheckBounds(size_t size) const {
  61. assert(Offset <= Size && "otherwise, offset larger than size");
  62. if (size > Size - Offset)
  63. throw buffer_overrun{};
  64. }
  65. template <typename T>
  66. const T *Cast(size_t size = 0) {
  67. if (0 == size) size = sizeof(T);
  68. CheckBounds(size);
  69. return reinterpret_cast<const T*>(Ptr + Offset);
  70. }
  71. template <typename T>
  72. const T &Read() {
  73. const size_t size = sizeof(T);
  74. const T* p = Cast<T>(size);
  75. Offset += size;
  76. return *p;
  77. }
  78. template <typename T>
  79. const T *ReadArray(size_t count = 1) {
  80. const size_t size = sizeof(T) * count;
  81. const T* p = Cast<T>(size);
  82. Offset += size;
  83. return p;
  84. }
  85. };
  86. DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr, 0) {}
  87. DxilRuntimeData::DxilRuntimeData(const char *ptr, size_t size)
  88. : m_TableCount(0), m_StringReader(), m_ResourceTableReader(),
  89. m_FunctionTableReader(), m_IndexTableReader(), m_Context() {
  90. m_Context = {&m_StringReader, &m_IndexTableReader, &m_ResourceTableReader,
  91. &m_FunctionTableReader};
  92. m_ResourceTableReader.SetContext(&m_Context);
  93. m_FunctionTableReader.SetContext(&m_Context);
  94. InitFromRDAT(ptr, size);
  95. }
  96. // initializing reader from RDAT. return true if no error has occured.
  97. bool DxilRuntimeData::InitFromRDAT(const void *pRDAT, size_t size) {
  98. if (pRDAT) {
  99. try {
  100. CheckedReader Reader(pRDAT, size);
  101. RuntimeDataHeader RDATHeader = Reader.Read<RuntimeDataHeader>();
  102. if (RDATHeader.Version < RDAT_Version_0) {
  103. // Prerelease version, fallback to that Init
  104. return InitFromRDAT_Prerelease(pRDAT, size);
  105. }
  106. const uint32_t *offsets = Reader.ReadArray<uint32_t>(RDATHeader.PartCount);
  107. for (uint32_t i = 0; i < RDATHeader.PartCount; ++i) {
  108. Reader.Advance(offsets[i]);
  109. RuntimeDataPartHeader part = Reader.Read<RuntimeDataPartHeader>();
  110. CheckedReader PR(Reader.ReadArray<char>(part.Size), part.Size);
  111. switch (part.Type) {
  112. case RuntimeDataPartType::StringBuffer: {
  113. m_StringReader = StringTableReader(
  114. PR.ReadArray<char>(part.Size), part.Size);
  115. break;
  116. }
  117. case RuntimeDataPartType::IndexArrays: {
  118. uint32_t count = part.Size / sizeof(uint32_t);
  119. m_IndexTableReader = IndexTableReader(
  120. PR.ReadArray<uint32_t>(count), count);
  121. break;
  122. }
  123. case RuntimeDataPartType::ResourceTable: {
  124. RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
  125. size_t tableSize = table.RecordCount * table.RecordStride;
  126. m_ResourceTableReader.SetResourceInfo(PR.ReadArray<char>(tableSize),
  127. table.RecordCount, table.RecordStride);
  128. break;
  129. }
  130. case RuntimeDataPartType::FunctionTable: {
  131. RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
  132. size_t tableSize = table.RecordCount * table.RecordStride;
  133. m_FunctionTableReader.SetFunctionInfo(PR.ReadArray<char>(tableSize),
  134. table.RecordCount, table.RecordStride);
  135. break;
  136. }
  137. default:
  138. continue; // Skip unrecognized parts
  139. }
  140. }
  141. return true;
  142. } catch(CheckedReader::exception e) {
  143. // TODO: error handling
  144. //throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
  145. return false;
  146. }
  147. }
  148. return false;
  149. }
  150. bool DxilRuntimeData::InitFromRDAT_Prerelease(const void *pRDAT, size_t size) {
  151. enum class RuntimeDataPartType_Prerelease : uint32_t {
  152. Invalid = 0,
  153. String,
  154. Function,
  155. Resource,
  156. Index
  157. };
  158. struct RuntimeDataTableHeader_Prerelease {
  159. uint32_t tableType; // RuntimeDataPartType
  160. uint32_t size;
  161. uint32_t offset;
  162. };
  163. if (pRDAT) {
  164. try {
  165. CheckedReader Reader(pRDAT, size);
  166. uint32_t partCount = Reader.Read<uint32_t>();
  167. const RuntimeDataTableHeader_Prerelease *tableHeaders =
  168. Reader.ReadArray<RuntimeDataTableHeader_Prerelease>(partCount);
  169. for (uint32_t i = 0; i < partCount; ++i) {
  170. uint32_t partSize = tableHeaders[i].size;
  171. Reader.Advance(tableHeaders[i].offset);
  172. CheckedReader PR(Reader.ReadArray<char>(partSize), partSize);
  173. switch ((RuntimeDataPartType_Prerelease)(tableHeaders[i].tableType)) {
  174. case RuntimeDataPartType_Prerelease::String: {
  175. m_StringReader = StringTableReader(
  176. PR.ReadArray<char>(partSize), partSize);
  177. break;
  178. }
  179. case RuntimeDataPartType_Prerelease::Index: {
  180. uint32_t count = partSize / sizeof(uint32_t);
  181. m_IndexTableReader = IndexTableReader(
  182. PR.ReadArray<uint32_t>(count), count);
  183. break;
  184. }
  185. case RuntimeDataPartType_Prerelease::Resource: {
  186. uint32_t count = partSize / sizeof(RuntimeDataResourceInfo);
  187. m_ResourceTableReader.SetResourceInfo(PR.ReadArray<char>(partSize),
  188. count, sizeof(RuntimeDataResourceInfo));
  189. break;
  190. }
  191. case RuntimeDataPartType_Prerelease::Function: {
  192. uint32_t count = partSize / sizeof(RuntimeDataFunctionInfo);
  193. m_FunctionTableReader.SetFunctionInfo(PR.ReadArray<char>(partSize),
  194. count, sizeof(RuntimeDataFunctionInfo));
  195. break;
  196. }
  197. default:
  198. return false; // There should be no unrecognized parts
  199. }
  200. }
  201. return true;
  202. } catch(CheckedReader::exception e) {
  203. // TODO: error handling
  204. //throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
  205. return false;
  206. }
  207. }
  208. return false;
  209. }
  210. FunctionTableReader *DxilRuntimeData::GetFunctionTableReader() {
  211. return &m_FunctionTableReader;
  212. }
  213. ResourceTableReader *DxilRuntimeData::GetResourceTableReader() {
  214. return &m_ResourceTableReader;
  215. }
  216. }} // hlsl::RDAT
  217. using namespace hlsl;
  218. using namespace RDAT;
  219. template<>
  220. struct std::hash<ResourceKey> {
  221. public:
  222. size_t operator()(const ResourceKey& key) const throw() {
  223. return (std::hash<uint32_t>()(key.Class) * (size_t)16777619U)
  224. ^ std::hash<uint32_t>()(key.ID);
  225. }
  226. };
  227. namespace {
  228. class DxilRuntimeReflection_impl : public DxilRuntimeReflection {
  229. private:
  230. typedef std::unordered_map<const char *, std::unique_ptr<wchar_t[]>> StringMap;
  231. typedef std::vector<DxilResourceDesc> ResourceList;
  232. typedef std::vector<DxilResourceDesc *> ResourceRefList;
  233. typedef std::vector<DxilFunctionDesc> FunctionList;
  234. typedef std::vector<const wchar_t *> WStringList;
  235. DxilRuntimeData m_RuntimeData;
  236. StringMap m_StringMap;
  237. ResourceList m_Resources;
  238. FunctionList m_Functions;
  239. std::unordered_map<ResourceKey, DxilResourceDesc *> m_ResourceMap;
  240. std::unordered_map<DxilFunctionDesc *, ResourceRefList> m_FuncToResMap;
  241. std::unordered_map<DxilFunctionDesc *, WStringList> m_FuncToStringMap;
  242. bool m_initialized;
  243. const wchar_t *GetWideString(const char *ptr);
  244. void AddString(const char *ptr);
  245. void InitializeReflection();
  246. const DxilResourceDesc * const*GetResourcesForFunction(DxilFunctionDesc &function,
  247. const FunctionReader &functionReader);
  248. const wchar_t **GetDependenciesForFunction(DxilFunctionDesc &function,
  249. const FunctionReader &functionReader);
  250. DxilResourceDesc *AddResource(const ResourceReader &resourceReader);
  251. DxilFunctionDesc *AddFunction(const FunctionReader &functionReader);
  252. public:
  253. // TODO: Implement pipeline state validation with runtime data
  254. // TODO: Update BlobContainer.h to recognize 'RDAT' blob
  255. DxilRuntimeReflection_impl()
  256. : m_RuntimeData(), m_StringMap(), m_Resources(), m_Functions(),
  257. m_FuncToResMap(), m_FuncToStringMap(), m_initialized(false) {}
  258. virtual ~DxilRuntimeReflection_impl() {}
  259. // This call will allocate memory for GetLibraryReflection call
  260. bool InitFromRDAT(const void *pRDAT, size_t size) override;
  261. const DxilLibraryDesc GetLibraryReflection() override;
  262. };
  263. void DxilRuntimeReflection_impl::AddString(const char *ptr) {
  264. if (m_StringMap.find(ptr) == m_StringMap.end()) {
  265. int size = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, ptr, -1,
  266. nullptr, 0);
  267. if (size != 0) {
  268. auto pNew = std::make_unique<wchar_t[]>(size);
  269. ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, ptr, -1,
  270. pNew.get(), size);
  271. m_StringMap[ptr] = std::move(pNew);
  272. }
  273. }
  274. }
  275. const wchar_t *DxilRuntimeReflection_impl::GetWideString(const char *ptr) {
  276. if (m_StringMap.find(ptr) == m_StringMap.end()) {
  277. AddString(ptr);
  278. }
  279. return m_StringMap.at(ptr).get();
  280. }
  281. bool DxilRuntimeReflection_impl::InitFromRDAT(const void *pRDAT, size_t size) {
  282. m_initialized = m_RuntimeData.InitFromRDAT(pRDAT, size);
  283. if (m_initialized)
  284. InitializeReflection();
  285. return m_initialized;
  286. }
  287. const DxilLibraryDesc DxilRuntimeReflection_impl::GetLibraryReflection() {
  288. DxilLibraryDesc reflection = {};
  289. if (m_initialized) {
  290. reflection.NumResources =
  291. m_RuntimeData.GetResourceTableReader()->GetNumResources();
  292. reflection.pResource = m_Resources.data();
  293. reflection.NumFunctions =
  294. m_RuntimeData.GetFunctionTableReader()->GetNumFunctions();
  295. reflection.pFunction = m_Functions.data();
  296. }
  297. return reflection;
  298. }
  299. void DxilRuntimeReflection_impl::InitializeReflection() {
  300. // First need to reserve spaces for resources because functions will need to
  301. // reference them via pointers.
  302. const ResourceTableReader *resourceTableReader = m_RuntimeData.GetResourceTableReader();
  303. m_Resources.reserve(resourceTableReader->GetNumResources());
  304. for (uint32_t i = 0; i < resourceTableReader->GetNumResources(); ++i) {
  305. ResourceReader resourceReader = resourceTableReader->GetItem(i);
  306. AddString(resourceReader.GetName());
  307. DxilResourceDesc *pResource = AddResource(resourceReader);
  308. if (pResource) {
  309. ResourceKey key(pResource->Class, pResource->ID);
  310. m_ResourceMap[key] = pResource;
  311. }
  312. }
  313. const FunctionTableReader *functionTableReader = m_RuntimeData.GetFunctionTableReader();
  314. m_Functions.reserve(functionTableReader->GetNumFunctions());
  315. for (uint32_t i = 0; i < functionTableReader->GetNumFunctions(); ++i) {
  316. FunctionReader functionReader = functionTableReader->GetItem(i);
  317. AddString(functionReader.GetName());
  318. AddFunction(functionReader);
  319. }
  320. }
  321. DxilResourceDesc *
  322. DxilRuntimeReflection_impl::AddResource(const ResourceReader &resourceReader) {
  323. assert(m_Resources.size() < m_Resources.capacity() && "Otherwise, number of resources was incorrect");
  324. if (!(m_Resources.size() < m_Resources.capacity()))
  325. return nullptr;
  326. m_Resources.emplace_back(DxilResourceDesc({0}));
  327. DxilResourceDesc &resource = m_Resources.back();
  328. resource.Class = (uint32_t)resourceReader.GetResourceClass();
  329. resource.Kind = (uint32_t)resourceReader.GetResourceKind();
  330. resource.Space = resourceReader.GetSpace();
  331. resource.LowerBound = resourceReader.GetLowerBound();
  332. resource.UpperBound = resourceReader.GetUpperBound();
  333. resource.ID = resourceReader.GetID();
  334. resource.Flags = resourceReader.GetFlags();
  335. resource.Name = GetWideString(resourceReader.GetName());
  336. return &resource;
  337. }
  338. const DxilResourceDesc * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
  339. DxilFunctionDesc &function, const FunctionReader &functionReader) {
  340. if (m_FuncToResMap.find(&function) == m_FuncToResMap.end())
  341. m_FuncToResMap.insert(std::pair<DxilFunctionDesc *, ResourceRefList>(
  342. &function, ResourceRefList()));
  343. ResourceRefList &resourceList = m_FuncToResMap.at(&function);
  344. if (resourceList.empty()) {
  345. resourceList.reserve(functionReader.GetNumResources());
  346. for (uint32_t i = 0; i < functionReader.GetNumResources(); ++i) {
  347. const ResourceReader resourceReader = functionReader.GetResource(i);
  348. ResourceKey key((uint32_t)resourceReader.GetResourceClass(),
  349. resourceReader.GetID());
  350. auto it = m_ResourceMap.find(key);
  351. assert(it != m_ResourceMap.end() && it->second && "Otherwise, resource was not in map, or was null");
  352. resourceList.emplace_back(it->second);
  353. }
  354. }
  355. return resourceList.empty() ? nullptr : resourceList.data();
  356. }
  357. const wchar_t **DxilRuntimeReflection_impl::GetDependenciesForFunction(
  358. DxilFunctionDesc &function, const FunctionReader &functionReader) {
  359. if (m_FuncToStringMap.find(&function) == m_FuncToStringMap.end())
  360. m_FuncToStringMap.insert(
  361. std::pair<DxilFunctionDesc *, WStringList>(&function, WStringList()));
  362. WStringList &wStringList = m_FuncToStringMap.at(&function);
  363. for (uint32_t i = 0; i < functionReader.GetNumDependencies(); ++i) {
  364. wStringList.emplace_back(GetWideString(functionReader.GetDependency(i)));
  365. }
  366. return wStringList.empty() ? nullptr : wStringList.data();
  367. }
  368. DxilFunctionDesc *
  369. DxilRuntimeReflection_impl::AddFunction(const FunctionReader &functionReader) {
  370. assert(m_Functions.size() < m_Functions.capacity() && "Otherwise, number of functions was incorrect");
  371. if (!(m_Functions.size() < m_Functions.capacity()))
  372. return nullptr;
  373. m_Functions.emplace_back(DxilFunctionDesc({0}));
  374. DxilFunctionDesc &function = m_Functions.back();
  375. function.Name = GetWideString(functionReader.GetName());
  376. function.UnmangledName = GetWideString(functionReader.GetUnmangledName());
  377. function.NumResources = functionReader.GetNumResources();
  378. function.Resources = GetResourcesForFunction(function, functionReader);
  379. function.NumFunctionDependencies = functionReader.GetNumDependencies();
  380. function.FunctionDependencies =
  381. GetDependenciesForFunction(function, functionReader);
  382. function.ShaderKind = (uint32_t)functionReader.GetShaderKind();
  383. function.PayloadSizeInBytes = functionReader.GetPayloadSizeInBytes();
  384. function.AttributeSizeInBytes = functionReader.GetAttributeSizeInBytes();
  385. function.FeatureInfo1 = functionReader.GetFeatureInfo1();
  386. function.FeatureInfo2 = functionReader.GetFeatureInfo2();
  387. function.ShaderStageFlag = functionReader.GetShaderStageFlag();
  388. function.MinShaderTarget = functionReader.GetMinShaderTarget();
  389. return &function;
  390. }
  391. } // namespace anon
  392. DxilRuntimeReflection *hlsl::RDAT::CreateDxilRuntimeReflection() {
  393. return new DxilRuntimeReflection_impl();
  394. }