DxilSubobject.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilSubobject.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/Support/Global.h"
  10. #include "dxc/Support/Unicode.h"
  11. #include "dxc/Support/WinIncludes.h"
  12. #include "dxc/DXIL/DxilSubobject.h"
  13. #include "dxc/DxilContainer/DxilRuntimeReflection.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. namespace hlsl {
  16. //------------------------------------------------------------------------------
  17. //
  18. // Subobject methods.
  19. //
  20. DxilSubobject::DxilSubobject(DxilSubobject &&other)
  21. : m_Owner(other.m_Owner),
  22. m_Kind(other.m_Kind),
  23. m_Name(m_Owner.InternString(other.m_Name)),
  24. m_Exports(std::move(other.m_Exports))
  25. {
  26. DXASSERT_NOMSG(DXIL::IsValidSubobjectKind(m_Kind));
  27. CopyUnionedContents(other);
  28. }
  29. DxilSubobject::DxilSubobject(DxilSubobjects &owner, Kind kind, llvm::StringRef name)
  30. : m_Owner(owner),
  31. m_Kind(kind),
  32. m_Name(m_Owner.InternString(name)),
  33. m_Exports()
  34. {
  35. DXASSERT_NOMSG(DXIL::IsValidSubobjectKind(m_Kind));
  36. }
  37. DxilSubobject::DxilSubobject(DxilSubobjects &owner, const DxilSubobject &other, llvm::StringRef name)
  38. : m_Owner(owner),
  39. m_Kind(other.m_Kind),
  40. m_Name(name),
  41. m_Exports(other.m_Exports.begin(), other.m_Exports.end())
  42. {
  43. DXASSERT_NOMSG(DXIL::IsValidSubobjectKind(m_Kind));
  44. CopyUnionedContents(other);
  45. if (&m_Owner != &other.m_Owner)
  46. InternStrings();
  47. }
  48. void DxilSubobject::CopyUnionedContents(const DxilSubobject &other) {
  49. switch (m_Kind) {
  50. case Kind::StateObjectConfig:
  51. StateObjectConfig.Flags = other.StateObjectConfig.Flags;
  52. break;
  53. case Kind::GlobalRootSignature:
  54. case Kind::LocalRootSignature:
  55. RootSignature.Size = other.RootSignature.Size;
  56. RootSignature.Data = other.RootSignature.Data;
  57. break;
  58. case Kind::SubobjectToExportsAssociation:
  59. SubobjectToExportsAssociation.Subobject = other.SubobjectToExportsAssociation.Subobject;
  60. break;
  61. case Kind::RaytracingShaderConfig:
  62. RaytracingShaderConfig.MaxPayloadSizeInBytes = other.RaytracingShaderConfig.MaxPayloadSizeInBytes;
  63. RaytracingShaderConfig.MaxAttributeSizeInBytes = other.RaytracingShaderConfig.MaxAttributeSizeInBytes;
  64. break;
  65. case Kind::RaytracingPipelineConfig:
  66. RaytracingPipelineConfig.MaxTraceRecursionDepth = other.RaytracingPipelineConfig.MaxTraceRecursionDepth;
  67. break;
  68. case Kind::HitGroup:
  69. HitGroup.Type = other.HitGroup.Type;
  70. HitGroup.AnyHit = other.HitGroup.AnyHit;
  71. HitGroup.ClosestHit = other.HitGroup.ClosestHit;
  72. HitGroup.Intersection = other.HitGroup.Intersection;
  73. break;
  74. case Kind::RaytracingPipelineConfig1:
  75. RaytracingPipelineConfig1.MaxTraceRecursionDepth = other.RaytracingPipelineConfig1.MaxTraceRecursionDepth;
  76. RaytracingPipelineConfig1.Flags = other.RaytracingPipelineConfig1.Flags;
  77. break;
  78. default:
  79. DXASSERT(0, "invalid kind");
  80. break;
  81. }
  82. }
  83. void DxilSubobject::InternStrings() {
  84. // Transfer strings if necessary
  85. m_Name = m_Owner.InternString(m_Name).data();
  86. switch (m_Kind) {
  87. case Kind::SubobjectToExportsAssociation:
  88. SubobjectToExportsAssociation.Subobject = m_Owner.InternString(SubobjectToExportsAssociation.Subobject).data();
  89. for (auto &ptr : m_Exports)
  90. ptr = m_Owner.InternString(ptr).data();
  91. break;
  92. case Kind::HitGroup:
  93. HitGroup.AnyHit = m_Owner.InternString(HitGroup.AnyHit).data();
  94. HitGroup.ClosestHit = m_Owner.InternString(HitGroup.ClosestHit).data();
  95. HitGroup.Intersection = m_Owner.InternString(HitGroup.Intersection).data();
  96. break;
  97. default:
  98. break;
  99. }
  100. }
  101. DxilSubobject::~DxilSubobject() {
  102. }
  103. // StateObjectConfig
  104. bool DxilSubobject::GetStateObjectConfig(uint32_t &Flags) const {
  105. if (m_Kind == Kind::StateObjectConfig) {
  106. Flags = StateObjectConfig.Flags;
  107. return true;
  108. }
  109. return false;
  110. }
  111. // Local/Global RootSignature
  112. bool DxilSubobject::GetRootSignature(
  113. bool local, const void * &Data, uint32_t &Size, const char **pText) const {
  114. Kind expected = local ? Kind::LocalRootSignature : Kind::GlobalRootSignature;
  115. if (m_Kind == expected) {
  116. Data = RootSignature.Data;
  117. Size = RootSignature.Size;
  118. if (pText)
  119. *pText = RootSignature.Text;
  120. return true;
  121. }
  122. return false;
  123. }
  124. // SubobjectToExportsAssociation
  125. bool DxilSubobject::GetSubobjectToExportsAssociation(
  126. llvm::StringRef &Subobject,
  127. const char * const * &Exports,
  128. uint32_t &NumExports) const {
  129. if (m_Kind == Kind::SubobjectToExportsAssociation) {
  130. Subobject = SubobjectToExportsAssociation.Subobject;
  131. Exports = m_Exports.data();
  132. NumExports = (uint32_t)m_Exports.size();
  133. return true;
  134. }
  135. return false;
  136. }
  137. // RaytracingShaderConfig
  138. bool DxilSubobject::GetRaytracingShaderConfig(uint32_t &MaxPayloadSizeInBytes,
  139. uint32_t &MaxAttributeSizeInBytes) const {
  140. if (m_Kind == Kind::RaytracingShaderConfig) {
  141. MaxPayloadSizeInBytes = RaytracingShaderConfig.MaxPayloadSizeInBytes;
  142. MaxAttributeSizeInBytes = RaytracingShaderConfig.MaxAttributeSizeInBytes;
  143. return true;
  144. }
  145. return false;
  146. }
  147. // RaytracingPipelineConfig
  148. bool DxilSubobject::GetRaytracingPipelineConfig(
  149. uint32_t &MaxTraceRecursionDepth) const {
  150. if (m_Kind == Kind::RaytracingPipelineConfig) {
  151. MaxTraceRecursionDepth = RaytracingPipelineConfig.MaxTraceRecursionDepth;
  152. return true;
  153. }
  154. return false;
  155. }
  156. // HitGroup
  157. bool DxilSubobject::GetHitGroup(DXIL::HitGroupType &hitGroupType,
  158. llvm::StringRef &AnyHit,
  159. llvm::StringRef &ClosestHit,
  160. llvm::StringRef &Intersection) const {
  161. if (m_Kind == Kind::HitGroup) {
  162. hitGroupType = HitGroup.Type;
  163. AnyHit = HitGroup.AnyHit;
  164. ClosestHit = HitGroup.ClosestHit;
  165. Intersection = HitGroup.Intersection;
  166. return true;
  167. }
  168. return false;
  169. }
  170. // RaytracingPipelineConfig1
  171. bool DxilSubobject::GetRaytracingPipelineConfig1(
  172. uint32_t &MaxTraceRecursionDepth, uint32_t &Flags) const {
  173. if (m_Kind == Kind::RaytracingPipelineConfig1) {
  174. MaxTraceRecursionDepth = RaytracingPipelineConfig1.MaxTraceRecursionDepth;
  175. Flags = RaytracingPipelineConfig1.Flags;
  176. return true;
  177. }
  178. return false;
  179. }
  180. DxilSubobjects::DxilSubobjects()
  181. : m_BytesStorage()
  182. , m_Subobjects()
  183. {}
  184. DxilSubobjects::DxilSubobjects(DxilSubobjects &&other)
  185. : m_BytesStorage(std::move(other.m_BytesStorage))
  186. , m_Subobjects(std::move(other.m_Subobjects))
  187. {}
  188. DxilSubobjects::~DxilSubobjects() {}
  189. llvm::StringRef DxilSubobjects::InternString(llvm::StringRef value) {
  190. auto it = m_BytesStorage.find(value);
  191. if (it != m_BytesStorage.end())
  192. return it->first;
  193. size_t size = value.size();
  194. StoredBytes stored(std::make_pair(std::unique_ptr<char[]>(new char[size + 1]), size + 1));
  195. memcpy(stored.first.get(), value.data(), size);
  196. stored.first[size] = 0;
  197. llvm::StringRef key(stored.first.get(), size);
  198. m_BytesStorage[key] = std::move(stored);
  199. return key;
  200. }
  201. const void *DxilSubobjects::InternRawBytes(const void *ptr, size_t size) {
  202. auto it = m_BytesStorage.find(llvm::StringRef((const char *)ptr, size));
  203. if (it != m_BytesStorage.end())
  204. return it->first.data();
  205. StoredBytes stored(std::make_pair(std::unique_ptr<char[]>(new char[size]), size));
  206. memcpy(stored.first.get(), ptr, size);
  207. llvm::StringRef key(stored.first.get(), size);
  208. m_BytesStorage[key] = std::move(stored);
  209. return key.data();
  210. }
  211. DxilSubobject *DxilSubobjects::FindSubobject(llvm::StringRef name) {
  212. auto it = m_Subobjects.find(name);
  213. if (it != m_Subobjects.end())
  214. return it->second.get();
  215. return nullptr;
  216. }
  217. void DxilSubobjects::RemoveSubobject(llvm::StringRef name) {
  218. auto it = m_Subobjects.find(name);
  219. if (it != m_Subobjects.end())
  220. m_Subobjects.erase(it);
  221. }
  222. DxilSubobject &DxilSubobjects::CloneSubobject(
  223. const DxilSubobject &Subobject, llvm::StringRef Name) {
  224. Name = InternString(Name);
  225. DXASSERT(FindSubobject(Name) == nullptr,
  226. "otherwise, name collision between subobjects");
  227. std::unique_ptr<DxilSubobject> ptr(new DxilSubobject(*this, Subobject, Name));
  228. DxilSubobject &ref = *ptr;
  229. m_Subobjects[Name] = std::move(ptr);
  230. return ref;
  231. }
  232. // Create DxilSubobjects
  233. DxilSubobject &DxilSubobjects::CreateStateObjectConfig(
  234. llvm::StringRef Name, uint32_t Flags) {
  235. DXASSERT_NOMSG(0 == ((~(uint32_t)DXIL::StateObjectFlags::ValidMask) & Flags));
  236. auto &obj = CreateSubobject(Kind::StateObjectConfig, Name);
  237. obj.StateObjectConfig.Flags = Flags;
  238. return obj;
  239. }
  240. DxilSubobject &DxilSubobjects::CreateRootSignature(
  241. llvm::StringRef Name, bool local, const void *Data, uint32_t Size, llvm::StringRef *pText /*= nullptr*/) {
  242. auto &obj = CreateSubobject(local ? Kind::LocalRootSignature : Kind::GlobalRootSignature, Name);
  243. obj.RootSignature.Data = InternRawBytes(Data, Size);
  244. obj.RootSignature.Size = Size;
  245. obj.RootSignature.Text = (pText ? InternString(*pText).data() : nullptr);
  246. return obj;
  247. }
  248. DxilSubobject &DxilSubobjects::CreateSubobjectToExportsAssociation(
  249. llvm::StringRef Name,
  250. llvm::StringRef Subobject,
  251. llvm::StringRef *Exports,
  252. uint32_t NumExports) {
  253. auto &obj = CreateSubobject(Kind::SubobjectToExportsAssociation, Name);
  254. Subobject = InternString(Subobject);
  255. obj.SubobjectToExportsAssociation.Subobject = Subobject.data();
  256. for (unsigned i = 0; i < NumExports; i++) {
  257. obj.m_Exports.emplace_back(InternString(Exports[i]).data());
  258. }
  259. return obj;
  260. }
  261. DxilSubobject &DxilSubobjects::CreateRaytracingShaderConfig(
  262. llvm::StringRef Name,
  263. uint32_t MaxPayloadSizeInBytes,
  264. uint32_t MaxAttributeSizeInBytes) {
  265. auto &obj = CreateSubobject(Kind::RaytracingShaderConfig, Name);
  266. obj.RaytracingShaderConfig.MaxPayloadSizeInBytes = MaxPayloadSizeInBytes;
  267. obj.RaytracingShaderConfig.MaxAttributeSizeInBytes = MaxAttributeSizeInBytes;
  268. return obj;
  269. }
  270. DxilSubobject &DxilSubobjects::CreateRaytracingPipelineConfig(
  271. llvm::StringRef Name,
  272. uint32_t MaxTraceRecursionDepth) {
  273. auto &obj = CreateSubobject(Kind::RaytracingPipelineConfig, Name);
  274. obj.RaytracingPipelineConfig.MaxTraceRecursionDepth = MaxTraceRecursionDepth;
  275. return obj;
  276. }
  277. DxilSubobject &DxilSubobjects::CreateHitGroup(llvm::StringRef Name,
  278. DXIL::HitGroupType hitGroupType,
  279. llvm::StringRef AnyHit,
  280. llvm::StringRef ClosestHit,
  281. llvm::StringRef Intersection) {
  282. auto &obj = CreateSubobject(Kind::HitGroup, Name);
  283. AnyHit = InternString(AnyHit);
  284. ClosestHit = InternString(ClosestHit);
  285. Intersection = InternString(Intersection);
  286. obj.HitGroup.Type = hitGroupType;
  287. obj.HitGroup.AnyHit = AnyHit.data();
  288. obj.HitGroup.ClosestHit = ClosestHit.data();
  289. obj.HitGroup.Intersection = Intersection.data();
  290. return obj;
  291. }
  292. DxilSubobject &DxilSubobjects::CreateRaytracingPipelineConfig1(
  293. llvm::StringRef Name, uint32_t MaxTraceRecursionDepth, uint32_t Flags) {
  294. auto &obj = CreateSubobject(Kind::RaytracingPipelineConfig1, Name);
  295. obj.RaytracingPipelineConfig1.MaxTraceRecursionDepth = MaxTraceRecursionDepth;
  296. DXASSERT_NOMSG(
  297. 0 == ((~(uint32_t)DXIL::RaytracingPipelineFlags::ValidMask) & Flags));
  298. obj.RaytracingPipelineConfig1.Flags = Flags;
  299. return obj;
  300. }
  301. DxilSubobject &DxilSubobjects::CreateSubobject(Kind kind, llvm::StringRef Name) {
  302. Name = InternString(Name);
  303. IFTBOOLMSG(FindSubobject(Name) == nullptr, DXC_E_GENERAL_INTERNAL_ERROR, "Subobject name collision");
  304. IFTBOOLMSG(!Name.empty(), DXC_E_GENERAL_INTERNAL_ERROR, "Empty Subobject name");
  305. std::unique_ptr<DxilSubobject> ptr(new DxilSubobject(*this, kind, Name));
  306. DxilSubobject &ref = *ptr;
  307. m_Subobjects[Name] = std::move(ptr);
  308. return ref;
  309. }
  310. bool LoadSubobjectsFromRDAT(DxilSubobjects &subobjects, RDAT::SubobjectTableReader *pSubobjectTableReader) {
  311. if (!pSubobjectTableReader)
  312. return false;
  313. bool result = true;
  314. for (unsigned i = 0; i < pSubobjectTableReader->GetCount(); ++i) {
  315. try {
  316. auto reader = pSubobjectTableReader->GetItem(i);
  317. DXIL::SubobjectKind kind = reader.GetKind();
  318. bool bLocalRS = false;
  319. switch (kind) {
  320. case DXIL::SubobjectKind::StateObjectConfig:
  321. subobjects.CreateStateObjectConfig(reader.GetName(),
  322. reader.GetStateObjectConfig_Flags());
  323. break;
  324. case DXIL::SubobjectKind::LocalRootSignature:
  325. bLocalRS = true;
  326. case DXIL::SubobjectKind::GlobalRootSignature: {
  327. const void *pOutBytes;
  328. uint32_t OutSizeInBytes;
  329. if (!reader.GetRootSignature(&pOutBytes, &OutSizeInBytes)) {
  330. result = false;
  331. continue;
  332. }
  333. subobjects.CreateRootSignature(reader.GetName(), bLocalRS, pOutBytes, OutSizeInBytes);
  334. break;
  335. }
  336. case DXIL::SubobjectKind::SubobjectToExportsAssociation: {
  337. uint32_t NumExports = reader.GetSubobjectToExportsAssociation_NumExports();
  338. std::vector<llvm::StringRef> Exports;
  339. Exports.resize(NumExports);
  340. for (unsigned i = 0; i < NumExports; ++i) {
  341. Exports[i] = reader.GetSubobjectToExportsAssociation_Export(i);
  342. }
  343. subobjects.CreateSubobjectToExportsAssociation(reader.GetName(),
  344. reader.GetSubobjectToExportsAssociation_Subobject(),
  345. Exports.data(), NumExports);
  346. break;
  347. }
  348. case DXIL::SubobjectKind::RaytracingShaderConfig:
  349. subobjects.CreateRaytracingShaderConfig(reader.GetName(),
  350. reader.GetRaytracingShaderConfig_MaxPayloadSizeInBytes(),
  351. reader.GetRaytracingShaderConfig_MaxAttributeSizeInBytes());
  352. break;
  353. case DXIL::SubobjectKind::RaytracingPipelineConfig:
  354. subobjects.CreateRaytracingPipelineConfig(reader.GetName(),
  355. reader.GetRaytracingPipelineConfig_MaxTraceRecursionDepth());
  356. break;
  357. case DXIL::SubobjectKind::HitGroup:
  358. subobjects.CreateHitGroup(reader.GetName(),
  359. reader.GetHitGroup_Type(),
  360. reader.GetHitGroup_AnyHit(),
  361. reader.GetHitGroup_ClosestHit(),
  362. reader.GetHitGroup_Intersection());
  363. break;
  364. case DXIL::SubobjectKind::RaytracingPipelineConfig1:
  365. subobjects.CreateRaytracingPipelineConfig1(
  366. reader.GetName(),
  367. reader.GetRaytracingPipelineConfig1_MaxTraceRecursionDepth(),
  368. reader.GetRaytracingPipelineConfig1_Flags());
  369. break;
  370. }
  371. }
  372. catch (hlsl::Exception &) {
  373. result = false;
  374. }
  375. }
  376. return result;
  377. }
  378. } // namespace hlsl