ShaderProgramCompiler.cpp 29 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060
  1. // Copyright (C) 2009-2023, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderProgramCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderProgramParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/Util/Serializer.h>
  9. #include <AnKi/Util/HashMap.h>
  10. #include <SpirvCross/spirv_cross.hpp>
  11. #define ANKI_DIXL_REFLECTION (ANKI_OS_WINDOWS)
  12. #if ANKI_DIXL_REFLECTION
  13. # include <windows.h>
  14. # include <ThirdParty/Dxc/dxcapi.h>
  15. # include <ThirdParty/Dxc/d3d12shader.h>
  16. # include <wrl.h>
  17. # include <AnKi/Util/CleanupWindows.h>
  18. #endif
  19. namespace anki {
  20. #if ANKI_DIXL_REFLECTION
  21. static HMODULE g_dxcLib = 0;
  22. static DxcCreateInstanceProc g_DxcCreateInstance = nullptr;
  23. static Mutex g_dxcLibMtx;
  24. #endif
  25. void freeShaderProgramBinary(ShaderProgramBinary*& binary)
  26. {
  27. if(binary == nullptr)
  28. {
  29. return;
  30. }
  31. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  32. for(ShaderProgramBinaryCodeBlock& code : binary->m_codeBlocks)
  33. {
  34. mempool.free(code.m_binary.getBegin());
  35. }
  36. mempool.free(binary->m_codeBlocks.getBegin());
  37. for(ShaderProgramBinaryMutator& mutator : binary->m_mutators)
  38. {
  39. mempool.free(mutator.m_values.getBegin());
  40. }
  41. mempool.free(binary->m_mutators.getBegin());
  42. for(ShaderProgramBinaryMutation& m : binary->m_mutations)
  43. {
  44. mempool.free(m.m_values.getBegin());
  45. }
  46. mempool.free(binary->m_mutations.getBegin());
  47. for(ShaderProgramBinaryVariant& variant : binary->m_variants)
  48. {
  49. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  50. }
  51. mempool.free(binary->m_variants.getBegin());
  52. mempool.free(binary->m_techniques.getBegin());
  53. for(ShaderProgramBinaryStruct& s : binary->m_structs)
  54. {
  55. mempool.free(s.m_members.getBegin());
  56. }
  57. mempool.free(binary->m_structs.getBegin());
  58. mempool.free(binary);
  59. binary = nullptr;
  60. }
  61. /// Spin the dials. Used to compute all mutator combinations.
  62. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderProgramParserMutator> mutators)
  63. {
  64. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  65. Bool done = true;
  66. U32 crntDial = dials.getSize() - 1;
  67. while(true)
  68. {
  69. // Turn dial
  70. ++dials[crntDial];
  71. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  72. {
  73. if(crntDial == 0)
  74. {
  75. // Reached the 1st dial, stop spinning
  76. done = true;
  77. break;
  78. }
  79. else
  80. {
  81. dials[crntDial] = 0;
  82. --crntDial;
  83. }
  84. }
  85. else
  86. {
  87. done = false;
  88. break;
  89. }
  90. }
  91. return done;
  92. }
  93. template<typename TFunc>
  94. static void visitSpirv(ConstWeakArray<U32> spv, TFunc func)
  95. {
  96. ANKI_ASSERT(spv.getSize() > 5);
  97. const U32* it = &spv[5];
  98. do
  99. {
  100. const U32 instructionCount = *it >> 16u;
  101. const U32 opcode = *it & 0xFFFFu;
  102. func(opcode);
  103. it += instructionCount;
  104. } while(it < spv.getEnd());
  105. ANKI_ASSERT(it == spv.getEnd());
  106. }
  107. Error doReflectionSpirv(ConstWeakArray<U8> spirv, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  108. {
  109. spirv_cross::Compiler spvc(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32));
  110. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  111. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  112. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, const DescriptorType origType,
  113. const DescriptorFlag origFlags) -> Error {
  114. for(const spirv_cross::Resource& r : resources)
  115. {
  116. const U32 id = r.id;
  117. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  118. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  119. if(set >= kMaxDescriptorSets || binding >= kMaxBindingsPerDescriptorSet)
  120. {
  121. errorStr.sprintf("Exceeded set or binding for: %s", r.name.c_str());
  122. return Error::kUserData;
  123. }
  124. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  125. U32 arraySize = 1;
  126. if(typeInfo.array.size() != 0)
  127. {
  128. if(typeInfo.array.size() != 1 || (arraySize = typeInfo.array[0]) == 0)
  129. {
  130. errorStr.sprintf("Only 1D arrays are supported: %s", r.name.c_str());
  131. return Error::kUserData;
  132. }
  133. }
  134. refl.m_descriptorSetMask.set(set);
  135. // Images are special, they might be texel buffers
  136. DescriptorType type = origType;
  137. DescriptorFlag flags = origFlags;
  138. if(type == DescriptorType::kTexture)
  139. {
  140. if(typeInfo.image.dim == spv::DimBuffer)
  141. {
  142. type = DescriptorType::kTexelBuffer;
  143. if(typeInfo.image.sampled == 1)
  144. {
  145. flags = DescriptorFlag::kRead;
  146. }
  147. else
  148. {
  149. ANKI_ASSERT(typeInfo.image.sampled == 2);
  150. flags = DescriptorFlag::kReadWrite;
  151. }
  152. }
  153. }
  154. // Check that there are no other descriptors with the same binding
  155. if(refl.m_descriptorTypes[set][binding] == DescriptorType::kCount)
  156. {
  157. // New binding, init it
  158. refl.m_descriptorTypes[set][binding] = type;
  159. refl.m_descriptorArraySizes[set][binding] = U16(arraySize);
  160. refl.m_descriptorFlags[set][binding] = flags;
  161. }
  162. else
  163. {
  164. // Same binding, make sure the type is compatible
  165. if(refl.m_descriptorTypes[set][binding] != type || refl.m_descriptorArraySizes[set][binding] != arraySize
  166. || refl.m_descriptorFlags[set][binding] != flags)
  167. {
  168. errorStr.sprintf("Descriptor with same binding but different type or array size: %s", r.name.c_str());
  169. return Error::kUserData;
  170. }
  171. }
  172. }
  173. return Error::kNone;
  174. };
  175. Error err = Error::kNone;
  176. err = func(rsrc.uniform_buffers, DescriptorType::kUniformBuffer, DescriptorFlag::kRead);
  177. if(!err)
  178. {
  179. err = func(rsrc.separate_images, DescriptorType::kTexture, DescriptorFlag::kRead); // This also handles texel buffers
  180. }
  181. if(!err)
  182. {
  183. err = func(rsrc.separate_samplers, DescriptorType::kSampler, DescriptorFlag::kRead);
  184. }
  185. if(!err)
  186. {
  187. err = func(rsrc.storage_buffers, DescriptorType::kStorageBuffer, DescriptorFlag::kReadWrite);
  188. }
  189. if(!err)
  190. {
  191. err = func(rsrc.storage_images, DescriptorType::kTexture, DescriptorFlag::kReadWrite);
  192. }
  193. if(!err)
  194. {
  195. err = func(rsrc.acceleration_structures, DescriptorType::kAccelerationStructure, DescriptorFlag::kRead);
  196. }
  197. // Color attachments
  198. if(type == ShaderType::kFragment)
  199. {
  200. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  201. {
  202. const U32 id = r.id;
  203. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  204. refl.m_colorAttachmentWritemask.set(location);
  205. }
  206. }
  207. // Push consts
  208. if(rsrc.push_constant_buffers.size() == 1)
  209. {
  210. const U32 blockSize = U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  211. if(blockSize == 0 || (blockSize % 16) != 0 || blockSize > kMaxU8)
  212. {
  213. errorStr.sprintf("Incorrect push constants size");
  214. return Error::kUserData;
  215. }
  216. refl.m_pushConstantsSize = U8(blockSize);
  217. }
  218. // Attribs
  219. if(type == ShaderType::kVertex)
  220. {
  221. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  222. {
  223. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  224. #define ANKI_ATTRIB_NAME(x) "in.var." #x
  225. if(r.name == ANKI_ATTRIB_NAME(POSITION))
  226. {
  227. a = VertexAttributeSemantic::kPosition;
  228. }
  229. else if(r.name == ANKI_ATTRIB_NAME(NORMAL))
  230. {
  231. a = VertexAttributeSemantic::kNormal;
  232. }
  233. else if(r.name == ANKI_ATTRIB_NAME(TEXCOORD0) || r.name == ANKI_ATTRIB_NAME(TEXCOORD))
  234. {
  235. a = VertexAttributeSemantic::kTexCoord;
  236. }
  237. else if(r.name == ANKI_ATTRIB_NAME(COLOR))
  238. {
  239. a = VertexAttributeSemantic::kColor;
  240. }
  241. else if(r.name == ANKI_ATTRIB_NAME(MISC0) || r.name == ANKI_ATTRIB_NAME(MISC))
  242. {
  243. a = VertexAttributeSemantic::kMisc0;
  244. }
  245. else if(r.name == ANKI_ATTRIB_NAME(MISC1))
  246. {
  247. a = VertexAttributeSemantic::kMisc1;
  248. }
  249. else if(r.name == ANKI_ATTRIB_NAME(MISC2))
  250. {
  251. a = VertexAttributeSemantic::kMisc2;
  252. }
  253. else if(r.name == ANKI_ATTRIB_NAME(MISC3))
  254. {
  255. a = VertexAttributeSemantic::kMisc3;
  256. }
  257. else
  258. {
  259. errorStr.sprintf("Unexpected attribute name: %s", r.name.c_str());
  260. return Error::kUserData;
  261. }
  262. #undef ANKI_ATTRIB_NAME
  263. refl.m_vertexAttributeMask.set(a);
  264. const U32 id = r.id;
  265. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  266. if(location > kMaxU8)
  267. {
  268. errorStr.sprintf("Too high location value for attribute: %s", r.name.c_str());
  269. return Error::kUserData;
  270. }
  271. refl.m_vertexAttributeLocations[a] = U8(location);
  272. }
  273. }
  274. // Discards?
  275. if(type == ShaderType::kFragment)
  276. {
  277. visitSpirv(ConstWeakArray<U32>(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32)), [&](U32 cmd) {
  278. if(cmd == spv::OpKill)
  279. {
  280. refl.m_discards = true;
  281. }
  282. });
  283. }
  284. return Error::kNone;
  285. }
  286. #if ANKI_DIXL_REFLECTION
  287. # define ANKI_REFL_CHECK(x) \
  288. do \
  289. { \
  290. HRESULT rez; \
  291. if((rez = (x)) < 0) [[unlikely]] \
  292. { \
  293. errorStr.sprintf("DXC function failed (HRESULT: %d): %s", rez, #x); \
  294. return Error::kFunctionFailed; \
  295. } \
  296. } while(0)
  297. Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  298. {
  299. using Microsoft::WRL::ComPtr;
  300. // Lazyly load the DXC DLL
  301. {
  302. LockGuard lock(g_dxcLibMtx);
  303. if(g_dxcLib == 0)
  304. {
  305. // Init DXC
  306. g_dxcLib = LoadLibraryA(ANKI_SOURCE_DIRECTORY "/ThirdParty/Bin/Windows64/dxcompiler.dll");
  307. if(g_dxcLib == 0)
  308. {
  309. ANKI_SHADER_COMPILER_LOGE("dxcompiler.dll missing or wrong architecture");
  310. return Error::kFunctionFailed;
  311. }
  312. g_DxcCreateInstance = reinterpret_cast<DxcCreateInstanceProc>(GetProcAddress(g_dxcLib, "DxcCreateInstance"));
  313. if(g_DxcCreateInstance == nullptr)
  314. {
  315. ANKI_SHADER_COMPILER_LOGE("DxcCreateInstance was not found in the dxcompiler.dll");
  316. return Error::kFunctionFailed;
  317. }
  318. }
  319. }
  320. const Bool isRayTracing = type >= ShaderType::kFirstRayTracing && type <= ShaderType::kLastRayTracing;
  321. if(isRayTracing)
  322. {
  323. // TODO: Skip for now. RT shaders require explicity register()
  324. return Error::kNone;
  325. }
  326. ComPtr<IDxcUtils> utils;
  327. ANKI_REFL_CHECK(g_DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&utils)));
  328. ComPtr<ID3D12ShaderReflection> dxRefl;
  329. ComPtr<ID3D12LibraryReflection> libRefl;
  330. ID3D12FunctionReflection* funcRefl = nullptr;
  331. D3D12_SHADER_DESC shaderDesc = {};
  332. U32 bindingCount = 0;
  333. if(!isRayTracing)
  334. {
  335. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  336. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&dxRefl)));
  337. ANKI_REFL_CHECK(dxRefl->GetDesc(&shaderDesc));
  338. bindingCount = shaderDesc.BoundResources;
  339. }
  340. else
  341. {
  342. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  343. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&libRefl)));
  344. D3D12_LIBRARY_DESC libDesc = {};
  345. libRefl->GetDesc(&libDesc);
  346. if(libDesc.FunctionCount != 1)
  347. {
  348. errorStr.sprintf("Expecting 1 in D3D12_LIBRARY_DESC::FunctionCount");
  349. return Error::kUserData;
  350. }
  351. funcRefl = libRefl->GetFunctionByIndex(0);
  352. D3D12_FUNCTION_DESC funcDesc;
  353. ANKI_REFL_CHECK(funcRefl->GetDesc(&funcDesc));
  354. bindingCount = funcDesc.BoundResources;
  355. }
  356. for(U32 i = 0; i < bindingCount; ++i)
  357. {
  358. D3D12_SHADER_INPUT_BIND_DESC bindDesc;
  359. if(dxRefl.Get() != nullptr)
  360. {
  361. ANKI_REFL_CHECK(dxRefl->GetResourceBindingDesc(i, &bindDesc));
  362. }
  363. else
  364. {
  365. ANKI_REFL_CHECK(funcRefl->GetResourceBindingDesc(i, &bindDesc));
  366. }
  367. ShaderReflectionBinding akBinding;
  368. akBinding.m_type = DescriptorType::kCount;
  369. akBinding.m_flags = DescriptorFlag::kNone;
  370. akBinding.m_arraySize = U16(bindDesc.BindCount);
  371. akBinding.m_registerBindingPoint = bindDesc.BindPoint;
  372. if(bindDesc.Type == D3D_SIT_CBUFFER)
  373. {
  374. // ConstantBuffer
  375. if(bindDesc.BindPoint == kPushConstantsRegisterBindPoint && bindDesc.Space == kPushConstantsRegisterSpace)
  376. {
  377. // It's push/root constants
  378. ID3D12ShaderReflectionConstantBuffer* cbuffer =
  379. (dxRefl.Get()) ? dxRefl->GetConstantBufferByName(bindDesc.Name) : funcRefl->GetConstantBufferByName(bindDesc.Name);
  380. D3D12_SHADER_BUFFER_DESC desc;
  381. ANKI_REFL_CHECK(cbuffer->GetDesc(&desc));
  382. refl.m_pushConstantsSize = U8(desc.Size);
  383. continue;
  384. }
  385. akBinding.m_type = DescriptorType::kUniformBuffer;
  386. akBinding.m_flags = DescriptorFlag::kRead;
  387. }
  388. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  389. {
  390. // Texture2D etc
  391. akBinding.m_type = DescriptorType::kTexture;
  392. akBinding.m_flags = DescriptorFlag::kRead;
  393. }
  394. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  395. {
  396. // Buffer
  397. akBinding.m_type = DescriptorType::kTexelBuffer;
  398. akBinding.m_flags = DescriptorFlag::kRead;
  399. }
  400. else if(bindDesc.Type == D3D_SIT_SAMPLER)
  401. {
  402. // SamplerState
  403. akBinding.m_type = DescriptorType::kSampler;
  404. akBinding.m_flags = DescriptorFlag::kRead;
  405. }
  406. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  407. {
  408. // RWBuffer
  409. akBinding.m_type = DescriptorType::kTexelBuffer;
  410. akBinding.m_flags = DescriptorFlag::kReadWrite;
  411. }
  412. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  413. {
  414. // RWTexture2D etc
  415. akBinding.m_type = DescriptorType::kTexture;
  416. akBinding.m_flags = DescriptorFlag::kReadWrite;
  417. }
  418. else if(bindDesc.Type == D3D_SIT_BYTEADDRESS)
  419. {
  420. // ByteAddressBuffer
  421. akBinding.m_type = DescriptorType::kStorageBuffer;
  422. akBinding.m_flags = DescriptorFlag::kRead | DescriptorFlag::kByteAddressBuffer;
  423. }
  424. else if(bindDesc.Type == D3D_SIT_UAV_RWBYTEADDRESS)
  425. {
  426. // RWByteAddressBuffer
  427. akBinding.m_type = DescriptorType::kStorageBuffer;
  428. akBinding.m_flags = DescriptorFlag::kReadWrite | DescriptorFlag::kByteAddressBuffer;
  429. }
  430. else if(bindDesc.Type == D3D_SIT_RTACCELERATIONSTRUCTURE)
  431. {
  432. // RaytracingAccelerationStructure
  433. akBinding.m_type = DescriptorType::kAccelerationStructure;
  434. akBinding.m_flags = DescriptorFlag::kRead;
  435. }
  436. else if(bindDesc.Type == D3D_SIT_STRUCTURED)
  437. {
  438. // StructuredBuffer
  439. akBinding.m_type = DescriptorType::kStorageBuffer;
  440. akBinding.m_flags = DescriptorFlag::kRead;
  441. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  442. }
  443. else if(bindDesc.Type == D3D_SIT_UAV_RWSTRUCTURED)
  444. {
  445. // RWStructuredBuffer
  446. akBinding.m_type = DescriptorType::kStorageBuffer;
  447. akBinding.m_flags = DescriptorFlag::kReadWrite;
  448. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  449. }
  450. else
  451. {
  452. errorStr.sprintf("Unrecognized type for binding: %s", bindDesc.Name);
  453. return Error::kUserData;
  454. }
  455. refl.m_bindings[bindDesc.Space][refl.m_bindingCounts[bindDesc.Space]] = akBinding;
  456. ++refl.m_bindingCounts[bindDesc.Space];
  457. refl.m_descriptorSetMask.set(bindDesc.Space);
  458. }
  459. for(U32 i = 0; i < kMaxDescriptorSets; ++i)
  460. {
  461. std::sort(refl.m_bindings[i].getBegin(), refl.m_bindings[i].getBegin() + refl.m_bindingCounts[i]);
  462. }
  463. if(type == ShaderType::kVertex)
  464. {
  465. for(U32 i = 0; i < shaderDesc.InputParameters; ++i)
  466. {
  467. D3D12_SIGNATURE_PARAMETER_DESC in;
  468. ANKI_REFL_CHECK(dxRefl->GetInputParameterDesc(i, &in));
  469. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  470. # define ANKI_ATTRIB_NAME(x, idx) CString(in.SemanticName) == # x&& in.SemanticIndex == idx
  471. if(ANKI_ATTRIB_NAME(POSITION, 0))
  472. {
  473. a = VertexAttributeSemantic::kPosition;
  474. }
  475. else if(ANKI_ATTRIB_NAME(NORMAL, 0))
  476. {
  477. a = VertexAttributeSemantic::kNormal;
  478. }
  479. else if(ANKI_ATTRIB_NAME(TEXCOORD, 0))
  480. {
  481. a = VertexAttributeSemantic::kTexCoord;
  482. }
  483. else if(ANKI_ATTRIB_NAME(COLOR, 0))
  484. {
  485. a = VertexAttributeSemantic::kColor;
  486. }
  487. else if(ANKI_ATTRIB_NAME(MISC, 0))
  488. {
  489. a = VertexAttributeSemantic::kMisc0;
  490. }
  491. else if(ANKI_ATTRIB_NAME(MISC, 1))
  492. {
  493. a = VertexAttributeSemantic::kMisc1;
  494. }
  495. else if(ANKI_ATTRIB_NAME(MISC, 2))
  496. {
  497. a = VertexAttributeSemantic::kMisc2;
  498. }
  499. else if(ANKI_ATTRIB_NAME(MISC, 3))
  500. {
  501. a = VertexAttributeSemantic::kMisc3;
  502. }
  503. else if(ANKI_ATTRIB_NAME(SV_VERTEXID, 0) || ANKI_ATTRIB_NAME(SV_INSTANCEID, 0))
  504. {
  505. // Ignore
  506. continue;
  507. }
  508. else
  509. {
  510. errorStr.sprintf("Unexpected attribute name: %s", in.SemanticName);
  511. return Error::kUserData;
  512. }
  513. # undef ANKI_ATTRIB_NAME
  514. refl.m_vertexAttributeMask.set(a);
  515. refl.m_vertexAttributeLocations[a] = U8(i);
  516. }
  517. }
  518. if(type == ShaderType::kFragment)
  519. {
  520. for(U32 i = 0; i < shaderDesc.OutputParameters; ++i)
  521. {
  522. D3D12_SIGNATURE_PARAMETER_DESC desc;
  523. ANKI_REFL_CHECK(dxRefl->GetOutputParameterDesc(i, &desc));
  524. if(CString(desc.SemanticName) == "SV_TARGET")
  525. {
  526. refl.m_colorAttachmentWritemask.set(desc.SemanticIndex);
  527. }
  528. }
  529. }
  530. return Error::kNone;
  531. }
  532. #endif // #if ANKI_DIXL_REFLECTION
  533. static void compileVariantAsync(const ShaderProgramParser& parser, Bool spirv, ShaderProgramBinaryMutation& mutation,
  534. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant>& variants,
  535. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock>& codeBlocks,
  536. ShaderCompilerDynamicArray<U64>& sourceCodeHashes, ShaderProgramAsyncTaskInterface& taskManager, Mutex& mtx,
  537. Atomic<I32>& error)
  538. {
  539. class Ctx
  540. {
  541. public:
  542. const ShaderProgramParser* m_parser;
  543. ShaderProgramBinaryMutation* m_mutation;
  544. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant>* m_variants;
  545. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock>* m_codeBlocks;
  546. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  547. Mutex* m_mtx;
  548. Atomic<I32>* m_err;
  549. Bool m_spirv;
  550. };
  551. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  552. ctx->m_parser = &parser;
  553. ctx->m_mutation = &mutation;
  554. ctx->m_variants = &variants;
  555. ctx->m_codeBlocks = &codeBlocks;
  556. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  557. ctx->m_mtx = &mtx;
  558. ctx->m_err = &error;
  559. ctx->m_spirv = spirv;
  560. auto callback = [](void* userData) {
  561. Ctx& ctx = *static_cast<Ctx*>(userData);
  562. class Cleanup
  563. {
  564. public:
  565. Ctx* m_ctx;
  566. ~Cleanup()
  567. {
  568. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  569. }
  570. } cleanup{&ctx};
  571. if(ctx.m_err->load() != 0)
  572. {
  573. // Don't bother
  574. return;
  575. }
  576. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  577. // Compile the sources
  578. ShaderCompilerDynamicArray<ShaderProgramBinaryTechniqueCodeBlocks> codeBlockIndices;
  579. codeBlockIndices.resize(techniqueCount);
  580. for(auto& it : codeBlockIndices)
  581. {
  582. it.m_codeBlockIndices.fill(kMaxU32);
  583. }
  584. ShaderCompilerString compilerErrorLog;
  585. Error err = Error::kNone;
  586. U newCodeBlockCount = 0;
  587. for(U32 t = 0; t < techniqueCount && !err; ++t)
  588. {
  589. const ShaderProgramParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  590. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  591. {
  592. ShaderCompilerString source;
  593. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  594. // Check if the source code was found before
  595. const U64 sourceCodeHash = source.computeHash();
  596. if(technique.m_activeMutators[shaderType] != kMaxU64)
  597. {
  598. LockGuard lock(*ctx.m_mtx);
  599. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  600. Bool found = false;
  601. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  602. {
  603. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  604. {
  605. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  606. found = true;
  607. break;
  608. }
  609. }
  610. if(found)
  611. {
  612. continue;
  613. }
  614. }
  615. ShaderCompilerDynamicArray<U8> il;
  616. if(ctx.m_spirv)
  617. {
  618. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), il, compilerErrorLog);
  619. }
  620. else
  621. {
  622. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), il, compilerErrorLog);
  623. }
  624. if(err)
  625. {
  626. break;
  627. }
  628. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  629. ShaderReflection refl;
  630. if(ctx.m_spirv)
  631. {
  632. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  633. }
  634. else
  635. {
  636. #if ANKI_DIXL_REFLECTION
  637. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  638. #else
  639. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  640. err = Error::kFunctionFailed;
  641. #endif
  642. }
  643. if(err)
  644. {
  645. break;
  646. }
  647. // Add the binary if not already there
  648. {
  649. LockGuard lock(*ctx.m_mtx);
  650. Bool found = false;
  651. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  652. {
  653. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  654. {
  655. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  656. found = true;
  657. break;
  658. }
  659. }
  660. if(!found)
  661. {
  662. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  663. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  664. il.moveAndReset(codeBlock.m_binary);
  665. codeBlock.m_hash = newHash;
  666. codeBlock.m_reflection = refl;
  667. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  668. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  669. ++newCodeBlockCount;
  670. }
  671. }
  672. }
  673. }
  674. if(err)
  675. {
  676. I32 expectedErr = 0;
  677. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  678. if(isFirstError)
  679. {
  680. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  681. return;
  682. }
  683. return;
  684. }
  685. // Do variant stuff
  686. {
  687. LockGuard lock(*ctx.m_mtx);
  688. Bool createVariant = true;
  689. if(newCodeBlockCount == 0)
  690. {
  691. // No new code blocks generated, search all variants to see if there is a duplicate
  692. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  693. {
  694. Bool same = true;
  695. for(U32 t = 0; t < techniqueCount; ++t)
  696. {
  697. const ShaderProgramBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  698. const ShaderProgramBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  699. if(memcmp(&a, &b, sizeof(a)) != 0)
  700. {
  701. // Not the same
  702. same = false;
  703. break;
  704. }
  705. }
  706. if(same)
  707. {
  708. createVariant = false;
  709. ctx.m_mutation->m_variantIndex = i;
  710. break;
  711. }
  712. }
  713. }
  714. // Create a new variant
  715. if(createVariant)
  716. {
  717. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  718. ShaderProgramBinaryVariant* variant = ctx.m_variants->emplaceBack();
  719. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  720. }
  721. }
  722. };
  723. taskManager.enqueueTask(callback, ctx);
  724. }
  725. static Error compileShaderProgramInternal(CString fname, Bool spirv, ShaderProgramFilesystemInterface& fsystem,
  726. ShaderProgramPostParseInterface* postParseCallback, ShaderProgramAsyncTaskInterface* taskManager_,
  727. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderProgramBinary*& binary)
  728. {
  729. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  730. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  731. for(const ShaderCompilerDefine& d : defines_)
  732. {
  733. defines.emplaceBack(d);
  734. }
  735. // Initialize the binary
  736. binary = newInstance<ShaderProgramBinary>(memPool);
  737. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  738. // Parse source
  739. ShaderProgramParser parser(fname, &fsystem, defines);
  740. ANKI_CHECK(parser.parse());
  741. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  742. {
  743. return Error::kNone;
  744. }
  745. // Get mutators
  746. U32 mutationCount = 0;
  747. if(parser.getMutators().getSize() > 0)
  748. {
  749. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  750. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  751. {
  752. ShaderProgramBinaryMutator& out = binary->m_mutators[i];
  753. const ShaderProgramParserMutator& in = parser.getMutators()[i];
  754. zeroMemory(out);
  755. newArray(memPool, in.m_values.getSize(), out.m_values);
  756. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  757. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  758. // Update the count
  759. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  760. }
  761. }
  762. else
  763. {
  764. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  765. }
  766. // Create all variants
  767. Mutex mtx;
  768. Atomic<I32> errorAtomic(0);
  769. class SyncronousShaderProgramAsyncTaskInterface : public ShaderProgramAsyncTaskInterface
  770. {
  771. public:
  772. void enqueueTask(void (*callback)(void* userData), void* userData) final
  773. {
  774. callback(userData);
  775. }
  776. Error joinTasks() final
  777. {
  778. // Nothing
  779. return Error::kNone;
  780. }
  781. } syncTaskManager;
  782. ShaderProgramAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  783. if(parser.getMutators().getSize() > 0)
  784. {
  785. // Initialize
  786. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  787. mutationValues.resize(parser.getMutators().getSize());
  788. ShaderCompilerDynamicArray<U32> dials;
  789. dials.resize(parser.getMutators().getSize(), 0);
  790. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant> variants;
  791. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock> codeBlocks;
  792. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  793. ShaderCompilerDynamicArray<ShaderProgramBinaryMutation> mutations;
  794. mutations.resize(mutationCount);
  795. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  796. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  797. variants.resizeStorage(mutationCount);
  798. mutationCount = 0;
  799. // Spin for all possible combinations of mutators and
  800. // - Create the spirv
  801. // - Populate the binary variant
  802. do
  803. {
  804. // Create the mutation
  805. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  806. {
  807. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  808. }
  809. ShaderProgramBinaryMutation& mutation = mutations[mutationCount++];
  810. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  811. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  812. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  813. ANKI_ASSERT(mutation.m_hash > 0);
  814. if(parser.skipMutation(mutationValues))
  815. {
  816. mutation.m_variantIndex = kMaxU32;
  817. }
  818. else
  819. {
  820. // New and unique mutation and thus variant, add it
  821. compileVariantAsync(parser, spirv, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  822. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  823. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  824. }
  825. } while(!spinDials(dials, parser.getMutators()));
  826. ANKI_ASSERT(mutationCount == mutations.getSize());
  827. // Done, wait the threads
  828. ANKI_CHECK(taskManager.joinTasks());
  829. // Now error out
  830. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  831. // Store temp containers to binary
  832. codeBlocks.moveAndReset(binary->m_codeBlocks);
  833. mutations.moveAndReset(binary->m_mutations);
  834. variants.moveAndReset(binary->m_variants);
  835. }
  836. else
  837. {
  838. newArray(memPool, 1, binary->m_mutations);
  839. ShaderCompilerDynamicArray<ShaderProgramBinaryVariant> variants;
  840. ShaderCompilerDynamicArray<ShaderProgramBinaryCodeBlock> codeBlocks;
  841. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  842. compileVariantAsync(parser, spirv, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  843. ANKI_CHECK(taskManager.joinTasks());
  844. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  845. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  846. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  847. ANKI_ASSERT(variants.getSize() == 1);
  848. binary->m_mutations[0].m_hash = 1;
  849. codeBlocks.moveAndReset(binary->m_codeBlocks);
  850. variants.moveAndReset(binary->m_variants);
  851. }
  852. // Sort the mutations
  853. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(),
  854. [](const ShaderProgramBinaryMutation& a, const ShaderProgramBinaryMutation& b) {
  855. return a.m_hash < b.m_hash;
  856. });
  857. // Techniques
  858. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  859. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  860. {
  861. zeroMemory(binary->m_techniques[i].m_name);
  862. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  863. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  864. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  865. }
  866. // Structs
  867. if(parser.getGhostStructs().getSize())
  868. {
  869. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  870. }
  871. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  872. {
  873. const ShaderProgramParserGhostStruct& in = parser.getGhostStructs()[i];
  874. ShaderProgramBinaryStruct& out = binary->m_structs[i];
  875. zeroMemory(out);
  876. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  877. ANKI_ASSERT(in.m_members.getSize());
  878. newArray(memPool, in.m_members.getSize(), out.m_members);
  879. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  880. {
  881. const ShaderProgramParserMember& inm = in.m_members[j];
  882. ShaderProgramBinaryStructMember& outm = out.m_members[j];
  883. zeroMemory(outm.m_name);
  884. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  885. outm.m_offset = inm.m_offset;
  886. outm.m_type = inm.m_type;
  887. }
  888. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  889. }
  890. return Error::kNone;
  891. }
  892. Error compileShaderProgram(CString fname, Bool spirv, ShaderProgramFilesystemInterface& fsystem, ShaderProgramPostParseInterface* postParseCallback,
  893. ShaderProgramAsyncTaskInterface* taskManager, ConstWeakArray<ShaderCompilerDefine> defines, ShaderProgramBinary*& binary)
  894. {
  895. const Error err = compileShaderProgramInternal(fname, spirv, fsystem, postParseCallback, taskManager, defines, binary);
  896. if(err)
  897. {
  898. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  899. freeShaderProgramBinary(binary);
  900. }
  901. return err;
  902. }
  903. } // end namespace anki