ShaderCompiler.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/ShaderCompiler/Spirv.h>
  9. #include <AnKi/Util/Serializer.h>
  10. #include <AnKi/Util/HashMap.h>
  11. namespace anki {
  12. void freeShaderBinary(ShaderBinary*& binary)
  13. {
  14. if(binary == nullptr)
  15. {
  16. return;
  17. }
  18. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  19. for(ShaderBinaryCodeBlock& code : binary->m_codeBlocks)
  20. {
  21. mempool.free(code.m_binary.getBegin());
  22. }
  23. mempool.free(binary->m_codeBlocks.getBegin());
  24. for(ShaderBinaryMutator& mutator : binary->m_mutators)
  25. {
  26. mempool.free(mutator.m_values.getBegin());
  27. }
  28. mempool.free(binary->m_mutators.getBegin());
  29. for(ShaderBinaryMutation& m : binary->m_mutations)
  30. {
  31. mempool.free(m.m_values.getBegin());
  32. }
  33. mempool.free(binary->m_mutations.getBegin());
  34. for(ShaderBinaryVariant& variant : binary->m_variants)
  35. {
  36. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  37. }
  38. mempool.free(binary->m_variants.getBegin());
  39. mempool.free(binary->m_techniques.getBegin());
  40. for(ShaderBinaryStruct& s : binary->m_structs)
  41. {
  42. mempool.free(s.m_members.getBegin());
  43. }
  44. mempool.free(binary->m_structs.getBegin());
  45. mempool.free(binary);
  46. binary = nullptr;
  47. }
  48. /// Spin the dials. Used to compute all mutator combinations.
  49. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderParserMutator> mutators)
  50. {
  51. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  52. Bool done = true;
  53. U32 crntDial = dials.getSize() - 1;
  54. while(true)
  55. {
  56. // Turn dial
  57. ++dials[crntDial];
  58. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  59. {
  60. if(crntDial == 0)
  61. {
  62. // Reached the 1st dial, stop spinning
  63. done = true;
  64. break;
  65. }
  66. else
  67. {
  68. dials[crntDial] = 0;
  69. --crntDial;
  70. }
  71. }
  72. else
  73. {
  74. done = false;
  75. break;
  76. }
  77. }
  78. return done;
  79. }
  80. static void compileVariantAsync(const ShaderParser& parser, Bool spirv, Bool debugInfo, ShaderModel sm, ShaderBinaryMutation& mutation,
  81. ShaderCompilerDynamicArray<ShaderBinaryVariant>& variants,
  82. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>& codeBlocks, ShaderCompilerDynamicArray<U64>& sourceCodeHashes,
  83. ShaderCompilerAsyncTaskInterface& taskManager, Mutex& mtx, Atomic<I32>& error)
  84. {
  85. class Ctx
  86. {
  87. public:
  88. const ShaderParser* m_parser;
  89. ShaderBinaryMutation* m_mutation;
  90. ShaderCompilerDynamicArray<ShaderBinaryVariant>* m_variants;
  91. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>* m_codeBlocks;
  92. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  93. Mutex* m_mtx;
  94. Atomic<I32>* m_err;
  95. Bool m_spirv;
  96. Bool m_debugInfo;
  97. ShaderModel m_sm;
  98. };
  99. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  100. ctx->m_parser = &parser;
  101. ctx->m_mutation = &mutation;
  102. ctx->m_variants = &variants;
  103. ctx->m_codeBlocks = &codeBlocks;
  104. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  105. ctx->m_mtx = &mtx;
  106. ctx->m_err = &error;
  107. ctx->m_spirv = spirv;
  108. ctx->m_debugInfo = debugInfo;
  109. ctx->m_sm = sm;
  110. auto callback = [](void* userData) {
  111. Ctx& ctx = *static_cast<Ctx*>(userData);
  112. class Cleanup
  113. {
  114. public:
  115. Ctx* m_ctx;
  116. ~Cleanup()
  117. {
  118. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  119. }
  120. } cleanup{&ctx};
  121. if(ctx.m_err->load() != 0)
  122. {
  123. // Don't bother
  124. return;
  125. }
  126. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  127. // Compile the sources
  128. ShaderCompilerDynamicArray<ShaderBinaryTechniqueCodeBlocks> codeBlockIndices;
  129. codeBlockIndices.resize(techniqueCount);
  130. for(auto& it : codeBlockIndices)
  131. {
  132. it.m_codeBlockIndices.fill(kMaxU32);
  133. }
  134. ShaderCompilerString compilerErrorLog;
  135. Error err = Error::kNone;
  136. U newCodeBlockCount = 0;
  137. for(U32 t = 0; t < techniqueCount && !err; ++t)
  138. {
  139. const ShaderParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  140. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  141. {
  142. ShaderCompilerString source;
  143. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  144. // Check if the source code was found before
  145. const U64 sourceCodeHash = source.computeHash();
  146. if(technique.m_activeMutators[shaderType] != kMaxU64)
  147. {
  148. LockGuard lock(*ctx.m_mtx);
  149. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  150. Bool found = false;
  151. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  152. {
  153. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  154. {
  155. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  156. found = true;
  157. break;
  158. }
  159. }
  160. if(found)
  161. {
  162. continue;
  163. }
  164. }
  165. ShaderCompilerDynamicArray<U8> il;
  166. if(ctx.m_spirv)
  167. {
  168. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo, ctx.m_sm,
  169. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  170. }
  171. else
  172. {
  173. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo, ctx.m_sm,
  174. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  175. }
  176. if(err)
  177. {
  178. break;
  179. }
  180. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  181. ShaderReflection refl;
  182. if(ctx.m_spirv)
  183. {
  184. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  185. }
  186. else
  187. {
  188. #if ANKI_OS_WINDOWS
  189. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  190. #else
  191. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  192. err = Error::kFunctionFailed;
  193. #endif
  194. }
  195. if(err)
  196. {
  197. break;
  198. }
  199. // Add the binary if not already there
  200. {
  201. LockGuard lock(*ctx.m_mtx);
  202. Bool found = false;
  203. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  204. {
  205. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  206. {
  207. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  208. found = true;
  209. break;
  210. }
  211. }
  212. if(!found)
  213. {
  214. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  215. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  216. il.moveAndReset(codeBlock.m_binary);
  217. codeBlock.m_hash = newHash;
  218. codeBlock.m_reflection = refl;
  219. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  220. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  221. ++newCodeBlockCount;
  222. }
  223. }
  224. }
  225. }
  226. if(err)
  227. {
  228. I32 expectedErr = 0;
  229. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  230. if(isFirstError)
  231. {
  232. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  233. return;
  234. }
  235. return;
  236. }
  237. // Do variant stuff
  238. {
  239. LockGuard lock(*ctx.m_mtx);
  240. Bool createVariant = true;
  241. if(newCodeBlockCount == 0)
  242. {
  243. // No new code blocks generated, search all variants to see if there is a duplicate
  244. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  245. {
  246. Bool same = true;
  247. for(U32 t = 0; t < techniqueCount; ++t)
  248. {
  249. const ShaderBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  250. const ShaderBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  251. if(memcmp(&a, &b, sizeof(a)) != 0)
  252. {
  253. // Not the same
  254. same = false;
  255. break;
  256. }
  257. }
  258. if(same)
  259. {
  260. createVariant = false;
  261. ctx.m_mutation->m_variantIndex = i;
  262. break;
  263. }
  264. }
  265. }
  266. // Create a new variant
  267. if(createVariant)
  268. {
  269. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  270. ShaderBinaryVariant* variant = ctx.m_variants->emplaceBack();
  271. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  272. }
  273. }
  274. };
  275. taskManager.enqueueTask(callback, ctx);
  276. }
  277. static Error compileShaderProgramInternal(CString fname, Bool spirv, Bool debugInfo, ShaderModel sm, ShaderCompilerFilesystemInterface& fsystem,
  278. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager_,
  279. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderBinary*& binary)
  280. {
  281. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  282. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  283. for(const ShaderCompilerDefine& d : defines_)
  284. {
  285. defines.emplaceBack(d);
  286. }
  287. // Initialize the binary
  288. binary = newInstance<ShaderBinary>(memPool);
  289. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  290. // Parse source
  291. ShaderParser parser(fname, &fsystem, defines);
  292. ANKI_CHECK(parser.parse());
  293. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  294. {
  295. return Error::kNone;
  296. }
  297. // Get mutators
  298. U32 mutationCount = 0;
  299. if(parser.getMutators().getSize() > 0)
  300. {
  301. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  302. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  303. {
  304. ShaderBinaryMutator& out = binary->m_mutators[i];
  305. const ShaderParserMutator& in = parser.getMutators()[i];
  306. zeroMemory(out);
  307. newArray(memPool, in.m_values.getSize(), out.m_values);
  308. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  309. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  310. // Update the count
  311. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  312. }
  313. }
  314. else
  315. {
  316. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  317. }
  318. // Create all variants
  319. Mutex mtx;
  320. Atomic<I32> errorAtomic(0);
  321. class SyncronousShaderCompilerAsyncTaskInterface : public ShaderCompilerAsyncTaskInterface
  322. {
  323. public:
  324. void enqueueTask(void (*callback)(void* userData), void* userData) final
  325. {
  326. callback(userData);
  327. }
  328. Error joinTasks() final
  329. {
  330. // Nothing
  331. return Error::kNone;
  332. }
  333. } syncTaskManager;
  334. ShaderCompilerAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  335. if(parser.getMutators().getSize() > 0)
  336. {
  337. // Initialize
  338. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  339. mutationValues.resize(parser.getMutators().getSize());
  340. ShaderCompilerDynamicArray<U32> dials;
  341. dials.resize(parser.getMutators().getSize(), 0);
  342. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  343. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  344. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  345. ShaderCompilerDynamicArray<ShaderBinaryMutation> mutations;
  346. mutations.resize(mutationCount);
  347. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  348. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  349. variants.resizeStorage(mutationCount);
  350. mutationCount = 0;
  351. // Spin for all possible combinations of mutators and
  352. // - Create the spirv
  353. // - Populate the binary variant
  354. do
  355. {
  356. // Create the mutation
  357. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  358. {
  359. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  360. }
  361. ShaderBinaryMutation& mutation = mutations[mutationCount++];
  362. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  363. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  364. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  365. ANKI_ASSERT(mutation.m_hash > 0);
  366. if(parser.skipMutation(mutationValues))
  367. {
  368. mutation.m_variantIndex = kMaxU32;
  369. }
  370. else
  371. {
  372. // New and unique mutation and thus variant, add it
  373. compileVariantAsync(parser, spirv, debugInfo, sm, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  374. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  375. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  376. }
  377. } while(!spinDials(dials, parser.getMutators()));
  378. ANKI_ASSERT(mutationCount == mutations.getSize());
  379. // Done, wait the threads
  380. ANKI_CHECK(taskManager.joinTasks());
  381. // Now error out
  382. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  383. // Store temp containers to binary
  384. codeBlocks.moveAndReset(binary->m_codeBlocks);
  385. mutations.moveAndReset(binary->m_mutations);
  386. variants.moveAndReset(binary->m_variants);
  387. }
  388. else
  389. {
  390. newArray(memPool, 1, binary->m_mutations);
  391. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  392. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  393. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  394. compileVariantAsync(parser, spirv, debugInfo, sm, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx,
  395. errorAtomic);
  396. ANKI_CHECK(taskManager.joinTasks());
  397. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  398. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  399. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  400. ANKI_ASSERT(variants.getSize() == 1);
  401. binary->m_mutations[0].m_hash = 1;
  402. codeBlocks.moveAndReset(binary->m_codeBlocks);
  403. variants.moveAndReset(binary->m_variants);
  404. }
  405. // Sort the mutations
  406. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(), [](const ShaderBinaryMutation& a, const ShaderBinaryMutation& b) {
  407. return a.m_hash < b.m_hash;
  408. });
  409. // Techniques
  410. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  411. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  412. {
  413. zeroMemory(binary->m_techniques[i].m_name);
  414. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  415. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  416. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  417. }
  418. // Structs
  419. if(parser.getGhostStructs().getSize())
  420. {
  421. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  422. }
  423. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  424. {
  425. const ShaderParserGhostStruct& in = parser.getGhostStructs()[i];
  426. ShaderBinaryStruct& out = binary->m_structs[i];
  427. zeroMemory(out);
  428. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  429. ANKI_ASSERT(in.m_members.getSize());
  430. newArray(memPool, in.m_members.getSize(), out.m_members);
  431. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  432. {
  433. const ShaderParserGhostStructMember& inm = in.m_members[j];
  434. ShaderBinaryStructMember& outm = out.m_members[j];
  435. zeroMemory(outm.m_name);
  436. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  437. outm.m_offset = inm.m_offset;
  438. outm.m_type = inm.m_type;
  439. }
  440. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  441. }
  442. return Error::kNone;
  443. }
  444. Error compileShaderProgram(CString fname, Bool spirv, Bool debugInfo, ShaderModel sm, ShaderCompilerFilesystemInterface& fsystem,
  445. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager,
  446. ConstWeakArray<ShaderCompilerDefine> defines, ShaderBinary*& binary)
  447. {
  448. const Error err = compileShaderProgramInternal(fname, spirv, debugInfo, sm, fsystem, postParseCallback, taskManager, defines, binary);
  449. if(err)
  450. {
  451. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  452. freeShaderBinary(binary);
  453. }
  454. return err;
  455. }
  456. } // end namespace anki