DxilContainerAssembler.cpp 53 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilContainerAssembler.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides support for serializing a module into DXIL container structures. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/ADT/MapVector.h"
  12. #include "llvm/IR/Module.h"
  13. #include "llvm/IR/DebugInfo.h"
  14. #include "llvm/Bitcode/ReaderWriter.h"
  15. #include "llvm/Support/MD5.h"
  16. #include "dxc/HLSL/DxilContainer.h"
  17. #include "dxc/HLSL/DxilModule.h"
  18. #include "dxc/HLSL/DxilShaderModel.h"
  19. #include "dxc/HLSL/DxilRootSignature.h"
  20. #include "dxc/HLSL/DxilUtil.h"
  21. #include "dxc/HLSL/DxilFunctionProps.h"
  22. #include "dxc/Support/Global.h"
  23. #include "dxc/Support/Unicode.h"
  24. #include "dxc/Support/WinIncludes.h"
  25. #include "dxc/Support/FileIOHelper.h"
  26. #include "dxc/Support/dxcapi.impl.h"
  27. #include "dxc/HLSL/DxilPipelineStateValidation.h"
  28. #include <algorithm>
  29. #include <functional>
  30. using namespace llvm;
  31. using namespace hlsl;
  32. using namespace hlsl::DXIL::PSV;
  33. static DxilProgramSigSemantic KindToSystemValue(Semantic::Kind kind, DXIL::TessellatorDomain domain) {
  34. switch (kind) {
  35. case Semantic::Kind::Arbitrary: return DxilProgramSigSemantic::Undefined;
  36. case Semantic::Kind::VertexID: return DxilProgramSigSemantic::VertexID;
  37. case Semantic::Kind::InstanceID: return DxilProgramSigSemantic::InstanceID;
  38. case Semantic::Kind::Position: return DxilProgramSigSemantic::Position;
  39. case Semantic::Kind::Coverage: return DxilProgramSigSemantic::Coverage;
  40. case Semantic::Kind::InnerCoverage: return DxilProgramSigSemantic::InnerCoverage;
  41. case Semantic::Kind::PrimitiveID: return DxilProgramSigSemantic::PrimitiveID;
  42. case Semantic::Kind::SampleIndex: return DxilProgramSigSemantic::SampleIndex;
  43. case Semantic::Kind::IsFrontFace: return DxilProgramSigSemantic::IsFrontFace;
  44. case Semantic::Kind::RenderTargetArrayIndex: return DxilProgramSigSemantic::RenderTargetArrayIndex;
  45. case Semantic::Kind::ViewPortArrayIndex: return DxilProgramSigSemantic::ViewPortArrayIndex;
  46. case Semantic::Kind::ClipDistance: return DxilProgramSigSemantic::ClipDistance;
  47. case Semantic::Kind::CullDistance: return DxilProgramSigSemantic::CullDistance;
  48. case Semantic::Kind::Barycentrics: return DxilProgramSigSemantic::Barycentrics;
  49. case Semantic::Kind::TessFactor: {
  50. switch (domain) {
  51. case DXIL::TessellatorDomain::IsoLine:
  52. // Will bu updated to DetailTessFactor in next row.
  53. return DxilProgramSigSemantic::FinalLineDensityTessfactor;
  54. case DXIL::TessellatorDomain::Tri:
  55. return DxilProgramSigSemantic::FinalTriEdgeTessfactor;
  56. case DXIL::TessellatorDomain::Quad:
  57. return DxilProgramSigSemantic::FinalQuadEdgeTessfactor;
  58. }
  59. }
  60. case Semantic::Kind::InsideTessFactor: {
  61. switch (domain) {
  62. case DXIL::TessellatorDomain::IsoLine:
  63. DXASSERT(0, "invalid semantic");
  64. return DxilProgramSigSemantic::Undefined;
  65. case DXIL::TessellatorDomain::Tri:
  66. return DxilProgramSigSemantic::FinalTriInsideTessfactor;
  67. case DXIL::TessellatorDomain::Quad:
  68. return DxilProgramSigSemantic::FinalQuadInsideTessfactor;
  69. }
  70. }
  71. case Semantic::Kind::Invalid:
  72. return DxilProgramSigSemantic::Undefined;
  73. case Semantic::Kind::Target: return DxilProgramSigSemantic::Target;
  74. case Semantic::Kind::Depth: return DxilProgramSigSemantic::Depth;
  75. case Semantic::Kind::DepthLessEqual: return DxilProgramSigSemantic::DepthLE;
  76. case Semantic::Kind::DepthGreaterEqual: return DxilProgramSigSemantic::DepthGE;
  77. case Semantic::Kind::StencilRef:
  78. __fallthrough;
  79. default:
  80. DXASSERT(kind == Semantic::Kind::StencilRef, "else Invalid or switch is missing a case");
  81. return DxilProgramSigSemantic::StencilRef;
  82. }
  83. // TODO: Final_* values need mappings
  84. }
  85. static DxilProgramSigCompType CompTypeToSigCompType(hlsl::CompType value) {
  86. switch (value.GetKind()) {
  87. case CompType::Kind::I32: return DxilProgramSigCompType::SInt32;
  88. case CompType::Kind::U32: return DxilProgramSigCompType::UInt32;
  89. case CompType::Kind::F32: return DxilProgramSigCompType::Float32;
  90. case CompType::Kind::I16: return DxilProgramSigCompType::SInt16;
  91. case CompType::Kind::I64: return DxilProgramSigCompType::SInt64;
  92. case CompType::Kind::U16: return DxilProgramSigCompType::UInt16;
  93. case CompType::Kind::U64: return DxilProgramSigCompType::UInt64;
  94. case CompType::Kind::F16: return DxilProgramSigCompType::Float16;
  95. case CompType::Kind::F64: return DxilProgramSigCompType::Float64;
  96. case CompType::Kind::Invalid: __fallthrough;
  97. case CompType::Kind::I1: __fallthrough;
  98. default:
  99. return DxilProgramSigCompType::Unknown;
  100. }
  101. }
  102. static DxilProgramSigMinPrecision CompTypeToSigMinPrecision(hlsl::CompType value) {
  103. switch (value.GetKind()) {
  104. case CompType::Kind::I32: return DxilProgramSigMinPrecision::Default;
  105. case CompType::Kind::U32: return DxilProgramSigMinPrecision::Default;
  106. case CompType::Kind::F32: return DxilProgramSigMinPrecision::Default;
  107. case CompType::Kind::I1: return DxilProgramSigMinPrecision::Default;
  108. case CompType::Kind::U64: __fallthrough;
  109. case CompType::Kind::I64: __fallthrough;
  110. case CompType::Kind::F64: return DxilProgramSigMinPrecision::Default;
  111. case CompType::Kind::I16: return DxilProgramSigMinPrecision::SInt16;
  112. case CompType::Kind::U16: return DxilProgramSigMinPrecision::UInt16;
  113. case CompType::Kind::F16: return DxilProgramSigMinPrecision::Float16; // Float2_8 is not supported in DXIL.
  114. case CompType::Kind::Invalid: __fallthrough;
  115. default:
  116. return DxilProgramSigMinPrecision::Default;
  117. }
  118. }
  119. template <typename T>
  120. struct sort_second {
  121. bool operator()(const T &a, const T &b) {
  122. return std::less<decltype(a.second)>()(a.second, b.second);
  123. }
  124. };
  125. struct sort_sig {
  126. bool operator()(const DxilProgramSignatureElement &a,
  127. const DxilProgramSignatureElement &b) {
  128. return (a.Stream < b.Stream) |
  129. ((a.Stream == b.Stream) & (a.Register < b.Register));
  130. }
  131. };
  132. class DxilProgramSignatureWriter : public DxilPartWriter {
  133. private:
  134. const DxilSignature &m_signature;
  135. DXIL::TessellatorDomain m_domain;
  136. bool m_isInput;
  137. bool m_useMinPrecision;
  138. size_t m_fixedSize;
  139. typedef std::pair<const char *, uint32_t> NameOffsetPair;
  140. typedef llvm::SmallMapVector<const char *, uint32_t, 8> NameOffsetMap;
  141. uint32_t m_lastOffset;
  142. NameOffsetMap m_semanticNameOffsets;
  143. unsigned m_paramCount;
  144. const char *GetSemanticName(const hlsl::DxilSignatureElement *pElement) {
  145. DXASSERT_NOMSG(pElement != nullptr);
  146. DXASSERT(pElement->GetName() != nullptr, "else sig is malformed");
  147. return pElement->GetName();
  148. }
  149. uint32_t GetSemanticOffset(const hlsl::DxilSignatureElement *pElement) {
  150. const char *pName = GetSemanticName(pElement);
  151. NameOffsetMap::iterator nameOffset = m_semanticNameOffsets.find(pName);
  152. uint32_t result;
  153. if (nameOffset == m_semanticNameOffsets.end()) {
  154. result = m_lastOffset;
  155. m_semanticNameOffsets.insert(NameOffsetPair(pName, result));
  156. m_lastOffset += strlen(pName) + 1;
  157. }
  158. else {
  159. result = nameOffset->second;
  160. }
  161. return result;
  162. }
  163. void write(std::vector<DxilProgramSignatureElement> &orderedSig,
  164. const hlsl::DxilSignatureElement *pElement) {
  165. const std::vector<unsigned> &indexVec = pElement->GetSemanticIndexVec();
  166. unsigned eltCount = pElement->GetSemanticIndexVec().size();
  167. unsigned eltRows = 1;
  168. if (eltCount)
  169. eltRows = pElement->GetRows() / eltCount;
  170. DXASSERT_NOMSG(eltRows == 1);
  171. DxilProgramSignatureElement sig;
  172. memset(&sig, 0, sizeof(DxilProgramSignatureElement));
  173. sig.Stream = pElement->GetOutputStream();
  174. sig.SemanticName = GetSemanticOffset(pElement);
  175. sig.SystemValue = KindToSystemValue(pElement->GetKind(), m_domain);
  176. sig.CompType = CompTypeToSigCompType(pElement->GetCompType());
  177. sig.Register = pElement->GetStartRow();
  178. sig.Mask = pElement->GetColsAsMask();
  179. // Only mark exist channel write for output.
  180. // All channel not used for input.
  181. if (!m_isInput)
  182. sig.NeverWrites_Mask = ~(sig.Mask);
  183. else
  184. sig.AlwaysReads_Mask = 0;
  185. sig.MinPrecision = m_useMinPrecision
  186. ? CompTypeToSigMinPrecision(pElement->GetCompType())
  187. : DxilProgramSigMinPrecision::Default;
  188. for (unsigned i = 0; i < eltCount; ++i) {
  189. sig.SemanticIndex = indexVec[i];
  190. orderedSig.emplace_back(sig);
  191. if (pElement->IsAllocated())
  192. sig.Register += eltRows;
  193. if (sig.SystemValue == DxilProgramSigSemantic::FinalLineDensityTessfactor)
  194. sig.SystemValue = DxilProgramSigSemantic::FinalLineDetailTessfactor;
  195. }
  196. }
  197. void calcSizes() {
  198. // Calculate size for signature elements.
  199. const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>> &elements = m_signature.GetElements();
  200. uint32_t result = sizeof(DxilProgramSignature);
  201. m_paramCount = 0;
  202. for (size_t i = 0; i < elements.size(); ++i) {
  203. DXIL::SemanticInterpretationKind I = elements[i]->GetInterpretation();
  204. if (I == DXIL::SemanticInterpretationKind::NA || I == DXIL::SemanticInterpretationKind::NotInSig)
  205. continue;
  206. unsigned semanticCount = elements[i]->GetSemanticIndexVec().size();
  207. result += semanticCount * sizeof(DxilProgramSignatureElement);
  208. m_paramCount += semanticCount;
  209. }
  210. m_fixedSize = result;
  211. m_lastOffset = m_fixedSize;
  212. // Calculate size for semantic strings.
  213. for (size_t i = 0; i < elements.size(); ++i) {
  214. GetSemanticOffset(elements[i].get());
  215. }
  216. }
  217. public:
  218. DxilProgramSignatureWriter(const DxilSignature &signature,
  219. DXIL::TessellatorDomain domain, bool isInput, bool UseMinPrecision)
  220. : m_signature(signature), m_domain(domain), m_isInput(isInput), m_useMinPrecision(UseMinPrecision) {
  221. calcSizes();
  222. }
  223. __override uint32_t size() const {
  224. return m_lastOffset;
  225. }
  226. __override void write(AbstractMemoryStream *pStream) {
  227. UINT64 startPos = pStream->GetPosition();
  228. const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>> &elements = m_signature.GetElements();
  229. DxilProgramSignature programSig;
  230. programSig.ParamCount = m_paramCount;
  231. programSig.ParamOffset = sizeof(DxilProgramSignature);
  232. IFT(WriteStreamValue(pStream, programSig));
  233. // Write structures in register order.
  234. std::vector<DxilProgramSignatureElement> orderedSig;
  235. for (size_t i = 0; i < elements.size(); ++i) {
  236. DXIL::SemanticInterpretationKind I = elements[i]->GetInterpretation();
  237. if (I == DXIL::SemanticInterpretationKind::NA || I == DXIL::SemanticInterpretationKind::NotInSig)
  238. continue;
  239. write(orderedSig, elements[i].get());
  240. }
  241. std::sort(orderedSig.begin(), orderedSig.end(), sort_sig());
  242. for (size_t i = 0; i < orderedSig.size(); ++i) {
  243. DxilProgramSignatureElement &sigElt = orderedSig[i];
  244. IFT(WriteStreamValue(pStream, sigElt));
  245. }
  246. // Write strings in the offset order.
  247. std::vector<NameOffsetPair> ordered;
  248. ordered.assign(m_semanticNameOffsets.begin(), m_semanticNameOffsets.end());
  249. std::sort(ordered.begin(), ordered.end(), sort_second<NameOffsetPair>());
  250. for (size_t i = 0; i < ordered.size(); ++i) {
  251. const char *pName = ordered[i].first;
  252. ULONG cbWritten;
  253. UINT64 offsetPos = pStream->GetPosition();
  254. DXASSERT_LOCALVAR(offsetPos, offsetPos - startPos == ordered[i].second, "else str offset is incorrect");
  255. IFT(pStream->Write(pName, strlen(pName) + 1, &cbWritten));
  256. }
  257. // Verify we wrote the bytes we though we would.
  258. UINT64 endPos = pStream->GetPosition();
  259. DXASSERT_LOCALVAR(endPos - startPos, endPos - startPos == size(), "else size is incorrect");
  260. }
  261. };
  262. DxilPartWriter *hlsl::NewProgramSignatureWriter(const DxilModule &M, DXIL::SignatureKind Kind) {
  263. switch (Kind) {
  264. case DXIL::SignatureKind::Input:
  265. return new DxilProgramSignatureWriter(
  266. M.GetInputSignature(), M.GetTessellatorDomain(), true,
  267. M.GetUseMinPrecision());
  268. case DXIL::SignatureKind::Output:
  269. return new DxilProgramSignatureWriter(
  270. M.GetOutputSignature(), M.GetTessellatorDomain(), false,
  271. M.GetUseMinPrecision());
  272. case DXIL::SignatureKind::PatchConstant:
  273. return new DxilProgramSignatureWriter(
  274. M.GetPatchConstantSignature(), M.GetTessellatorDomain(),
  275. /*IsInput*/ M.GetShaderModel()->IsDS(),
  276. /*UseMinPrecision*/M.GetUseMinPrecision());
  277. }
  278. return nullptr;
  279. }
  280. class DxilProgramRootSignatureWriter : public DxilPartWriter {
  281. private:
  282. const RootSignatureHandle &m_Sig;
  283. public:
  284. DxilProgramRootSignatureWriter(const RootSignatureHandle &S) : m_Sig(S) {}
  285. uint32_t size() const {
  286. return m_Sig.GetSerializedSize();
  287. }
  288. void write(AbstractMemoryStream *pStream) {
  289. ULONG cbWritten;
  290. IFT(pStream->Write(m_Sig.GetSerializedBytes(), size(), &cbWritten));
  291. }
  292. };
  293. DxilPartWriter *hlsl::NewRootSignatureWriter(const RootSignatureHandle &S) {
  294. return new DxilProgramRootSignatureWriter(S);
  295. }
  296. class DxilFeatureInfoWriter : public DxilPartWriter {
  297. private:
  298. // Only save the shader properties after create class for it.
  299. DxilShaderFeatureInfo featureInfo;
  300. public:
  301. DxilFeatureInfoWriter(const DxilModule &M) {
  302. featureInfo.FeatureFlags = M.m_ShaderFlags.GetFeatureInfo();
  303. }
  304. __override uint32_t size() const {
  305. return sizeof(DxilShaderFeatureInfo);
  306. }
  307. __override void write(AbstractMemoryStream *pStream) {
  308. IFT(WriteStreamValue(pStream, featureInfo.FeatureFlags));
  309. }
  310. };
  311. DxilPartWriter *hlsl::NewFeatureInfoWriter(const DxilModule &M) {
  312. return new DxilFeatureInfoWriter(M);
  313. }
  314. class DxilPSVWriter : public DxilPartWriter {
  315. private:
  316. const DxilModule &m_Module;
  317. PSVInitInfo m_PSVInitInfo;
  318. DxilPipelineStateValidation m_PSV;
  319. uint32_t m_PSVBufferSize;
  320. SmallVector<char, 512> m_PSVBuffer;
  321. SmallVector<char, 256> m_StringBuffer;
  322. SmallVector<uint32_t, 8> m_SemanticIndexBuffer;
  323. std::vector<PSVSignatureElement0> m_SigInputElements;
  324. std::vector<PSVSignatureElement0> m_SigOutputElements;
  325. std::vector<PSVSignatureElement0> m_SigPatchConstantElements;
  326. void SetPSVSigElement(PSVSignatureElement0 &E, const DxilSignatureElement &SE) {
  327. memset(&E, 0, sizeof(PSVSignatureElement0));
  328. if (SE.GetKind() == DXIL::SemanticKind::Arbitrary && strlen(SE.GetName()) > 0) {
  329. E.SemanticName = (uint32_t)m_StringBuffer.size();
  330. StringRef Name(SE.GetName());
  331. m_StringBuffer.append(Name.size()+1, '\0');
  332. memcpy(m_StringBuffer.data() + E.SemanticName, Name.data(), Name.size());
  333. } else {
  334. // m_StringBuffer always starts with '\0' so offset 0 is empty string:
  335. E.SemanticName = 0;
  336. }
  337. // Search index buffer for matching semantic index sequence
  338. DXASSERT_NOMSG(SE.GetRows() == SE.GetSemanticIndexVec().size());
  339. auto &SemIdx = SE.GetSemanticIndexVec();
  340. bool match = false;
  341. for (uint32_t offset = 0; offset + SE.GetRows() - 1 < m_SemanticIndexBuffer.size(); offset++) {
  342. match = true;
  343. for (uint32_t row = 0; row < SE.GetRows(); row++) {
  344. if ((uint32_t)SemIdx[row] != m_SemanticIndexBuffer[offset + row]) {
  345. match = false;
  346. break;
  347. }
  348. }
  349. if (match) {
  350. E.SemanticIndexes = offset;
  351. break;
  352. }
  353. }
  354. if (!match) {
  355. E.SemanticIndexes = m_SemanticIndexBuffer.size();
  356. for (uint32_t row = 0; row < SemIdx.size(); row++) {
  357. m_SemanticIndexBuffer.push_back((uint32_t)SemIdx[row]);
  358. }
  359. }
  360. DXASSERT_NOMSG(SE.GetRows() <= 32);
  361. E.Rows = (uint8_t)SE.GetRows();
  362. DXASSERT_NOMSG(SE.GetCols() <= 4);
  363. E.ColsAndStart = (uint8_t)SE.GetCols() & 0xF;
  364. if (SE.IsAllocated()) {
  365. DXASSERT_NOMSG(SE.GetStartCol() < 4);
  366. DXASSERT_NOMSG(SE.GetStartRow() < 32);
  367. E.ColsAndStart |= 0x40 | (SE.GetStartCol() << 4);
  368. E.StartRow = (uint8_t)SE.GetStartRow();
  369. }
  370. E.SemanticKind = (uint8_t)SE.GetKind();
  371. E.ComponentType = (uint8_t)CompTypeToSigCompType(SE.GetCompType());
  372. E.InterpolationMode = (uint8_t)SE.GetInterpolationMode()->GetKind();
  373. DXASSERT_NOMSG(SE.GetOutputStream() < 4);
  374. E.DynamicMaskAndStream = (uint8_t)((SE.GetOutputStream() & 0x3) << 4);
  375. E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF;
  376. }
  377. const uint32_t *CopyViewIDState(const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars, PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) {
  378. unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
  379. if (ViewIDMask.IsValid()) {
  380. DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
  381. memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords);
  382. pSrc += MaskDwords;
  383. }
  384. if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
  385. DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
  386. DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
  387. memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars);
  388. pSrc += MaskDwords * InputScalars;
  389. }
  390. return pSrc;
  391. }
  392. public:
  393. DxilPSVWriter(const DxilModule &module, uint32_t PSVVersion = 0)
  394. : m_Module(module),
  395. m_PSVInitInfo(PSVVersion)
  396. {
  397. unsigned ValMajor, ValMinor;
  398. m_Module.GetValidatorVersion(ValMajor, ValMinor);
  399. // Allow PSVVersion to be upgraded
  400. if (m_PSVInitInfo.PSVVersion < 1 && (ValMajor > 1 || (ValMajor == 1 && ValMinor >= 1)))
  401. m_PSVInitInfo.PSVVersion = 1;
  402. const ShaderModel *SM = m_Module.GetShaderModel();
  403. UINT uCBuffers = m_Module.GetCBuffers().size();
  404. UINT uSamplers = m_Module.GetSamplers().size();
  405. UINT uSRVs = m_Module.GetSRVs().size();
  406. UINT uUAVs = m_Module.GetUAVs().size();
  407. m_PSVInitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;
  408. // TODO: for >= 6.2 version, create more efficient structure
  409. if (m_PSVInitInfo.PSVVersion > 0) {
  410. m_PSVInitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
  411. // Copy Dxil Signatures
  412. m_StringBuffer.push_back('\0'); // For empty semantic name (system value)
  413. m_PSVInitInfo.SigInputElements = m_Module.GetInputSignature().GetElements().size();
  414. m_SigInputElements.resize(m_PSVInitInfo.SigInputElements);
  415. m_PSVInitInfo.SigOutputElements = m_Module.GetOutputSignature().GetElements().size();
  416. m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements);
  417. m_PSVInitInfo.SigPatchConstantElements = m_Module.GetPatchConstantSignature().GetElements().size();
  418. m_SigPatchConstantElements.resize(m_PSVInitInfo.SigPatchConstantElements);
  419. uint32_t i = 0;
  420. for (auto &SE : m_Module.GetInputSignature().GetElements()) {
  421. SetPSVSigElement(m_SigInputElements[i++], *(SE.get()));
  422. }
  423. i = 0;
  424. for (auto &SE : m_Module.GetOutputSignature().GetElements()) {
  425. SetPSVSigElement(m_SigOutputElements[i++], *(SE.get()));
  426. }
  427. i = 0;
  428. for (auto &SE : m_Module.GetPatchConstantSignature().GetElements()) {
  429. SetPSVSigElement(m_SigPatchConstantElements[i++], *(SE.get()));
  430. }
  431. // Set String and SemanticInput Tables
  432. m_PSVInitInfo.StringTable.Table = m_StringBuffer.data();
  433. m_PSVInitInfo.StringTable.Size = m_StringBuffer.size();
  434. m_PSVInitInfo.SemanticIndexTable.Table = m_SemanticIndexBuffer.data();
  435. m_PSVInitInfo.SemanticIndexTable.Entries = m_SemanticIndexBuffer.size();
  436. // Set up ViewID and signature dependency info
  437. m_PSVInitInfo.UsesViewID = m_Module.m_ShaderFlags.GetViewID() ? true : false;
  438. m_PSVInitInfo.SigInputVectors = m_Module.GetInputSignature().NumVectorsUsed(0);
  439. for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
  440. m_PSVInitInfo.SigOutputVectors[streamIndex] = m_Module.GetOutputSignature().NumVectorsUsed(streamIndex);
  441. }
  442. m_PSVInitInfo.SigPatchConstantVectors = m_PSVInitInfo.SigPatchConstantVectors = 0;
  443. if (SM->IsHS()) {
  444. m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
  445. }
  446. if (SM->IsDS()) {
  447. m_PSVInitInfo.SigPatchConstantVectors = m_Module.GetPatchConstantSignature().NumVectorsUsed(0);
  448. }
  449. }
  450. if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) {
  451. DXASSERT(false, "PSV InitNew failed computing size!");
  452. }
  453. }
  454. __override uint32_t size() const {
  455. return m_PSVBufferSize;
  456. }
  457. __override void write(AbstractMemoryStream *pStream) {
  458. m_PSVBuffer.resize(m_PSVBufferSize);
  459. if (!m_PSV.InitNew(m_PSVInitInfo, m_PSVBuffer.data(), &m_PSVBufferSize)) {
  460. DXASSERT(false, "PSV InitNew failed!");
  461. }
  462. DXASSERT_NOMSG(m_PSVBuffer.size() == m_PSVBufferSize);
  463. // Set DxilRuntimInfo
  464. PSVRuntimeInfo0* pInfo = m_PSV.GetPSVRuntimeInfo0();
  465. PSVRuntimeInfo1* pInfo1 = m_PSV.GetPSVRuntimeInfo1();
  466. const ShaderModel* SM = m_Module.GetShaderModel();
  467. pInfo->MinimumExpectedWaveLaneCount = 0;
  468. pInfo->MaximumExpectedWaveLaneCount = (UINT)-1;
  469. switch (SM->GetKind()) {
  470. case ShaderModel::Kind::Vertex: {
  471. pInfo->VS.OutputPositionPresent = 0;
  472. const DxilSignature &S = m_Module.GetOutputSignature();
  473. for (auto &&E : S.GetElements()) {
  474. if (E->GetKind() == Semantic::Kind::Position) {
  475. // Ideally, we might check never writes mask here,
  476. // but this is not yet part of the signature element in Dxil
  477. pInfo->VS.OutputPositionPresent = 1;
  478. break;
  479. }
  480. }
  481. break;
  482. }
  483. case ShaderModel::Kind::Hull: {
  484. pInfo->HS.InputControlPointCount = (UINT)m_Module.GetInputControlPointCount();
  485. pInfo->HS.OutputControlPointCount = (UINT)m_Module.GetOutputControlPointCount();
  486. pInfo->HS.TessellatorDomain = (UINT)m_Module.GetTessellatorDomain();
  487. pInfo->HS.TessellatorOutputPrimitive = (UINT)m_Module.GetTessellatorOutputPrimitive();
  488. break;
  489. }
  490. case ShaderModel::Kind::Domain: {
  491. pInfo->DS.InputControlPointCount = (UINT)m_Module.GetInputControlPointCount();
  492. pInfo->DS.OutputPositionPresent = 0;
  493. const DxilSignature &S = m_Module.GetOutputSignature();
  494. for (auto &&E : S.GetElements()) {
  495. if (E->GetKind() == Semantic::Kind::Position) {
  496. // Ideally, we might check never writes mask here,
  497. // but this is not yet part of the signature element in Dxil
  498. pInfo->DS.OutputPositionPresent = 1;
  499. break;
  500. }
  501. }
  502. pInfo->DS.TessellatorDomain = (UINT)m_Module.GetTessellatorDomain();
  503. break;
  504. }
  505. case ShaderModel::Kind::Geometry: {
  506. pInfo->GS.InputPrimitive = (UINT)m_Module.GetInputPrimitive();
  507. // NOTE: For OutputTopology, pick one from a used stream, or if none
  508. // are used, use stream 0, and set OutputStreamMask to 1.
  509. pInfo->GS.OutputTopology = (UINT)m_Module.GetStreamPrimitiveTopology();
  510. pInfo->GS.OutputStreamMask = m_Module.GetActiveStreamMask();
  511. if (pInfo->GS.OutputStreamMask == 0) {
  512. pInfo->GS.OutputStreamMask = 1; // This is what runtime expects.
  513. }
  514. pInfo->GS.OutputPositionPresent = 0;
  515. const DxilSignature &S = m_Module.GetOutputSignature();
  516. for (auto &&E : S.GetElements()) {
  517. if (E->GetKind() == Semantic::Kind::Position) {
  518. // Ideally, we might check never writes mask here,
  519. // but this is not yet part of the signature element in Dxil
  520. pInfo->GS.OutputPositionPresent = 1;
  521. break;
  522. }
  523. }
  524. break;
  525. }
  526. case ShaderModel::Kind::Pixel: {
  527. pInfo->PS.DepthOutput = 0;
  528. pInfo->PS.SampleFrequency = 0;
  529. {
  530. const DxilSignature &S = m_Module.GetInputSignature();
  531. for (auto &&E : S.GetElements()) {
  532. if (E->GetInterpolationMode()->IsAnySample() ||
  533. E->GetKind() == Semantic::Kind::SampleIndex) {
  534. pInfo->PS.SampleFrequency = 1;
  535. }
  536. }
  537. }
  538. {
  539. const DxilSignature &S = m_Module.GetOutputSignature();
  540. for (auto &&E : S.GetElements()) {
  541. if (E->IsAnyDepth()) {
  542. pInfo->PS.DepthOutput = 1;
  543. break;
  544. }
  545. }
  546. }
  547. break;
  548. }
  549. }
  550. // Set resource binding information
  551. UINT uResIndex = 0;
  552. for (auto &&R : m_Module.GetCBuffers()) {
  553. DXASSERT_NOMSG(uResIndex < m_PSVInitInfo.ResourceCount);
  554. PSVResourceBindInfo0* pBindInfo = m_PSV.GetPSVResourceBindInfo0(uResIndex);
  555. DXASSERT_NOMSG(pBindInfo);
  556. pBindInfo->ResType = (UINT)PSVResourceType::CBV;
  557. pBindInfo->Space = R->GetSpaceID();
  558. pBindInfo->LowerBound = R->GetLowerBound();
  559. pBindInfo->UpperBound = R->GetUpperBound();
  560. uResIndex++;
  561. }
  562. for (auto &&R : m_Module.GetSamplers()) {
  563. DXASSERT_NOMSG(uResIndex < m_PSVInitInfo.ResourceCount);
  564. PSVResourceBindInfo0* pBindInfo = m_PSV.GetPSVResourceBindInfo0(uResIndex);
  565. DXASSERT_NOMSG(pBindInfo);
  566. pBindInfo->ResType = (UINT)PSVResourceType::Sampler;
  567. pBindInfo->Space = R->GetSpaceID();
  568. pBindInfo->LowerBound = R->GetLowerBound();
  569. pBindInfo->UpperBound = R->GetUpperBound();
  570. uResIndex++;
  571. }
  572. for (auto &&R : m_Module.GetSRVs()) {
  573. DXASSERT_NOMSG(uResIndex < m_PSVInitInfo.ResourceCount);
  574. PSVResourceBindInfo0* pBindInfo = m_PSV.GetPSVResourceBindInfo0(uResIndex);
  575. DXASSERT_NOMSG(pBindInfo);
  576. if (R->IsStructuredBuffer()) {
  577. pBindInfo->ResType = (UINT)PSVResourceType::SRVStructured;
  578. } else if (R->IsRawBuffer()) {
  579. pBindInfo->ResType = (UINT)PSVResourceType::SRVRaw;
  580. } else {
  581. pBindInfo->ResType = (UINT)PSVResourceType::SRVTyped;
  582. }
  583. pBindInfo->Space = R->GetSpaceID();
  584. pBindInfo->LowerBound = R->GetLowerBound();
  585. pBindInfo->UpperBound = R->GetUpperBound();
  586. uResIndex++;
  587. }
  588. for (auto &&R : m_Module.GetUAVs()) {
  589. DXASSERT_NOMSG(uResIndex < m_PSVInitInfo.ResourceCount);
  590. PSVResourceBindInfo0* pBindInfo = m_PSV.GetPSVResourceBindInfo0(uResIndex);
  591. DXASSERT_NOMSG(pBindInfo);
  592. if (R->IsStructuredBuffer()) {
  593. if (R->HasCounter())
  594. pBindInfo->ResType = (UINT)PSVResourceType::UAVStructuredWithCounter;
  595. else
  596. pBindInfo->ResType = (UINT)PSVResourceType::UAVStructured;
  597. } else if (R->IsRawBuffer()) {
  598. pBindInfo->ResType = (UINT)PSVResourceType::UAVRaw;
  599. } else {
  600. pBindInfo->ResType = (UINT)PSVResourceType::UAVTyped;
  601. }
  602. pBindInfo->Space = R->GetSpaceID();
  603. pBindInfo->LowerBound = R->GetLowerBound();
  604. pBindInfo->UpperBound = R->GetUpperBound();
  605. uResIndex++;
  606. }
  607. DXASSERT_NOMSG(uResIndex == m_PSVInitInfo.ResourceCount);
  608. if (m_PSVInitInfo.PSVVersion > 0) {
  609. DXASSERT_NOMSG(pInfo1);
  610. // Write MaxVertexCount
  611. if (SM->IsGS()) {
  612. DXASSERT_NOMSG(m_Module.GetMaxVertexCount() <= 1024);
  613. pInfo1->MaxVertexCount = (uint16_t)m_Module.GetMaxVertexCount();
  614. }
  615. // Write Dxil Signature Elements
  616. for (unsigned i = 0; i < m_PSV.GetSigInputElements(); i++) {
  617. PSVSignatureElement0 *pInputElement = m_PSV.GetInputElement0(i);
  618. DXASSERT_NOMSG(pInputElement);
  619. memcpy(pInputElement, &m_SigInputElements[i], sizeof(PSVSignatureElement0));
  620. }
  621. for (unsigned i = 0; i < m_PSV.GetSigOutputElements(); i++) {
  622. PSVSignatureElement0 *pOutputElement = m_PSV.GetOutputElement0(i);
  623. DXASSERT_NOMSG(pOutputElement);
  624. memcpy(pOutputElement, &m_SigOutputElements[i], sizeof(PSVSignatureElement0));
  625. }
  626. for (unsigned i = 0; i < m_PSV.GetSigPatchConstantElements(); i++) {
  627. PSVSignatureElement0 *pPatchConstantElement = m_PSV.GetPatchConstantElement0(i);
  628. DXASSERT_NOMSG(pPatchConstantElement);
  629. memcpy(pPatchConstantElement, &m_SigPatchConstantElements[i], sizeof(PSVSignatureElement0));
  630. }
  631. // Gather ViewID dependency information
  632. auto &viewState = m_Module.GetViewIdState().GetSerialized();
  633. if (!viewState.empty()) {
  634. const uint32_t *pSrc = viewState.data();
  635. const uint32_t InputScalars = *(pSrc++);
  636. uint32_t OutputScalars[4];
  637. for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
  638. OutputScalars[streamIndex] = *(pSrc++);
  639. pSrc = CopyViewIDState(pSrc, InputScalars, OutputScalars[streamIndex], m_PSV.GetViewIDOutputMask(streamIndex), m_PSV.GetInputToOutputTable(streamIndex));
  640. if (!SM->IsGS())
  641. break;
  642. }
  643. if (SM->IsHS()) {
  644. const uint32_t PCScalars = *(pSrc++);
  645. pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable());
  646. } else if (SM->IsDS()) {
  647. const uint32_t PCScalars = *(pSrc++);
  648. pSrc = CopyViewIDState(pSrc, PCScalars, OutputScalars[0], PSVComponentMask(), m_PSV.GetPCInputToOutputTable());
  649. }
  650. DXASSERT_NOMSG(viewState.data() + viewState.size() == pSrc);
  651. }
  652. }
  653. ULONG cbWritten;
  654. IFT(pStream->Write(m_PSVBuffer.data(), m_PSVBufferSize, &cbWritten));
  655. DXASSERT_NOMSG(cbWritten == m_PSVBufferSize);
  656. }
  657. };
  658. class RDATTable {
  659. public:
  660. virtual uint32_t GetBlobSize() const { return 0; }
  661. virtual void write(void *ptr) {}
  662. virtual RuntimeDataTableType GetType() const { return RuntimeDataTableType::Invalid; }
  663. virtual ~RDATTable() {}
  664. };
  665. class ResourceTable : public RDATTable {
  666. private:
  667. uint32_t m_Version;
  668. std::vector<std::pair<const DxilCBuffer*, uint32_t>> CBufferToOffset;
  669. std::vector<std::pair<const DxilSampler*, uint32_t>> SamplerToOffset;
  670. std::vector<std::pair<const DxilResource*, uint32_t>> SRVToOffset;
  671. std::vector<std::pair<const DxilResource*, uint32_t>> UAVToOffset;
  672. void UpdateResourceInfo(const DxilResourceBase *res, uint32_t offset,
  673. RuntimeDataResourceInfo *info, char **pCur) {
  674. info->Kind = static_cast<uint32_t>(res->GetKind());
  675. info->Space = res->GetSpaceID();
  676. info->LowerBound = res->GetLowerBound();
  677. info->UpperBound = res->GetUpperBound();
  678. info->Name = offset;
  679. memcpy(*pCur, info, sizeof(RuntimeDataResourceInfo));
  680. *pCur += sizeof(RuntimeDataResourceInfo);
  681. }
  682. public:
  683. ResourceTable(uint32_t version) : m_Version(version), CBufferToOffset(), SamplerToOffset(), SRVToOffset(), UAVToOffset() {}
  684. void AddCBuffer(const DxilCBuffer *resource, uint32_t offset) {
  685. CBufferToOffset.emplace_back(
  686. std::pair<const DxilCBuffer *, uint32_t>(resource, offset));
  687. }
  688. void AddSampler(const DxilSampler *resource, uint32_t offset) {
  689. SamplerToOffset.emplace_back(
  690. std::pair<const DxilSampler *, uint32_t>(resource, offset));
  691. }
  692. void AddSRV(const DxilResource *resource, uint32_t offset) {
  693. SRVToOffset.emplace_back(
  694. std::pair<const DxilResource *, uint32_t>(resource, offset));
  695. }
  696. void AddUAV(const DxilResource *resource, uint32_t offset) {
  697. UAVToOffset.emplace_back(
  698. std::pair<const DxilResource *, uint32_t>(resource, offset));
  699. }
  700. uint32_t NumResources() const {
  701. return CBufferToOffset.size() + SamplerToOffset.size() +
  702. SRVToOffset.size() + UAVToOffset.size();
  703. }
  704. RuntimeDataTableType GetType() const { return RuntimeDataTableType::Resource; }
  705. uint32_t GetBlobSize() const {
  706. return NumResources() * sizeof(RuntimeDataResourceInfo) +
  707. 4 * sizeof(uint32_t);
  708. }
  709. void write(void *ptr) {
  710. // Only impelemented for RDAT for now
  711. if (m_Version == 0) {
  712. char *pCur = (char*)ptr;
  713. // count for each resource class
  714. uint32_t cBufferCount = CBufferToOffset.size();
  715. uint32_t samplerCount = SamplerToOffset.size();
  716. uint32_t srvCount = SRVToOffset.size();
  717. uint32_t uavCount = UAVToOffset.size();
  718. memcpy(pCur, &cBufferCount, sizeof(uint32_t));
  719. pCur += sizeof(uint32_t);
  720. memcpy(pCur, &samplerCount, sizeof(uint32_t));
  721. pCur += sizeof(uint32_t);
  722. memcpy(pCur, &srvCount, sizeof(uint32_t));
  723. pCur += sizeof(uint32_t);
  724. memcpy(pCur, &uavCount, sizeof(uint32_t));
  725. pCur += sizeof(uint32_t);
  726. for (auto pair : CBufferToOffset) {
  727. RuntimeDataResourceInfo info = {};
  728. info.ResType = static_cast<uint32_t>(PSVResourceType::CBV);
  729. UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
  730. }
  731. for (auto pair : SamplerToOffset) {
  732. RuntimeDataResourceInfo info = {};
  733. info.ResType = static_cast<uint32_t>(PSVResourceType::Sampler);
  734. UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
  735. }
  736. for (auto pair : SRVToOffset) {
  737. RuntimeDataResourceInfo info = {};
  738. auto res = pair.first;
  739. if (res->IsStructuredBuffer()) {
  740. info.ResType = (UINT)PSVResourceType::SRVStructured;
  741. } else if (res->IsRawBuffer()) {
  742. info.ResType = (UINT)PSVResourceType::SRVRaw;
  743. } else {
  744. info.ResType = (UINT)PSVResourceType::SRVTyped;
  745. }
  746. UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
  747. }
  748. for (auto pair : UAVToOffset) {
  749. RuntimeDataResourceInfo info = {};
  750. auto res = pair.first;
  751. if (res->IsStructuredBuffer()) {
  752. if (res->HasCounter())
  753. info.ResType = (UINT)PSVResourceType::UAVStructuredWithCounter;
  754. else
  755. info.ResType = (UINT)PSVResourceType::UAVStructured;
  756. } else if (res->IsRawBuffer()) {
  757. info.ResType = (UINT)PSVResourceType::UAVRaw;
  758. } else {
  759. info.ResType = (UINT)PSVResourceType::UAVTyped;
  760. }
  761. UpdateResourceInfo(res, pair.second, &info, &pCur);
  762. }
  763. }
  764. }
  765. };
  766. class FunctionTable : public RDATTable {
  767. private:
  768. std::vector<std::pair<const llvm::Function *, RuntimeDataFunctionInfo>> FuncToInfo;
  769. public:
  770. FunctionTable(): FuncToInfo() {}
  771. uint32_t NumFunctions() const { return FuncToInfo.size(); }
  772. void AddFunction(const llvm::Function *func, uint32_t mangledOfffset,
  773. uint32_t unmangledOffset, uint32_t shaderKind, uint32_t resourceIndex,
  774. uint32_t payloadSizeInBytes, uint32_t attrSizeInBytes, ShaderFlags flags) {
  775. RuntimeDataFunctionInfo info = {};
  776. info.Name = mangledOfffset;
  777. info.UnmangledName = unmangledOffset;
  778. info.ShaderKind = shaderKind;
  779. info.Resources = resourceIndex;
  780. info.PayloadSizeInBytes = payloadSizeInBytes;
  781. info.AttributeSizeInBytes = attrSizeInBytes;
  782. uint64_t rawFlags = flags.GetShaderFlagsRaw();
  783. info.FeatureInfo1 = rawFlags & 0xffffffff;
  784. info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff;
  785. FuncToInfo.push_back({ func, info });
  786. }
  787. uint32_t GetBlobSize() const { return NumFunctions() * sizeof(RuntimeDataFunctionInfo); }
  788. RuntimeDataTableType GetType() const { return RuntimeDataTableType::Function; }
  789. void write(void *ptr) {
  790. char *cur = (char *)ptr;
  791. for (auto &&pair : FuncToInfo) {
  792. auto offset = pair.second;
  793. memcpy(cur, &offset, sizeof(RuntimeDataFunctionInfo));
  794. cur += sizeof(RuntimeDataFunctionInfo);
  795. }
  796. }
  797. };
  798. class StringTable : public RDATTable {
  799. private:
  800. SmallVector<char, 256> m_StringBuffer;
  801. uint32_t curIndex;
  802. public:
  803. StringTable() : m_StringBuffer(), curIndex(0) {}
  804. // returns the offset of the name inserted
  805. uint32_t Insert(StringRef name) {
  806. for (auto iter = name.begin(), End = name.end(); iter != End; ++iter) {
  807. m_StringBuffer.push_back(*iter);
  808. }
  809. m_StringBuffer.push_back('\0');
  810. uint32_t prevIndex = curIndex;
  811. curIndex += name.size() + 1;
  812. return prevIndex;
  813. }
  814. RuntimeDataTableType GetType() const { return RuntimeDataTableType::String; }
  815. uint32_t GetBlobSize() const { return m_StringBuffer.size(); }
  816. void write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); }
  817. };
  818. template <class T>
  819. struct IndexTable : public RDATTable {
  820. private:
  821. std::vector<std::vector<T>> m_IndicesList;
  822. uint32_t m_curOffset;
  823. public:
  824. IndexTable() : m_IndicesList(), m_curOffset(0) {}
  825. uint32_t AddIndex(const std::vector<T> &Indices) {
  826. uint32_t prevOffset = m_curOffset;
  827. m_curOffset += Indices.size() + 1;
  828. m_IndicesList.emplace_back(std::move(Indices));
  829. return prevOffset;
  830. }
  831. RuntimeDataTableType GetType() const { return RuntimeDataTableType::Index; }
  832. uint32_t GetBlobSize() const {
  833. uint32_t size = 0;
  834. for (auto Indices : m_IndicesList) {
  835. size += Indices.size() + 1;
  836. }
  837. return sizeof(T) * size;
  838. }
  839. void write(void *ptr) {
  840. T *cur = (T*)ptr;
  841. for (auto Indices : m_IndicesList) {
  842. uint32_t count = Indices.size();
  843. memcpy(cur, &count, 4);
  844. std::copy(Indices.data(), Indices.data() + Indices.size(), cur + 1);
  845. cur += sizeof(T)/sizeof(4) + Indices.size();
  846. }
  847. }
  848. };
  849. class DxilRDATWriter : public DxilPartWriter {
  850. private:
  851. const DxilModule &m_Module;
  852. SmallVector<char, 1024> m_RDATBuffer;
  853. std::vector<std::unique_ptr<RDATTable>> m_tables;
  854. std::map<llvm::Function *, std::vector<uint32_t>> m_FuncToResNameOffset;
  855. void UpdateFunctionToResourceInfo(const DxilResourceBase *resource, uint32_t offset) {
  856. Constant *var = resource->GetGlobalSymbol();
  857. if (var) {
  858. for (auto user : var->users()) {
  859. if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(user)) {
  860. if (llvm::Function *F = dyn_cast<llvm::Function>(I->getParent()->getParent())) {
  861. if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) {
  862. m_FuncToResNameOffset[F].emplace_back(offset);
  863. }
  864. else {
  865. m_FuncToResNameOffset[F] = std::vector<uint32_t>({offset});
  866. }
  867. }
  868. }
  869. }
  870. }
  871. }
  872. void UpdateResourceInfo(StringTable &stringTable) {
  873. // Try to allocate string table for resources. String table is a sequence
  874. // of strings delimited by \0
  875. m_tables.emplace_back(std::make_unique<ResourceTable>(0));
  876. ResourceTable &resourceTable = *(ResourceTable*)m_tables.back().get();
  877. uint32_t stringIndex;
  878. uint32_t resourceIndex = 0;
  879. for (auto &resource : m_Module.GetCBuffers()) {
  880. stringIndex = stringTable.Insert(resource->GetGlobalName());
  881. UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
  882. resourceTable.AddCBuffer(resource.get(), stringIndex);
  883. }
  884. for (auto &resource : m_Module.GetSamplers()) {
  885. stringIndex = stringTable.Insert(resource->GetGlobalName());
  886. UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
  887. resourceTable.AddSampler(resource.get(), stringIndex);
  888. }
  889. for (auto &resource : m_Module.GetSRVs()) {
  890. stringIndex = stringTable.Insert(resource->GetGlobalName());
  891. UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
  892. resourceTable.AddSRV(resource.get(), stringIndex);
  893. }
  894. for (auto &resource : m_Module.GetUAVs()) {
  895. stringIndex = stringTable.Insert(resource->GetGlobalName());
  896. UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
  897. resourceTable.AddUAV(resource.get(), stringIndex);
  898. }
  899. }
  900. void UpdateFunctionInfo(StringTable &stringTable) {
  901. // TODO: get a list of required features
  902. // TODO: get a list of valid shader flags
  903. // TODO: get a minimum shader version
  904. m_tables.emplace_back(std::make_unique<FunctionTable>());
  905. FunctionTable &functionTable = *(FunctionTable*)(m_tables.back().get());
  906. m_tables.emplace_back(std::make_unique<IndexTable<uint32_t>>());
  907. IndexTable<uint32_t> &indexTable = *(IndexTable<uint32_t>*)(m_tables.back().get());
  908. for (auto &function : m_Module.GetModule()->getFunctionList()) {
  909. if (!function.isDeclaration()) {
  910. StringRef mangled = function.getName();
  911. StringRef unmangled = hlsl::dxilutil::DemangleFunctionName(function.getName());
  912. uint32_t mangledIndex = stringTable.Insert(mangled);
  913. uint32_t unmangledIndex = stringTable.Insert(unmangled);
  914. // Update resource Index
  915. uint32_t resourceIndex = UINT_MAX;
  916. uint32_t payloadSizeInBytes = 0;
  917. uint32_t attrSizeInBytes = 0;
  918. uint32_t shaderKind = (uint32_t)PSVShaderKind::Library;
  919. if (m_FuncToResNameOffset.find(&function) != m_FuncToResNameOffset.end())
  920. resourceIndex = indexTable.AddIndex(m_FuncToResNameOffset[&function]);
  921. if (m_Module.HasDxilFunctionProps(&function)) {
  922. auto props = m_Module.GetDxilFunctionProps(&function);
  923. if (props.IsClosestHit() || props.IsAnyHit()) {
  924. payloadSizeInBytes = props.ShaderProps.Ray.payloadSizeInBytes;
  925. attrSizeInBytes = props.ShaderProps.Ray.attributeSizeInBytes;
  926. }
  927. else if (props.IsMiss()) {
  928. payloadSizeInBytes = props.ShaderProps.Ray.payloadSizeInBytes;
  929. }
  930. else if (props.IsCallable()) {
  931. payloadSizeInBytes = props.ShaderProps.Ray.paramSizeInBytes;
  932. }
  933. shaderKind = (uint32_t)props.shaderKind;
  934. }
  935. ShaderFlags flags = ShaderFlags::CollectShaderFlags(&function, &m_Module);
  936. functionTable.AddFunction(&function, mangledIndex, unmangledIndex,
  937. shaderKind, resourceIndex,
  938. payloadSizeInBytes, attrSizeInBytes, flags);
  939. }
  940. }
  941. }
  942. public:
  943. DxilRDATWriter(const DxilModule &module, uint32_t InfoVersion = 0)
  944. : m_Module(module), m_RDATBuffer(), m_tables(), m_FuncToResNameOffset() {
  945. // It's important to keep the order of this update
  946. m_tables.emplace_back(std::make_unique<StringTable>());
  947. StringTable &stringTable = *(StringTable*)m_tables.back().get();
  948. UpdateResourceInfo(stringTable);
  949. UpdateFunctionInfo(stringTable);
  950. }
  951. __override uint32_t size() const {
  952. // one variable to count the number of blobs and two blobs
  953. uint32_t total = 4 + m_tables.size() * sizeof(RuntimeDataTableHeader);
  954. for (auto &&table : m_tables)
  955. total += table->GetBlobSize();
  956. return total;
  957. }
  958. __override void write(AbstractMemoryStream *pStream) {
  959. m_RDATBuffer.resize(size());
  960. char *pCur = m_RDATBuffer.data();
  961. // write number of tables
  962. uint32_t size = m_tables.size();
  963. memcpy(pCur, &size, sizeof(uint32_t));
  964. pCur += sizeof(uint32_t);
  965. // write records
  966. uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4;
  967. for (auto &&table : m_tables) {
  968. RuntimeDataTableHeader record = { table->GetType(), table->GetBlobSize(), curTableOffset };
  969. memcpy(pCur, &record, sizeof(RuntimeDataTableHeader));
  970. pCur += sizeof(RuntimeDataTableHeader);
  971. curTableOffset += record.size;
  972. }
  973. // write tables
  974. for (auto &&table : m_tables) {
  975. table->write(pCur);
  976. pCur += table->GetBlobSize();
  977. }
  978. ULONG cbWritten;
  979. IFT(pStream->Write(m_RDATBuffer.data(), m_RDATBuffer.size(), &cbWritten));
  980. DXASSERT_NOMSG(cbWritten == m_RDATBuffer.size());
  981. }
  982. };
  983. DxilPartWriter *hlsl::NewPSVWriter(const DxilModule &M, uint32_t PSVVersion) {
  984. return new DxilPSVWriter(M, PSVVersion);
  985. }
  986. class DxilContainerWriter_impl : public DxilContainerWriter {
  987. private:
  988. class DxilPart {
  989. public:
  990. DxilPartHeader Header;
  991. WriteFn Write;
  992. DxilPart(uint32_t fourCC, uint32_t size, WriteFn write) : Write(write) {
  993. Header.PartFourCC = fourCC;
  994. Header.PartSize = size;
  995. }
  996. };
  997. llvm::SmallVector<DxilPart, 8> m_Parts;
  998. public:
  999. __override void AddPart(uint32_t FourCC, uint32_t Size, WriteFn Write) {
  1000. m_Parts.emplace_back(FourCC, Size, Write);
  1001. }
  1002. __override uint32_t size() const {
  1003. uint32_t partSize = 0;
  1004. for (auto &part : m_Parts) {
  1005. partSize += part.Header.PartSize;
  1006. }
  1007. return (uint32_t)GetDxilContainerSizeFromParts((uint32_t)m_Parts.size(), partSize);
  1008. }
  1009. __override void write(AbstractMemoryStream *pStream) {
  1010. DxilContainerHeader header;
  1011. const uint32_t PartCount = (uint32_t)m_Parts.size();
  1012. uint32_t containerSizeInBytes = size();
  1013. InitDxilContainer(&header, PartCount, containerSizeInBytes);
  1014. IFT(pStream->Reserve(header.ContainerSizeInBytes));
  1015. IFT(WriteStreamValue(pStream, header));
  1016. uint32_t offset = sizeof(header) + (uint32_t)GetOffsetTableSize(PartCount);
  1017. for (auto &&part : m_Parts) {
  1018. IFT(WriteStreamValue(pStream, offset));
  1019. offset += sizeof(DxilPartHeader) + part.Header.PartSize;
  1020. }
  1021. for (auto &&part : m_Parts) {
  1022. IFT(WriteStreamValue(pStream, part.Header));
  1023. size_t start = pStream->GetPosition();
  1024. part.Write(pStream);
  1025. DXASSERT_LOCALVAR(start, pStream->GetPosition() - start == (size_t)part.Header.PartSize, "out of bound");
  1026. }
  1027. DXASSERT(containerSizeInBytes == (uint32_t)pStream->GetPosition(), "else stream size is incorrect");
  1028. }
  1029. };
  1030. DxilContainerWriter *hlsl::NewDxilContainerWriter() {
  1031. return new DxilContainerWriter_impl();
  1032. }
  1033. static bool HasDebugInfo(const Module &M) {
  1034. for (Module::const_named_metadata_iterator NMI = M.named_metadata_begin(),
  1035. NME = M.named_metadata_end();
  1036. NMI != NME; ++NMI) {
  1037. if (NMI->getName().startswith("llvm.dbg.")) {
  1038. return true;
  1039. }
  1040. }
  1041. return false;
  1042. }
  1043. static void GetPaddedProgramPartSize(AbstractMemoryStream *pStream,
  1044. uint32_t &bitcodeInUInt32,
  1045. uint32_t &bitcodePaddingBytes) {
  1046. bitcodeInUInt32 = pStream->GetPtrSize();
  1047. bitcodePaddingBytes = (bitcodeInUInt32 % 4);
  1048. bitcodeInUInt32 = (bitcodeInUInt32 / 4) + (bitcodePaddingBytes ? 1 : 0);
  1049. }
  1050. static void WriteProgramPart(const ShaderModel *pModel,
  1051. AbstractMemoryStream *pModuleBitcode,
  1052. AbstractMemoryStream *pStream) {
  1053. DXASSERT(pModel != nullptr, "else generation should have failed");
  1054. DxilProgramHeader programHeader;
  1055. uint32_t shaderVersion =
  1056. EncodeVersion(pModel->GetKind(), pModel->GetMajor(), pModel->GetMinor());
  1057. unsigned dxilMajor, dxilMinor;
  1058. pModel->GetDxilVersion(dxilMajor, dxilMinor);
  1059. uint32_t dxilVersion = DXIL::MakeDxilVersion(dxilMajor, dxilMinor);
  1060. InitProgramHeader(programHeader, shaderVersion, dxilVersion, pModuleBitcode->GetPtrSize());
  1061. uint32_t programInUInt32, programPaddingBytes;
  1062. GetPaddedProgramPartSize(pModuleBitcode, programInUInt32,
  1063. programPaddingBytes);
  1064. ULONG cbWritten;
  1065. IFT(WriteStreamValue(pStream, programHeader));
  1066. IFT(pStream->Write(pModuleBitcode->GetPtr(), pModuleBitcode->GetPtrSize(),
  1067. &cbWritten));
  1068. if (programPaddingBytes) {
  1069. uint32_t paddingValue = 0;
  1070. IFT(pStream->Write(&paddingValue, programPaddingBytes, &cbWritten));
  1071. }
  1072. }
  1073. void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
  1074. AbstractMemoryStream *pModuleBitcode,
  1075. AbstractMemoryStream *pFinalStream,
  1076. SerializeDxilFlags Flags) {
  1077. // TODO: add a flag to update the module and remove information that is not part
  1078. // of DXIL proper and is used only to assemble the container.
  1079. DXASSERT_NOMSG(pModule != nullptr);
  1080. DXASSERT_NOMSG(pModuleBitcode != nullptr);
  1081. DXASSERT_NOMSG(pFinalStream != nullptr);
  1082. unsigned ValMajor, ValMinor;
  1083. pModule->GetValidatorVersion(ValMajor, ValMinor);
  1084. if (ValMajor == 1 && ValMinor == 0)
  1085. Flags &= ~SerializeDxilFlags::IncludeDebugNamePart;
  1086. DxilProgramSignatureWriter inputSigWriter(
  1087. pModule->GetInputSignature(), pModule->GetTessellatorDomain(),
  1088. /*IsInput*/ true,
  1089. /*UseMinPrecision*/ pModule->GetUseMinPrecision());
  1090. DxilProgramSignatureWriter outputSigWriter(
  1091. pModule->GetOutputSignature(), pModule->GetTessellatorDomain(),
  1092. /*IsInput*/ false,
  1093. /*UseMinPrecision*/ pModule->GetUseMinPrecision());
  1094. DxilContainerWriter_impl writer;
  1095. // Write the feature part.
  1096. DxilFeatureInfoWriter featureInfoWriter(*pModule);
  1097. writer.AddPart(DFCC_FeatureInfo, featureInfoWriter.size(), [&](AbstractMemoryStream *pStream) {
  1098. featureInfoWriter.write(pStream);
  1099. });
  1100. // Write the input and output signature parts.
  1101. writer.AddPart(DFCC_InputSignature, inputSigWriter.size(), [&](AbstractMemoryStream *pStream) {
  1102. inputSigWriter.write(pStream);
  1103. });
  1104. writer.AddPart(DFCC_OutputSignature, outputSigWriter.size(), [&](AbstractMemoryStream *pStream) {
  1105. outputSigWriter.write(pStream);
  1106. });
  1107. DxilProgramSignatureWriter patchConstantSigWriter(
  1108. pModule->GetPatchConstantSignature(), pModule->GetTessellatorDomain(),
  1109. /*IsInput*/ pModule->GetShaderModel()->IsDS(),
  1110. /*UseMinPrecision*/ pModule->GetUseMinPrecision());
  1111. if (pModule->GetPatchConstantSignature().GetElements().size()) {
  1112. writer.AddPart(DFCC_PatchConstantSignature, patchConstantSigWriter.size(),
  1113. [&](AbstractMemoryStream *pStream) {
  1114. patchConstantSigWriter.write(pStream);
  1115. });
  1116. }
  1117. // Write the DxilPipelineStateValidation (PSV0) part.
  1118. DxilRDATWriter RDATWriter(*pModule);
  1119. DxilPSVWriter PSVWriter(*pModule);
  1120. unsigned int major, minor;
  1121. pModule->GetDxilVersion(major, minor);
  1122. if (pModule->GetShaderModel()->IsLib() && (major >= 1 || minor == 1 && minor >= 3)) {
  1123. writer.AddPart(DFCC_RuntimeData, RDATWriter.size(), [&](AbstractMemoryStream *pStream) {
  1124. RDATWriter.write(pStream);
  1125. });
  1126. }
  1127. else {
  1128. writer.AddPart(DFCC_PipelineStateValidation, PSVWriter.size(), [&](AbstractMemoryStream *pStream) {
  1129. PSVWriter.write(pStream);
  1130. });
  1131. }
  1132. // Write the root signature (RTS0) part.
  1133. DxilProgramRootSignatureWriter rootSigWriter(pModule->GetRootSignature());
  1134. CComPtr<AbstractMemoryStream> pInputProgramStream = pModuleBitcode;
  1135. if (!pModule->GetRootSignature().IsEmpty()) {
  1136. writer.AddPart(
  1137. DFCC_RootSignature, rootSigWriter.size(),
  1138. [&](AbstractMemoryStream *pStream) { rootSigWriter.write(pStream); });
  1139. pModule->StripRootSignatureFromMetadata();
  1140. pInputProgramStream.Release();
  1141. IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pInputProgramStream));
  1142. raw_stream_ostream outStream(pInputProgramStream.p);
  1143. WriteBitcodeToFile(pModule->GetModule(), outStream, true);
  1144. }
  1145. // If we have debug information present, serialize it to a debug part, then use the stripped version as the canonical program version.
  1146. CComPtr<AbstractMemoryStream> pProgramStream = pInputProgramStream;
  1147. if (HasDebugInfo(*pModule->GetModule())) {
  1148. uint32_t debugInUInt32, debugPaddingBytes;
  1149. GetPaddedProgramPartSize(pInputProgramStream, debugInUInt32, debugPaddingBytes);
  1150. if (Flags & SerializeDxilFlags::IncludeDebugInfoPart) {
  1151. writer.AddPart(DFCC_ShaderDebugInfoDXIL, debugInUInt32 * sizeof(uint32_t) + sizeof(DxilProgramHeader), [&](AbstractMemoryStream *pStream) {
  1152. WriteProgramPart(pModule->GetShaderModel(), pInputProgramStream, pStream);
  1153. });
  1154. }
  1155. pProgramStream.Release();
  1156. llvm::StripDebugInfo(*pModule->GetModule());
  1157. pModule->StripDebugRelatedCode();
  1158. IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pProgramStream));
  1159. raw_stream_ostream outStream(pProgramStream.p);
  1160. WriteBitcodeToFile(pModule->GetModule(), outStream, true);
  1161. if (Flags & SerializeDxilFlags::IncludeDebugNamePart) {
  1162. CComPtr<AbstractMemoryStream> pHashStream;
  1163. // If the debug name should be specific to the sources, base the name on the debug
  1164. // bitcode, which will include the source references, line numbers, etc. Otherwise,
  1165. // do it exclusively on the target shader bitcode.
  1166. pHashStream = (int)(Flags & SerializeDxilFlags::DebugNameDependOnSource) ? pModuleBitcode : pProgramStream;
  1167. const uint32_t DebugInfoNameHashLen = 32; // 32 chars of MD5
  1168. const uint32_t DebugInfoNameSuffix = 4; // '.lld'
  1169. const uint32_t DebugInfoNameNullAndPad = 4; // '\0\0\0\0'
  1170. const uint32_t DebugInfoContentLen =
  1171. sizeof(DxilShaderDebugName) + DebugInfoNameHashLen +
  1172. DebugInfoNameSuffix + DebugInfoNameNullAndPad;
  1173. writer.AddPart(DFCC_ShaderDebugName, DebugInfoContentLen, [&](AbstractMemoryStream *pStream) {
  1174. DxilShaderDebugName NameContent;
  1175. NameContent.Flags = 0;
  1176. NameContent.NameLength = DebugInfoNameHashLen + DebugInfoNameSuffix;
  1177. IFT(WriteStreamValue(pStream, NameContent));
  1178. ArrayRef<uint8_t> Data((uint8_t *)pHashStream->GetPtr(), pHashStream->GetPtrSize());
  1179. llvm::MD5 md5;
  1180. llvm::MD5::MD5Result md5Result;
  1181. SmallString<32> Hash;
  1182. md5.update(Data);
  1183. md5.final(md5Result);
  1184. md5.stringifyResult(md5Result, Hash);
  1185. ULONG cbWritten;
  1186. IFT(pStream->Write(Hash.data(), Hash.size(), &cbWritten));
  1187. const char SuffixAndPad[] = ".lld\0\0\0";
  1188. IFT(pStream->Write(SuffixAndPad, _countof(SuffixAndPad), &cbWritten));
  1189. });
  1190. }
  1191. }
  1192. // Compute padded bitcode size.
  1193. uint32_t programInUInt32, programPaddingBytes;
  1194. GetPaddedProgramPartSize(pProgramStream, programInUInt32, programPaddingBytes);
  1195. // Write the program part.
  1196. writer.AddPart(DFCC_DXIL, programInUInt32 * sizeof(uint32_t) + sizeof(DxilProgramHeader), [&](AbstractMemoryStream *pStream) {
  1197. WriteProgramPart(pModule->GetShaderModel(), pProgramStream, pStream);
  1198. });
  1199. writer.write(pFinalStream);
  1200. }
  1201. void hlsl::SerializeDxilContainerForRootSignature(hlsl::RootSignatureHandle *pRootSigHandle,
  1202. AbstractMemoryStream *pFinalStream) {
  1203. DXASSERT_NOMSG(pRootSigHandle != nullptr);
  1204. DXASSERT_NOMSG(pFinalStream != nullptr);
  1205. DxilContainerWriter_impl writer;
  1206. // Write the root signature (RTS0) part.
  1207. DxilProgramRootSignatureWriter rootSigWriter(*pRootSigHandle);
  1208. if (!pRootSigHandle->IsEmpty()) {
  1209. writer.AddPart(
  1210. DFCC_RootSignature, rootSigWriter.size(),
  1211. [&](AbstractMemoryStream *pStream) { rootSigWriter.write(pStream); });
  1212. }
  1213. writer.write(pFinalStream);
  1214. }