ShaderProgramCompiler.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982
  1. // Copyright (C) 2009-2023, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderProgramCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderProgramParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/Util/Serializer.h>
  9. #include <AnKi/Util/HashMap.h>
  10. #include <SpirvCross/spirv_cross.hpp>
  11. #define ANKI_DIXL_REFLECTION (ANKI_OS_WINDOWS)
  12. #if ANKI_DIXL_REFLECTION
  13. # include <windows.h>
  14. # include <ThirdParty/Dxc/dxcapi.h>
  15. # include <ThirdParty/Dxc/d3d12shader.h>
  16. # include <wrl.h>
  17. # include <AnKi/Util/CleanupWindows.h>
  18. static HMODULE g_dxcLib = 0;
  19. static DxcCreateInstanceProc g_DxcCreateInstance = nullptr;
  20. #endif
  21. namespace anki {
  22. void freeShaderProgramBinary(ShaderProgramBinary*& binary)
  23. {
  24. if(binary == nullptr)
  25. {
  26. return;
  27. }
  28. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  29. for(ShaderProgramBinaryCodeBlock& code : binary->m_codeBlocks)
  30. {
  31. mempool.free(code.m_binary.getBegin());
  32. }
  33. mempool.free(binary->m_codeBlocks.getBegin());
  34. for(ShaderProgramBinaryMutator& mutator : binary->m_mutators)
  35. {
  36. mempool.free(mutator.m_values.getBegin());
  37. }
  38. mempool.free(binary->m_mutators.getBegin());
  39. for(ShaderProgramBinaryMutation& m : binary->m_mutations)
  40. {
  41. mempool.free(m.m_values.getBegin());
  42. }
  43. mempool.free(binary->m_mutations.getBegin());
  44. for(ShaderProgramBinaryVariant& variant : binary->m_variants)
  45. {
  46. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  47. }
  48. mempool.free(binary->m_variants.getBegin());
  49. mempool.free(binary->m_techniques.getBegin());
  50. for(ShaderProgramBinaryStruct& s : binary->m_structs)
  51. {
  52. mempool.free(s.m_members.getBegin());
  53. }
  54. mempool.free(binary->m_structs.getBegin());
  55. mempool.free(binary);
  56. binary = nullptr;
  57. }
  58. /// Spin the dials. Used to compute all mutator combinations.
  59. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderProgramParserMutator> mutators)
  60. {
  61. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  62. Bool done = true;
  63. U32 crntDial = dials.getSize() - 1;
  64. while(true)
  65. {
  66. // Turn dial
  67. ++dials[crntDial];
  68. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  69. {
  70. if(crntDial == 0)
  71. {
  72. // Reached the 1st dial, stop spinning
  73. done = true;
  74. break;
  75. }
  76. else
  77. {
  78. dials[crntDial] = 0;
  79. --crntDial;
  80. }
  81. }
  82. else
  83. {
  84. done = false;
  85. break;
  86. }
  87. }
  88. return done;
  89. }
  90. template<typename TFunc>
  91. static void visitSpirv(ConstWeakArray<U32> spv, TFunc func)
  92. {
  93. ANKI_ASSERT(spv.getSize() > 5);
  94. const U32* it = &spv[5];
  95. do
  96. {
  97. const U32 instructionCount = *it >> 16u;
  98. const U32 opcode = *it & 0xFFFFu;
  99. func(opcode);
  100. it += instructionCount;
  101. } while(it < spv.getEnd());
  102. ANKI_ASSERT(it == spv.getEnd());
  103. }
  104. static Error doReflectionSpirv(ConstWeakArray<U8> spirv, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  105. {
  106. spirv_cross::Compiler spvc(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32));
  107. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  108. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  109. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, const DescriptorType origType,
  110. const DescriptorFlag origFlags) -> Error {
  111. for(const spirv_cross::Resource& r : resources)
  112. {
  113. const U32 id = r.id;
  114. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  115. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  116. if(set >= kMaxDescriptorSets || binding >= kMaxBindingsPerDescriptorSet)
  117. {
  118. errorStr.sprintf("Exceeded set or binding for: %s", r.name.c_str());
  119. return Error::kUserData;
  120. }
  121. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  122. U32 arraySize = 1;
  123. if(typeInfo.array.size() != 0)
  124. {
  125. if(typeInfo.array.size() != 1 || (arraySize = typeInfo.array[0]) == 0)
  126. {
  127. errorStr.sprintf("Only 1D arrays are supported: %s", r.name.c_str());
  128. return Error::kUserData;
  129. }
  130. }
  131. refl.m_descriptorSetMask.set(set);
  132. // Images are special, they might be texel buffers
  133. DescriptorType type = origType;
  134. DescriptorFlag flags = origFlags;
  135. if(type == DescriptorType::kTexture)
  136. {
  137. if(typeInfo.image.dim == spv::DimBuffer)
  138. {
  139. type = DescriptorType::kTexelBuffer;
  140. if(typeInfo.image.sampled == 1)
  141. {
  142. flags = DescriptorFlag::kRead;
  143. }
  144. else
  145. {
  146. ANKI_ASSERT(typeInfo.image.sampled == 2);
  147. flags = DescriptorFlag::kReadWrite;
  148. }
  149. }
  150. }
  151. // Check that there are no other descriptors with the same binding
  152. if(refl.m_descriptorTypes[set][binding] == DescriptorType::kCount)
  153. {
  154. // New binding, init it
  155. refl.m_descriptorTypes[set][binding] = type;
  156. refl.m_descriptorArraySizes[set][binding] = U16(arraySize);
  157. refl.m_descriptorFlags[set][binding] = flags;
  158. }
  159. else
  160. {
  161. // Same binding, make sure the type is compatible
  162. if(refl.m_descriptorTypes[set][binding] != type || refl.m_descriptorArraySizes[set][binding] != arraySize
  163. || refl.m_descriptorFlags[set][binding] != flags)
  164. {
  165. errorStr.sprintf("Descriptor with same binding but different type or array size: %s", r.name.c_str());
  166. return Error::kUserData;
  167. }
  168. }
  169. }
  170. return Error::kNone;
  171. };
  172. Error err = Error::kNone;
  173. err = func(rsrc.uniform_buffers, DescriptorType::kUniformBuffer, DescriptorFlag::kRead);
  174. if(!err)
  175. {
  176. err = func(rsrc.separate_images, DescriptorType::kTexture, DescriptorFlag::kRead); // This also handles texel buffers
  177. }
  178. if(!err)
  179. {
  180. err = func(rsrc.separate_samplers, DescriptorType::kSampler, DescriptorFlag::kRead);
  181. }
  182. if(!err)
  183. {
  184. err = func(rsrc.storage_buffers, DescriptorType::kStorageBuffer, DescriptorFlag::kReadWrite);
  185. }
  186. if(!err)
  187. {
  188. err = func(rsrc.storage_images, DescriptorType::kTexture, DescriptorFlag::kReadWrite);
  189. }
  190. if(!err)
  191. {
  192. err = func(rsrc.acceleration_structures, DescriptorType::kAccelerationStructure, DescriptorFlag::kRead);
  193. }
  194. // Color attachments
  195. if(type == ShaderType::kFragment)
  196. {
  197. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  198. {
  199. const U32 id = r.id;
  200. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  201. refl.m_colorAttachmentWritemask.set(location);
  202. }
  203. }
  204. // Push consts
  205. if(rsrc.push_constant_buffers.size() == 1)
  206. {
  207. const U32 blockSize = U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  208. if(blockSize == 0 || (blockSize % 16) != 0 || blockSize > kMaxU8)
  209. {
  210. errorStr.sprintf("Incorrect push constants size");
  211. return Error::kUserData;
  212. }
  213. refl.m_pushConstantsSize = U8(blockSize);
  214. }
  215. // Attribs
  216. if(type == ShaderType::kVertex)
  217. {
  218. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  219. {
  220. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  221. #define ANKI_ATTRIB_NAME(x) "in.var." #x
  222. if(r.name == ANKI_ATTRIB_NAME(POSITION))
  223. {
  224. a = VertexAttributeSemantic::kPosition;
  225. }
  226. else if(r.name == ANKI_ATTRIB_NAME(NORMAL))
  227. {
  228. a = VertexAttributeSemantic::kNormal;
  229. }
  230. else if(r.name == ANKI_ATTRIB_NAME(TEXCOORD0) || r.name == ANKI_ATTRIB_NAME(TEXCOORD))
  231. {
  232. a = VertexAttributeSemantic::kTexCoord;
  233. }
  234. else if(r.name == ANKI_ATTRIB_NAME(COLOR))
  235. {
  236. a = VertexAttributeSemantic::kColor;
  237. }
  238. else if(r.name == ANKI_ATTRIB_NAME(MISC0) || r.name == ANKI_ATTRIB_NAME(MISC))
  239. {
  240. a = VertexAttributeSemantic::kMisc0;
  241. }
  242. else if(r.name == ANKI_ATTRIB_NAME(MISC1))
  243. {
  244. a = VertexAttributeSemantic::kMisc1;
  245. }
  246. else if(r.name == ANKI_ATTRIB_NAME(MISC2))
  247. {
  248. a = VertexAttributeSemantic::kMisc2;
  249. }
  250. else if(r.name == ANKI_ATTRIB_NAME(MISC3))
  251. {
  252. a = VertexAttributeSemantic::kMisc3;
  253. }
  254. else
  255. {
  256. errorStr.sprintf("Unexpected attribute name: %s", r.name.c_str());
  257. return Error::kUserData;
  258. }
  259. #undef ANKI_ATTRIB_NAME
  260. refl.m_vertexAttributeMask.set(a);
  261. const U32 id = r.id;
  262. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  263. if(location > kMaxU8)
  264. {
  265. errorStr.sprintf("Too high location value for attribute: %s", r.name.c_str());
  266. return Error::kUserData;
  267. }
  268. refl.m_vertexAttributeLocations[a] = U8(location);
  269. }
  270. }
  271. // Discards?
  272. if(type == ShaderType::kFragment)
  273. {
  274. visitSpirv(ConstWeakArray<U32>(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32)), [&](U32 cmd) {
  275. if(cmd == spv::OpKill)
  276. {
  277. refl.m_discards = true;
  278. }
  279. });
  280. }
  281. return Error::kNone;
  282. }
  283. #if ANKI_DIXL_REFLECTION
  284. # define ANKI_REFL_CHECK(x) \
  285. do \
  286. { \
  287. HRESULT rez; \
  288. if((rez = (x)) < 0) [[unlikely]] \
  289. { \
  290. errorStr.sprintf("DXC function failed (HRESULT: %d): %s", rez, #x); \
  291. return Error::kFunctionFailed; \
  292. } \
  293. } while(0)
  294. static Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  295. {
  296. using Microsoft::WRL::ComPtr;
  297. ComPtr<IDxcUtils> utils;
  298. ANKI_REFL_CHECK(g_DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&utils)));
  299. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  300. ComPtr<ID3D12ShaderReflection> dxRefl;
  301. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&dxRefl)));
  302. D3D12_SHADER_DESC shaderDesc = {};
  303. dxRefl->GetDesc(&shaderDesc);
  304. for(U32 i = 0; i < shaderDesc.BoundResources; ++i)
  305. {
  306. D3D12_SHADER_INPUT_BIND_DESC bind;
  307. ANKI_REFL_CHECK(dxRefl->GetResourceBindingDesc(i, &bind));
  308. ShaderReflectionBinding akBinding;
  309. akBinding.m_type = DescriptorType::kCount;
  310. akBinding.m_flags = DescriptorFlag::kNone;
  311. akBinding.m_arraySize = U16(bind.BindCount);
  312. if(bind.Type == D3D_SIT_CBUFFER)
  313. {
  314. // ConstantBuffer
  315. akBinding.m_type = DescriptorType::kUniformBuffer;
  316. akBinding.m_flags = DescriptorFlag::kRead;
  317. akBinding.m_semantic = BindingSemantic('b', bind.BindPoint);
  318. }
  319. else if(bind.Type == D3D_SIT_TEXTURE && bind.Dimension != D3D_SRV_DIMENSION_BUFFER)
  320. {
  321. // Texture2D etc
  322. akBinding.m_type = DescriptorType::kTexture;
  323. akBinding.m_flags = DescriptorFlag::kRead;
  324. akBinding.m_semantic = BindingSemantic('t', bind.BindPoint);
  325. }
  326. else if(bind.Type == D3D_SIT_TEXTURE && bind.Dimension == D3D_SRV_DIMENSION_BUFFER)
  327. {
  328. // Buffer
  329. akBinding.m_type = DescriptorType::kTexelBuffer;
  330. akBinding.m_flags = DescriptorFlag::kRead;
  331. akBinding.m_semantic = BindingSemantic('t', bind.BindPoint);
  332. }
  333. else if(bind.Type == D3D_SIT_SAMPLER)
  334. {
  335. // SamplerState
  336. akBinding.m_type = DescriptorType::kSampler;
  337. akBinding.m_flags = DescriptorFlag::kRead;
  338. akBinding.m_semantic = BindingSemantic('s', bind.BindPoint);
  339. }
  340. else if(bind.Type == D3D_SIT_UAV_RWTYPED && bind.Dimension == D3D_SRV_DIMENSION_BUFFER)
  341. {
  342. // RWBuffer
  343. akBinding.m_type = DescriptorType::kTexelBuffer;
  344. akBinding.m_flags = DescriptorFlag::kReadWrite;
  345. akBinding.m_semantic = BindingSemantic('u', bind.BindPoint);
  346. }
  347. else if(bind.Type == D3D_SIT_UAV_RWTYPED && bind.Dimension != D3D_SRV_DIMENSION_BUFFER)
  348. {
  349. // RWTexture2D etc
  350. akBinding.m_type = DescriptorType::kTexture;
  351. akBinding.m_flags = DescriptorFlag::kReadWrite;
  352. akBinding.m_semantic = BindingSemantic('u', bind.BindPoint);
  353. }
  354. else if(bind.Type == D3D_SIT_BYTEADDRESS)
  355. {
  356. // ByteAddressBuffer
  357. akBinding.m_type = DescriptorType::kStorageBuffer;
  358. akBinding.m_flags = DescriptorFlag::kRead | DescriptorFlag::kByteAddressBuffer;
  359. akBinding.m_semantic = BindingSemantic('t', bind.BindPoint);
  360. }
  361. else if(bind.Type == D3D_SIT_UAV_RWBYTEADDRESS)
  362. {
  363. // RWByteAddressBuffer
  364. akBinding.m_type = DescriptorType::kStorageBuffer;
  365. akBinding.m_flags = DescriptorFlag::kReadWrite | DescriptorFlag::kByteAddressBuffer;
  366. akBinding.m_semantic = BindingSemantic('u', bind.BindPoint);
  367. }
  368. else if(bind.Type == D3D_SIT_RTACCELERATIONSTRUCTURE)
  369. {
  370. // RaytracingAccelerationStructure
  371. akBinding.m_type = DescriptorType::kAccelerationStructure;
  372. akBinding.m_flags = DescriptorFlag::kRead;
  373. akBinding.m_semantic = BindingSemantic('t', bind.BindPoint);
  374. }
  375. else
  376. {
  377. ANKI_SHADER_COMPILER_LOGE("Unrecognized type for binding: %s", bind.Name);
  378. return Error::kUserData;
  379. }
  380. refl.m_bindings[bind.Space][refl.m_bindingCounts[bind.Space++]] = akBinding;
  381. }
  382. if(type == ShaderType::kVertex)
  383. {
  384. for(U32 i = 0; i < shaderDesc.InputParameters; ++i)
  385. {
  386. D3D12_SIGNATURE_PARAMETER_DESC in;
  387. ANKI_REFL_CHECK(dxRefl->GetInputParameterDesc(i, &in));
  388. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  389. # define ANKI_ATTRIB_NAME(x, idx) CString(in.SemanticName) == # x&& in.SemanticIndex == idx
  390. if(ANKI_ATTRIB_NAME(POSITION, 0))
  391. {
  392. a = VertexAttributeSemantic::kPosition;
  393. }
  394. else if(ANKI_ATTRIB_NAME(NORMAL, 0))
  395. {
  396. a = VertexAttributeSemantic::kNormal;
  397. }
  398. else if(ANKI_ATTRIB_NAME(TEXCOORD, 0))
  399. {
  400. a = VertexAttributeSemantic::kTexCoord;
  401. }
  402. else if(ANKI_ATTRIB_NAME(COLOR, 0))
  403. {
  404. a = VertexAttributeSemantic::kColor;
  405. }
  406. else if(ANKI_ATTRIB_NAME(MISC, 0))
  407. {
  408. a = VertexAttributeSemantic::kMisc0;
  409. }
  410. else if(ANKI_ATTRIB_NAME(MISC, 1))
  411. {
  412. a = VertexAttributeSemantic::kMisc1;
  413. }
  414. else if(ANKI_ATTRIB_NAME(MISC, 2))
  415. {
  416. a = VertexAttributeSemantic::kMisc2;
  417. }
  418. else if(ANKI_ATTRIB_NAME(MISC, 3))
  419. {
  420. a = VertexAttributeSemantic::kMisc3;
  421. }
  422. else if(ANKI_ATTRIB_NAME(SV_VERTEXID, 0))
  423. {
  424. // Ignore
  425. }
  426. else
  427. {
  428. errorStr.sprintf("Unexpected attribute name: %s", in.SemanticName);
  429. return Error::kUserData;
  430. }
  431. # undef ANKI_ATTRIB_NAME
  432. refl.m_vertexAttributeMask.set(a);
  433. refl.m_vertexAttributeLocations[a] = U8(i);
  434. }
  435. }
  436. if(type == ShaderType::kFragment)
  437. {
  438. for(U32 i = 0; i < shaderDesc.OutputParameters; ++i)
  439. {
  440. D3D12_SIGNATURE_PARAMETER_DESC desc;
  441. ANKI_REFL_CHECK(dxRefl->GetOutputParameterDesc(i, &desc));
  442. if(CString(desc.SemanticName) == "SV_TARGET")
  443. {
  444. refl.m_colorAttachmentWritemask.set(desc.SemanticIndex);
  445. }
  446. }
  447. }
  448. return Error::kNone;
  449. }
  450. #endif // #if ANKI_DIXL_REFLECTION
  451. static void compileVariantAsync(const ShaderProgramParser& parser, Bool spirv, ShaderProgramBinaryMutation& mutation,
  452. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant>& variants,
  453. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock>& codeBlocks,
  454. ShaderCompilerDynamicArray<U64>& sourceCodeHashes, ShaderProgramAsyncTaskInterface& taskManager, Mutex& mtx,
  455. Atomic<I32>& error)
  456. {
  457. class Ctx
  458. {
  459. public:
  460. const ShaderProgramParser* m_parser;
  461. ShaderProgramBinaryMutation* m_mutation;
  462. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant>* m_variants;
  463. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock>* m_codeBlocks;
  464. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  465. Mutex* m_mtx;
  466. Atomic<I32>* m_err;
  467. Bool m_spirv;
  468. };
  469. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  470. ctx->m_parser = &parser;
  471. ctx->m_mutation = &mutation;
  472. ctx->m_variants = &variants;
  473. ctx->m_codeBlocks = &codeBlocks;
  474. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  475. ctx->m_mtx = &mtx;
  476. ctx->m_err = &error;
  477. ctx->m_spirv = spirv;
  478. auto callback = [](void* userData) {
  479. Ctx& ctx = *static_cast<Ctx*>(userData);
  480. class Cleanup
  481. {
  482. public:
  483. Ctx* m_ctx;
  484. ~Cleanup()
  485. {
  486. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  487. }
  488. } cleanup{&ctx};
  489. if(ctx.m_err->load() != 0)
  490. {
  491. // Don't bother
  492. return;
  493. }
  494. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  495. // Compile the sources
  496. ShaderCompilerDynamicArray<ShaderProgramBinaryTechniqueCodeBlocks> codeBlockIndices;
  497. codeBlockIndices.resize(techniqueCount);
  498. for(auto& it : codeBlockIndices)
  499. {
  500. it.m_codeBlockIndices.fill(kMaxU32);
  501. }
  502. ShaderCompilerString compilerErrorLog;
  503. Error err = Error::kNone;
  504. U newCodeBlockCount = 0;
  505. for(U32 t = 0; t < techniqueCount && !err; ++t)
  506. {
  507. const ShaderProgramParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  508. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  509. {
  510. ShaderCompilerString source;
  511. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  512. // Check if the source code was found before
  513. const U64 sourceCodeHash = source.computeHash();
  514. if(technique.m_activeMutators[shaderType] != kMaxU64)
  515. {
  516. LockGuard lock(*ctx.m_mtx);
  517. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  518. Bool found = false;
  519. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  520. {
  521. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  522. {
  523. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  524. found = true;
  525. break;
  526. }
  527. }
  528. if(found)
  529. {
  530. continue;
  531. }
  532. }
  533. ShaderCompilerDynamicArray<U8> il;
  534. if(ctx.m_spirv)
  535. {
  536. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), il, compilerErrorLog);
  537. }
  538. else
  539. {
  540. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), il, compilerErrorLog);
  541. }
  542. if(err)
  543. {
  544. break;
  545. }
  546. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  547. ShaderReflection refl;
  548. if(ctx.m_spirv)
  549. {
  550. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  551. }
  552. else
  553. {
  554. #if ANKI_DIXL_REFLECTION
  555. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  556. #else
  557. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  558. err = Error::kFunctionFailed;
  559. #endif
  560. }
  561. if(err)
  562. {
  563. break;
  564. }
  565. // Add the binary if not already there
  566. {
  567. LockGuard lock(*ctx.m_mtx);
  568. Bool found = false;
  569. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  570. {
  571. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  572. {
  573. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  574. found = true;
  575. break;
  576. }
  577. }
  578. if(!found)
  579. {
  580. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  581. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  582. il.moveAndReset(codeBlock.m_binary);
  583. codeBlock.m_hash = newHash;
  584. codeBlock.m_reflection = refl;
  585. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  586. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  587. ++newCodeBlockCount;
  588. }
  589. }
  590. }
  591. }
  592. if(err)
  593. {
  594. I32 expectedErr = 0;
  595. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  596. if(isFirstError)
  597. {
  598. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  599. return;
  600. }
  601. return;
  602. }
  603. // Do variant stuff
  604. {
  605. LockGuard lock(*ctx.m_mtx);
  606. Bool createVariant = true;
  607. if(newCodeBlockCount == 0)
  608. {
  609. // No new code blocks generated, search all variants to see if there is a duplicate
  610. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  611. {
  612. Bool same = true;
  613. for(U32 t = 0; t < techniqueCount; ++t)
  614. {
  615. const ShaderProgramBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  616. const ShaderProgramBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  617. if(memcmp(&a, &b, sizeof(a)) != 0)
  618. {
  619. // Not the same
  620. same = false;
  621. break;
  622. }
  623. }
  624. if(same)
  625. {
  626. createVariant = false;
  627. ctx.m_mutation->m_variantIndex = i;
  628. break;
  629. }
  630. }
  631. }
  632. // Create a new variant
  633. if(createVariant)
  634. {
  635. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  636. ShaderProgramBinaryVariant* variant = ctx.m_variants->emplaceBack();
  637. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  638. }
  639. }
  640. };
  641. taskManager.enqueueTask(callback, ctx);
  642. }
  643. static Error compileShaderProgramInternal(CString fname, Bool spirv, ShaderProgramFilesystemInterface& fsystem,
  644. ShaderProgramPostParseInterface* postParseCallback, ShaderProgramAsyncTaskInterface* taskManager_,
  645. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderProgramBinary*& binary)
  646. {
  647. #if ANKI_DIXL_REFLECTION
  648. if(!spirv)
  649. {
  650. // Init DXC
  651. g_dxcLib = LoadLibraryA(ANKI_SOURCE_DIRECTORY "/ThirdParty/Bin/Windows64/dxcompiler.dll");
  652. if(g_dxcLib == 0)
  653. {
  654. ANKI_SHADER_COMPILER_LOGE("dxcompiler.dll missing or wrong architecture");
  655. return Error::kFunctionFailed;
  656. }
  657. g_DxcCreateInstance = reinterpret_cast<DxcCreateInstanceProc>(GetProcAddress(g_dxcLib, "DxcCreateInstance"));
  658. if(g_DxcCreateInstance == nullptr)
  659. {
  660. ANKI_SHADER_COMPILER_LOGE("DxcCreateInstance was not found in the dxcompiler.dll");
  661. return Error::kFunctionFailed;
  662. }
  663. }
  664. #endif
  665. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  666. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  667. for(const ShaderCompilerDefine& d : defines_)
  668. {
  669. defines.emplaceBack(d);
  670. }
  671. defines.emplaceBack(ShaderCompilerDefine{"ANKI_TARGET_SPIRV", spirv});
  672. defines.emplaceBack(ShaderCompilerDefine{"ANKI_TARGET_DXIL", !spirv});
  673. // Initialize the binary
  674. binary = newInstance<ShaderProgramBinary>(memPool);
  675. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  676. // Parse source
  677. ShaderProgramParser parser(fname, &fsystem, defines);
  678. ANKI_CHECK(parser.parse());
  679. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  680. {
  681. return Error::kNone;
  682. }
  683. // Get mutators
  684. U32 mutationCount = 0;
  685. if(parser.getMutators().getSize() > 0)
  686. {
  687. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  688. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  689. {
  690. ShaderProgramBinaryMutator& out = binary->m_mutators[i];
  691. const ShaderProgramParserMutator& in = parser.getMutators()[i];
  692. zeroMemory(out);
  693. newArray(memPool, in.m_values.getSize(), out.m_values);
  694. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  695. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  696. // Update the count
  697. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  698. }
  699. }
  700. else
  701. {
  702. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  703. }
  704. // Create all variants
  705. Mutex mtx;
  706. Atomic<I32> errorAtomic(0);
  707. class SyncronousShaderProgramAsyncTaskInterface : public ShaderProgramAsyncTaskInterface
  708. {
  709. public:
  710. void enqueueTask(void (*callback)(void* userData), void* userData) final
  711. {
  712. callback(userData);
  713. }
  714. Error joinTasks() final
  715. {
  716. // Nothing
  717. return Error::kNone;
  718. }
  719. } syncTaskManager;
  720. ShaderProgramAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  721. if(parser.getMutators().getSize() > 0)
  722. {
  723. // Initialize
  724. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  725. mutationValues.resize(parser.getMutators().getSize());
  726. ShaderCompilerDynamicArray<U32> dials;
  727. dials.resize(parser.getMutators().getSize(), 0);
  728. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant> variants;
  729. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock> codeBlocks;
  730. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  731. ShaderCompilerDynamicArray<ShaderProgramBinaryMutation> mutations;
  732. mutations.resize(mutationCount);
  733. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  734. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  735. variants.resizeStorage(mutationCount);
  736. mutationCount = 0;
  737. // Spin for all possible combinations of mutators and
  738. // - Create the spirv
  739. // - Populate the binary variant
  740. do
  741. {
  742. // Create the mutation
  743. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  744. {
  745. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  746. }
  747. ShaderProgramBinaryMutation& mutation = mutations[mutationCount++];
  748. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  749. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  750. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  751. ANKI_ASSERT(mutation.m_hash > 0);
  752. if(parser.skipMutation(mutationValues))
  753. {
  754. mutation.m_variantIndex = kMaxU32;
  755. }
  756. else
  757. {
  758. // New and unique mutation and thus variant, add it
  759. compileVariantAsync(parser, spirv, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  760. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  761. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  762. }
  763. } while(!spinDials(dials, parser.getMutators()));
  764. ANKI_ASSERT(mutationCount == mutations.getSize());
  765. // Done, wait the threads
  766. ANKI_CHECK(taskManager.joinTasks());
  767. // Now error out
  768. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  769. // Store temp containers to binary
  770. codeBlocks.moveAndReset(binary->m_codeBlocks);
  771. mutations.moveAndReset(binary->m_mutations);
  772. variants.moveAndReset(binary->m_variants);
  773. }
  774. else
  775. {
  776. newArray(memPool, 1, binary->m_mutations);
  777. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant> variants;
  778. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock> codeBlocks;
  779. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  780. compileVariantAsync(parser, spirv, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  781. ANKI_CHECK(taskManager.joinTasks());
  782. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  783. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  784. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  785. ANKI_ASSERT(variants.getSize() == 1);
  786. binary->m_mutations[0].m_hash = 1;
  787. codeBlocks.moveAndReset(binary->m_codeBlocks);
  788. variants.moveAndReset(binary->m_variants);
  789. }
  790. // Sort the mutations
  791. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(),
  792. [](const ShaderProgramBinaryMutation& a, const ShaderProgramBinaryMutation& b) {
  793. return a.m_hash < b.m_hash;
  794. });
  795. // Techniques
  796. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  797. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  798. {
  799. zeroMemory(binary->m_techniques[i].m_name);
  800. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  801. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  802. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  803. }
  804. // Structs
  805. if(parser.getGhostStructs().getSize())
  806. {
  807. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  808. }
  809. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  810. {
  811. const ShaderProgramParserGhostStruct& in = parser.getGhostStructs()[i];
  812. ShaderProgramBinaryStruct& out = binary->m_structs[i];
  813. zeroMemory(out);
  814. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  815. ANKI_ASSERT(in.m_members.getSize());
  816. newArray(memPool, in.m_members.getSize(), out.m_members);
  817. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  818. {
  819. const ShaderProgramParserMember& inm = in.m_members[j];
  820. ShaderProgramBinaryStructMember& outm = out.m_members[j];
  821. zeroMemory(outm.m_name);
  822. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  823. outm.m_offset = inm.m_offset;
  824. outm.m_type = inm.m_type;
  825. }
  826. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  827. }
  828. return Error::kNone;
  829. }
  830. Error compileShaderProgram(CString fname, Bool spirv, ShaderProgramFilesystemInterface& fsystem, ShaderProgramPostParseInterface* postParseCallback,
  831. ShaderProgramAsyncTaskInterface* taskManager, ConstWeakArray<ShaderCompilerDefine> defines, ShaderProgramBinary*& binary)
  832. {
  833. const Error err = compileShaderProgramInternal(fname, spirv, fsystem, postParseCallback, taskManager, defines, binary);
  834. if(err)
  835. {
  836. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  837. freeShaderProgramBinary(binary);
  838. }
  839. return err;
  840. }
  841. } // end namespace anki