ShaderCompiler.cpp 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/Util/Serializer.h>
  9. #include <AnKi/Util/HashMap.h>
  10. #include <SpirvCross/spirv_cross.hpp>
  11. #define ANKI_DIXL_REFLECTION (ANKI_OS_WINDOWS)
  12. #if ANKI_DIXL_REFLECTION
  13. # include <windows.h>
  14. # include <ThirdParty/Dxc/dxcapi.h>
  15. # include <ThirdParty/Dxc/d3d12shader.h>
  16. # include <wrl.h>
  17. # include <AnKi/Util/CleanupWindows.h>
  18. #endif
  19. namespace anki {
  20. #if ANKI_DIXL_REFLECTION
  21. static HMODULE g_dxcLib = 0;
  22. static DxcCreateInstanceProc g_DxcCreateInstance = nullptr;
  23. static Mutex g_dxcLibMtx;
  24. #endif
  25. void freeShaderBinary(ShaderBinary*& binary)
  26. {
  27. if(binary == nullptr)
  28. {
  29. return;
  30. }
  31. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  32. for(ShaderBinaryCodeBlock& code : binary->m_codeBlocks)
  33. {
  34. mempool.free(code.m_binary.getBegin());
  35. }
  36. mempool.free(binary->m_codeBlocks.getBegin());
  37. for(ShaderBinaryMutator& mutator : binary->m_mutators)
  38. {
  39. mempool.free(mutator.m_values.getBegin());
  40. }
  41. mempool.free(binary->m_mutators.getBegin());
  42. for(ShaderBinaryMutation& m : binary->m_mutations)
  43. {
  44. mempool.free(m.m_values.getBegin());
  45. }
  46. mempool.free(binary->m_mutations.getBegin());
  47. for(ShaderBinaryVariant& variant : binary->m_variants)
  48. {
  49. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  50. }
  51. mempool.free(binary->m_variants.getBegin());
  52. mempool.free(binary->m_techniques.getBegin());
  53. for(ShaderBinaryStruct& s : binary->m_structs)
  54. {
  55. mempool.free(s.m_members.getBegin());
  56. }
  57. mempool.free(binary->m_structs.getBegin());
  58. mempool.free(binary);
  59. binary = nullptr;
  60. }
  61. /// Spin the dials. Used to compute all mutator combinations.
  62. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderParserMutator> mutators)
  63. {
  64. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  65. Bool done = true;
  66. U32 crntDial = dials.getSize() - 1;
  67. while(true)
  68. {
  69. // Turn dial
  70. ++dials[crntDial];
  71. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  72. {
  73. if(crntDial == 0)
  74. {
  75. // Reached the 1st dial, stop spinning
  76. done = true;
  77. break;
  78. }
  79. else
  80. {
  81. dials[crntDial] = 0;
  82. --crntDial;
  83. }
  84. }
  85. else
  86. {
  87. done = false;
  88. break;
  89. }
  90. }
  91. return done;
  92. }
  93. /// Does SPIR-V reflection and re-writes the SPIR-V binary's bindings
  94. Error doReflectionSpirv(ConstWeakArray<U8> spirv, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  95. {
  96. spirv_cross::Compiler spvc(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32));
  97. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  98. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  99. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, const DescriptorType origType,
  100. const DescriptorFlag origFlags) -> Error {
  101. for(const spirv_cross::Resource& r : resources)
  102. {
  103. const U32 id = r.id;
  104. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  105. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  106. if(set >= kMaxDescriptorSets && set != kDxcVkBindlessRegisterSpace)
  107. {
  108. errorStr.sprintf("Exceeded set for: %s", r.name.c_str());
  109. return Error::kUserData;
  110. }
  111. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  112. U32 arraySize = 1;
  113. if(typeInfo.array.size() != 0)
  114. {
  115. if(typeInfo.array.size() != 1)
  116. {
  117. errorStr.sprintf("Only 1D arrays are supported: %s", r.name.c_str());
  118. return Error::kUserData;
  119. }
  120. if(set == kDxcVkBindlessRegisterSpace && typeInfo.array[0] != 0)
  121. {
  122. errorStr.sprintf("Only the bindless descriptor set can be an unbound array: %s", r.name.c_str());
  123. return Error::kUserData;
  124. }
  125. arraySize = typeInfo.array[0];
  126. }
  127. // Images are special, they might be texel buffers
  128. DescriptorType type = origType;
  129. DescriptorFlag flags = origFlags;
  130. if(type == DescriptorType::kTexture && typeInfo.image.dim == spv::DimBuffer)
  131. {
  132. type = DescriptorType::kTexelBuffer;
  133. if(typeInfo.image.sampled == 1)
  134. {
  135. flags = DescriptorFlag::kRead;
  136. }
  137. else
  138. {
  139. ANKI_ASSERT(typeInfo.image.sampled == 2);
  140. flags = DescriptorFlag::kReadWrite;
  141. }
  142. }
  143. if(set == kDxcVkBindlessRegisterSpace)
  144. {
  145. // Bindless dset
  146. if(arraySize != 0)
  147. {
  148. errorStr.sprintf("Unexpected unbound array for bindless: %s", r.name.c_str());
  149. }
  150. if(type != DescriptorType::kTexture || flags != DescriptorFlag::kRead)
  151. {
  152. errorStr.sprintf("Unexpected bindless binding: %s", r.name.c_str());
  153. return Error::kUserData;
  154. }
  155. refl.m_descriptor.m_hasVkBindlessDescriptorSet = true;
  156. }
  157. else
  158. {
  159. // Regular binding
  160. if(origType == DescriptorType::kStorageBuffer && binding >= kDxcVkBindingShifts[set][HlslResourceType::kSrv]
  161. && binding < kDxcVkBindingShifts[set][HlslResourceType::kSrv] + 1000)
  162. {
  163. // Storage buffer is read only
  164. flags = DescriptorFlag::kRead;
  165. }
  166. const HlslResourceType hlslResourceType = descriptorTypeToHlslResourceType(type, flags);
  167. if(binding < kDxcVkBindingShifts[set][hlslResourceType] || binding >= kDxcVkBindingShifts[set][hlslResourceType] + 1000)
  168. {
  169. errorStr.sprintf("Unexpected binding: %s", r.name.c_str());
  170. return Error::kUserData;
  171. }
  172. ShaderReflectionBinding akBinding;
  173. akBinding.m_registerBindingPoint = binding - kDxcVkBindingShifts[set][hlslResourceType];
  174. akBinding.m_arraySize = U16(arraySize);
  175. akBinding.m_type = type;
  176. akBinding.m_flags = flags;
  177. refl.m_descriptor.m_bindings[set][refl.m_descriptor.m_bindingCounts[set]] = akBinding;
  178. ++refl.m_descriptor.m_bindingCounts[set];
  179. }
  180. }
  181. return Error::kNone;
  182. };
  183. Error err = func(rsrc.uniform_buffers, DescriptorType::kUniformBuffer, DescriptorFlag::kRead);
  184. if(err)
  185. {
  186. return err;
  187. }
  188. err = func(rsrc.separate_images, DescriptorType::kTexture, DescriptorFlag::kRead); // This also handles texel buffers
  189. if(err)
  190. {
  191. return err;
  192. }
  193. err = func(rsrc.separate_samplers, DescriptorType::kSampler, DescriptorFlag::kRead);
  194. if(err)
  195. {
  196. return err;
  197. }
  198. err = func(rsrc.storage_buffers, DescriptorType::kStorageBuffer, DescriptorFlag::kReadWrite);
  199. if(err)
  200. {
  201. return err;
  202. }
  203. err = func(rsrc.storage_images, DescriptorType::kTexture, DescriptorFlag::kReadWrite);
  204. if(err)
  205. {
  206. return err;
  207. }
  208. err = func(rsrc.acceleration_structures, DescriptorType::kAccelerationStructure, DescriptorFlag::kRead);
  209. if(err)
  210. {
  211. return err;
  212. }
  213. for(U32 i = 0; i < kMaxDescriptorSets; ++i)
  214. {
  215. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  216. }
  217. // Color attachments
  218. if(type == ShaderType::kFragment)
  219. {
  220. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  221. {
  222. const U32 id = r.id;
  223. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  224. refl.m_fragment.m_colorAttachmentWritemask.set(location);
  225. }
  226. }
  227. // Push consts
  228. if(rsrc.push_constant_buffers.size() == 1)
  229. {
  230. const U32 blockSize = U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  231. if(blockSize == 0 || (blockSize % 16) != 0 || blockSize > kMaxU8)
  232. {
  233. errorStr.sprintf("Incorrect push constants size");
  234. return Error::kUserData;
  235. }
  236. refl.m_descriptor.m_pushConstantsSize = blockSize;
  237. }
  238. // Attribs
  239. if(type == ShaderType::kVertex)
  240. {
  241. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  242. {
  243. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  244. #define ANKI_ATTRIB_NAME(x) "in.var." #x
  245. if(r.name == ANKI_ATTRIB_NAME(POSITION))
  246. {
  247. a = VertexAttributeSemantic::kPosition;
  248. }
  249. else if(r.name == ANKI_ATTRIB_NAME(NORMAL))
  250. {
  251. a = VertexAttributeSemantic::kNormal;
  252. }
  253. else if(r.name == ANKI_ATTRIB_NAME(TEXCOORD0) || r.name == ANKI_ATTRIB_NAME(TEXCOORD))
  254. {
  255. a = VertexAttributeSemantic::kTexCoord;
  256. }
  257. else if(r.name == ANKI_ATTRIB_NAME(COLOR))
  258. {
  259. a = VertexAttributeSemantic::kColor;
  260. }
  261. else if(r.name == ANKI_ATTRIB_NAME(MISC0) || r.name == ANKI_ATTRIB_NAME(MISC))
  262. {
  263. a = VertexAttributeSemantic::kMisc0;
  264. }
  265. else if(r.name == ANKI_ATTRIB_NAME(MISC1))
  266. {
  267. a = VertexAttributeSemantic::kMisc1;
  268. }
  269. else if(r.name == ANKI_ATTRIB_NAME(MISC2))
  270. {
  271. a = VertexAttributeSemantic::kMisc2;
  272. }
  273. else if(r.name == ANKI_ATTRIB_NAME(MISC3))
  274. {
  275. a = VertexAttributeSemantic::kMisc3;
  276. }
  277. else
  278. {
  279. errorStr.sprintf("Unexpected attribute name: %s", r.name.c_str());
  280. return Error::kUserData;
  281. }
  282. #undef ANKI_ATTRIB_NAME
  283. refl.m_vertex.m_vertexAttributeMask.set(a);
  284. const U32 id = r.id;
  285. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  286. if(location > kMaxU8)
  287. {
  288. errorStr.sprintf("Too high location value for attribute: %s", r.name.c_str());
  289. return Error::kUserData;
  290. }
  291. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(location);
  292. }
  293. }
  294. // Discards?
  295. if(type == ShaderType::kFragment)
  296. {
  297. visitSpirv(ConstWeakArray<U32>(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32)),
  298. [&](U32 cmd, [[maybe_unused]] ConstWeakArray<U32> instructions) {
  299. if(cmd == spv::OpKill)
  300. {
  301. refl.m_fragment.m_discards = true;
  302. }
  303. });
  304. }
  305. return Error::kNone;
  306. }
  307. #if ANKI_DIXL_REFLECTION
  308. # define ANKI_REFL_CHECK(x) \
  309. do \
  310. { \
  311. HRESULT rez; \
  312. if((rez = (x)) < 0) [[unlikely]] \
  313. { \
  314. errorStr.sprintf("DXC function failed (HRESULT: %d): %s", rez, #x); \
  315. return Error::kFunctionFailed; \
  316. } \
  317. } while(0)
  318. Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  319. {
  320. using Microsoft::WRL::ComPtr;
  321. // Lazyly load the DXC DLL
  322. {
  323. LockGuard lock(g_dxcLibMtx);
  324. if(g_dxcLib == 0)
  325. {
  326. // Init DXC
  327. g_dxcLib = LoadLibraryA(ANKI_SOURCE_DIRECTORY "/ThirdParty/Bin/Windows64/dxcompiler.dll");
  328. if(g_dxcLib == 0)
  329. {
  330. ANKI_SHADER_COMPILER_LOGE("dxcompiler.dll missing or wrong architecture");
  331. return Error::kFunctionFailed;
  332. }
  333. g_DxcCreateInstance = reinterpret_cast<DxcCreateInstanceProc>(GetProcAddress(g_dxcLib, "DxcCreateInstance"));
  334. if(g_DxcCreateInstance == nullptr)
  335. {
  336. ANKI_SHADER_COMPILER_LOGE("DxcCreateInstance was not found in the dxcompiler.dll");
  337. return Error::kFunctionFailed;
  338. }
  339. }
  340. }
  341. const Bool isLib = (type >= ShaderType::kFirstRayTracing && type <= ShaderType::kLastRayTracing) || type == ShaderType::kWorkGraph;
  342. ComPtr<IDxcUtils> utils;
  343. ANKI_REFL_CHECK(g_DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&utils)));
  344. ComPtr<ID3D12ShaderReflection> dxRefl;
  345. ComPtr<ID3D12LibraryReflection> libRefl;
  346. ShaderCompilerDynamicArray<ID3D12FunctionReflection*> funcReflections;
  347. D3D12_SHADER_DESC shaderDesc = {};
  348. if(isLib)
  349. {
  350. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  351. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&libRefl)));
  352. D3D12_LIBRARY_DESC libDesc = {};
  353. libRefl->GetDesc(&libDesc);
  354. if(libDesc.FunctionCount == 0)
  355. {
  356. errorStr.sprintf("Expecting at least 1 in D3D12_LIBRARY_DESC::FunctionCount");
  357. return Error::kUserData;
  358. }
  359. funcReflections.resize(libDesc.FunctionCount);
  360. for(U32 i = 0; i < libDesc.FunctionCount; ++i)
  361. {
  362. funcReflections[i] = libRefl->GetFunctionByIndex(i);
  363. }
  364. }
  365. else
  366. {
  367. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  368. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&dxRefl)));
  369. ANKI_REFL_CHECK(dxRefl->GetDesc(&shaderDesc));
  370. }
  371. for(U32 ifunc = 0; ifunc < ((isLib) ? funcReflections.getSize() : 1); ++ifunc)
  372. {
  373. U32 bindingCount;
  374. if(isLib)
  375. {
  376. D3D12_FUNCTION_DESC funcDesc;
  377. ANKI_REFL_CHECK(funcReflections[ifunc]->GetDesc(&funcDesc));
  378. bindingCount = funcDesc.BoundResources;
  379. }
  380. else
  381. {
  382. bindingCount = shaderDesc.BoundResources;
  383. }
  384. for(U32 i = 0; i < bindingCount; ++i)
  385. {
  386. D3D12_SHADER_INPUT_BIND_DESC bindDesc;
  387. if(isLib)
  388. {
  389. ANKI_REFL_CHECK(funcReflections[ifunc]->GetResourceBindingDesc(i, &bindDesc));
  390. }
  391. else
  392. {
  393. ANKI_REFL_CHECK(dxRefl->GetResourceBindingDesc(i, &bindDesc));
  394. }
  395. ShaderReflectionBinding akBinding;
  396. akBinding.m_type = DescriptorType::kCount;
  397. akBinding.m_flags = DescriptorFlag::kNone;
  398. akBinding.m_arraySize = U16(bindDesc.BindCount);
  399. akBinding.m_registerBindingPoint = bindDesc.BindPoint;
  400. if(bindDesc.Type == D3D_SIT_CBUFFER)
  401. {
  402. // ConstantBuffer
  403. if(bindDesc.Space == 3000 && bindDesc.BindPoint == 0)
  404. {
  405. // It's push/root constants
  406. ID3D12ShaderReflectionConstantBuffer* cbuffer =
  407. (isLib) ? funcReflections[ifunc]->GetConstantBufferByName(bindDesc.Name) : dxRefl->GetConstantBufferByName(bindDesc.Name);
  408. D3D12_SHADER_BUFFER_DESC desc;
  409. ANKI_REFL_CHECK(cbuffer->GetDesc(&desc));
  410. refl.m_descriptor.m_pushConstantsSize = desc.Size;
  411. continue;
  412. }
  413. akBinding.m_type = DescriptorType::kUniformBuffer;
  414. akBinding.m_flags = DescriptorFlag::kRead;
  415. }
  416. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  417. {
  418. // Texture2D etc
  419. akBinding.m_type = DescriptorType::kTexture;
  420. akBinding.m_flags = DescriptorFlag::kRead;
  421. }
  422. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  423. {
  424. // Buffer
  425. akBinding.m_type = DescriptorType::kTexelBuffer;
  426. akBinding.m_flags = DescriptorFlag::kRead;
  427. }
  428. else if(bindDesc.Type == D3D_SIT_SAMPLER)
  429. {
  430. // SamplerState
  431. akBinding.m_type = DescriptorType::kSampler;
  432. akBinding.m_flags = DescriptorFlag::kRead;
  433. }
  434. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  435. {
  436. // RWBuffer
  437. akBinding.m_type = DescriptorType::kTexelBuffer;
  438. akBinding.m_flags = DescriptorFlag::kReadWrite;
  439. }
  440. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  441. {
  442. // RWTexture2D etc
  443. akBinding.m_type = DescriptorType::kTexture;
  444. akBinding.m_flags = DescriptorFlag::kReadWrite;
  445. }
  446. else if(bindDesc.Type == D3D_SIT_BYTEADDRESS)
  447. {
  448. // ByteAddressBuffer
  449. akBinding.m_type = DescriptorType::kStorageBuffer;
  450. akBinding.m_flags = DescriptorFlag::kRead | DescriptorFlag::kByteAddressBuffer;
  451. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  452. }
  453. else if(bindDesc.Type == D3D_SIT_UAV_RWBYTEADDRESS)
  454. {
  455. // RWByteAddressBuffer
  456. akBinding.m_type = DescriptorType::kStorageBuffer;
  457. akBinding.m_flags = DescriptorFlag::kReadWrite | DescriptorFlag::kByteAddressBuffer;
  458. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  459. }
  460. else if(bindDesc.Type == D3D_SIT_RTACCELERATIONSTRUCTURE)
  461. {
  462. // RaytracingAccelerationStructure
  463. akBinding.m_type = DescriptorType::kAccelerationStructure;
  464. akBinding.m_flags = DescriptorFlag::kRead;
  465. }
  466. else if(bindDesc.Type == D3D_SIT_STRUCTURED)
  467. {
  468. // StructuredBuffer
  469. akBinding.m_type = DescriptorType::kStorageBuffer;
  470. akBinding.m_flags = DescriptorFlag::kRead;
  471. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  472. }
  473. else if(bindDesc.Type == D3D_SIT_UAV_RWSTRUCTURED)
  474. {
  475. // RWStructuredBuffer
  476. akBinding.m_type = DescriptorType::kStorageBuffer;
  477. akBinding.m_flags = DescriptorFlag::kReadWrite;
  478. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  479. }
  480. else
  481. {
  482. errorStr.sprintf("Unrecognized type for binding: %s", bindDesc.Name);
  483. return Error::kUserData;
  484. }
  485. Bool skip = false;
  486. if(isLib)
  487. {
  488. // Search if the binding exists because it may repeat
  489. for(U32 i = 0; i < refl.m_descriptor.m_bindingCounts[bindDesc.Space]; ++i)
  490. {
  491. if(refl.m_descriptor.m_bindings[bindDesc.Space][i] == akBinding)
  492. {
  493. skip = true;
  494. break;
  495. }
  496. }
  497. }
  498. if(!skip)
  499. {
  500. refl.m_descriptor.m_bindings[bindDesc.Space][refl.m_descriptor.m_bindingCounts[bindDesc.Space]] = akBinding;
  501. ++refl.m_descriptor.m_bindingCounts[bindDesc.Space];
  502. }
  503. }
  504. }
  505. for(U32 i = 0; i < kMaxDescriptorSets; ++i)
  506. {
  507. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  508. }
  509. if(type == ShaderType::kVertex)
  510. {
  511. for(U32 i = 0; i < shaderDesc.InputParameters; ++i)
  512. {
  513. D3D12_SIGNATURE_PARAMETER_DESC in;
  514. ANKI_REFL_CHECK(dxRefl->GetInputParameterDesc(i, &in));
  515. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  516. # define ANKI_ATTRIB_NAME(x, idx) CString(in.SemanticName) == # x&& in.SemanticIndex == idx
  517. if(ANKI_ATTRIB_NAME(POSITION, 0))
  518. {
  519. a = VertexAttributeSemantic::kPosition;
  520. }
  521. else if(ANKI_ATTRIB_NAME(NORMAL, 0))
  522. {
  523. a = VertexAttributeSemantic::kNormal;
  524. }
  525. else if(ANKI_ATTRIB_NAME(TEXCOORD, 0))
  526. {
  527. a = VertexAttributeSemantic::kTexCoord;
  528. }
  529. else if(ANKI_ATTRIB_NAME(COLOR, 0))
  530. {
  531. a = VertexAttributeSemantic::kColor;
  532. }
  533. else if(ANKI_ATTRIB_NAME(MISC, 0))
  534. {
  535. a = VertexAttributeSemantic::kMisc0;
  536. }
  537. else if(ANKI_ATTRIB_NAME(MISC, 1))
  538. {
  539. a = VertexAttributeSemantic::kMisc1;
  540. }
  541. else if(ANKI_ATTRIB_NAME(MISC, 2))
  542. {
  543. a = VertexAttributeSemantic::kMisc2;
  544. }
  545. else if(ANKI_ATTRIB_NAME(MISC, 3))
  546. {
  547. a = VertexAttributeSemantic::kMisc3;
  548. }
  549. else if(ANKI_ATTRIB_NAME(SV_VERTEXID, 0) || ANKI_ATTRIB_NAME(SV_INSTANCEID, 0))
  550. {
  551. // Ignore
  552. continue;
  553. }
  554. else
  555. {
  556. errorStr.sprintf("Unexpected attribute name: %s", in.SemanticName);
  557. return Error::kUserData;
  558. }
  559. # undef ANKI_ATTRIB_NAME
  560. refl.m_vertex.m_vertexAttributeMask.set(a);
  561. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(i); // Just set something
  562. }
  563. }
  564. if(type == ShaderType::kFragment)
  565. {
  566. for(U32 i = 0; i < shaderDesc.OutputParameters; ++i)
  567. {
  568. D3D12_SIGNATURE_PARAMETER_DESC desc;
  569. ANKI_REFL_CHECK(dxRefl->GetOutputParameterDesc(i, &desc));
  570. if(CString(desc.SemanticName) == "SV_TARGET")
  571. {
  572. refl.m_fragment.m_colorAttachmentWritemask.set(desc.SemanticIndex);
  573. }
  574. }
  575. }
  576. return Error::kNone;
  577. }
  578. #endif // #if ANKI_DIXL_REFLECTION
  579. static void compileVariantAsync(const ShaderParser& parser, Bool spirv, Bool debugInfo, ShaderBinaryMutation& mutation,
  580. ShaderCompilerDynamicArray<ShaderBinaryVariant>& variants,
  581. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>& codeBlocks, ShaderCompilerDynamicArray<U64>& sourceCodeHashes,
  582. ShaderCompilerAsyncTaskInterface& taskManager, Mutex& mtx, Atomic<I32>& error)
  583. {
  584. class Ctx
  585. {
  586. public:
  587. const ShaderParser* m_parser;
  588. ShaderBinaryMutation* m_mutation;
  589. ShaderCompilerDynamicArray<ShaderBinaryVariant>* m_variants;
  590. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>* m_codeBlocks;
  591. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  592. Mutex* m_mtx;
  593. Atomic<I32>* m_err;
  594. Bool m_spirv;
  595. Bool m_debugInfo;
  596. };
  597. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  598. ctx->m_parser = &parser;
  599. ctx->m_mutation = &mutation;
  600. ctx->m_variants = &variants;
  601. ctx->m_codeBlocks = &codeBlocks;
  602. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  603. ctx->m_mtx = &mtx;
  604. ctx->m_err = &error;
  605. ctx->m_spirv = spirv;
  606. ctx->m_debugInfo = debugInfo;
  607. auto callback = [](void* userData) {
  608. Ctx& ctx = *static_cast<Ctx*>(userData);
  609. class Cleanup
  610. {
  611. public:
  612. Ctx* m_ctx;
  613. ~Cleanup()
  614. {
  615. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  616. }
  617. } cleanup{&ctx};
  618. if(ctx.m_err->load() != 0)
  619. {
  620. // Don't bother
  621. return;
  622. }
  623. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  624. // Compile the sources
  625. ShaderCompilerDynamicArray<ShaderBinaryTechniqueCodeBlocks> codeBlockIndices;
  626. codeBlockIndices.resize(techniqueCount);
  627. for(auto& it : codeBlockIndices)
  628. {
  629. it.m_codeBlockIndices.fill(kMaxU32);
  630. }
  631. ShaderCompilerString compilerErrorLog;
  632. Error err = Error::kNone;
  633. U newCodeBlockCount = 0;
  634. for(U32 t = 0; t < techniqueCount && !err; ++t)
  635. {
  636. const ShaderParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  637. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  638. {
  639. ShaderCompilerString source;
  640. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  641. // Check if the source code was found before
  642. const U64 sourceCodeHash = source.computeHash();
  643. if(technique.m_activeMutators[shaderType] != kMaxU64)
  644. {
  645. LockGuard lock(*ctx.m_mtx);
  646. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  647. Bool found = false;
  648. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  649. {
  650. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  651. {
  652. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  653. found = true;
  654. break;
  655. }
  656. }
  657. if(found)
  658. {
  659. continue;
  660. }
  661. }
  662. ShaderCompilerDynamicArray<U8> il;
  663. if(ctx.m_spirv)
  664. {
  665. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo,
  666. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  667. }
  668. else
  669. {
  670. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo,
  671. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  672. }
  673. if(err)
  674. {
  675. break;
  676. }
  677. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  678. ShaderReflection refl;
  679. if(ctx.m_spirv)
  680. {
  681. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  682. }
  683. else
  684. {
  685. #if ANKI_DIXL_REFLECTION
  686. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  687. #else
  688. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  689. err = Error::kFunctionFailed;
  690. #endif
  691. }
  692. if(err)
  693. {
  694. break;
  695. }
  696. // Add the binary if not already there
  697. {
  698. LockGuard lock(*ctx.m_mtx);
  699. Bool found = false;
  700. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  701. {
  702. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  703. {
  704. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  705. found = true;
  706. break;
  707. }
  708. }
  709. if(!found)
  710. {
  711. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  712. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  713. il.moveAndReset(codeBlock.m_binary);
  714. codeBlock.m_hash = newHash;
  715. codeBlock.m_reflection = refl;
  716. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  717. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  718. ++newCodeBlockCount;
  719. }
  720. }
  721. }
  722. }
  723. if(err)
  724. {
  725. I32 expectedErr = 0;
  726. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  727. if(isFirstError)
  728. {
  729. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  730. return;
  731. }
  732. return;
  733. }
  734. // Do variant stuff
  735. {
  736. LockGuard lock(*ctx.m_mtx);
  737. Bool createVariant = true;
  738. if(newCodeBlockCount == 0)
  739. {
  740. // No new code blocks generated, search all variants to see if there is a duplicate
  741. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  742. {
  743. Bool same = true;
  744. for(U32 t = 0; t < techniqueCount; ++t)
  745. {
  746. const ShaderBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  747. const ShaderBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  748. if(memcmp(&a, &b, sizeof(a)) != 0)
  749. {
  750. // Not the same
  751. same = false;
  752. break;
  753. }
  754. }
  755. if(same)
  756. {
  757. createVariant = false;
  758. ctx.m_mutation->m_variantIndex = i;
  759. break;
  760. }
  761. }
  762. }
  763. // Create a new variant
  764. if(createVariant)
  765. {
  766. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  767. ShaderBinaryVariant* variant = ctx.m_variants->emplaceBack();
  768. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  769. }
  770. }
  771. };
  772. taskManager.enqueueTask(callback, ctx);
  773. }
  774. static Error compileShaderProgramInternal(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  775. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager_,
  776. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderBinary*& binary)
  777. {
  778. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  779. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  780. for(const ShaderCompilerDefine& d : defines_)
  781. {
  782. defines.emplaceBack(d);
  783. }
  784. // Initialize the binary
  785. binary = newInstance<ShaderBinary>(memPool);
  786. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  787. // Parse source
  788. ShaderParser parser(fname, &fsystem, defines);
  789. ANKI_CHECK(parser.parse());
  790. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  791. {
  792. return Error::kNone;
  793. }
  794. // Get mutators
  795. U32 mutationCount = 0;
  796. if(parser.getMutators().getSize() > 0)
  797. {
  798. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  799. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  800. {
  801. ShaderBinaryMutator& out = binary->m_mutators[i];
  802. const ShaderParserMutator& in = parser.getMutators()[i];
  803. zeroMemory(out);
  804. newArray(memPool, in.m_values.getSize(), out.m_values);
  805. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  806. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  807. // Update the count
  808. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  809. }
  810. }
  811. else
  812. {
  813. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  814. }
  815. // Create all variants
  816. Mutex mtx;
  817. Atomic<I32> errorAtomic(0);
  818. class SyncronousShaderCompilerAsyncTaskInterface : public ShaderCompilerAsyncTaskInterface
  819. {
  820. public:
  821. void enqueueTask(void (*callback)(void* userData), void* userData) final
  822. {
  823. callback(userData);
  824. }
  825. Error joinTasks() final
  826. {
  827. // Nothing
  828. return Error::kNone;
  829. }
  830. } syncTaskManager;
  831. ShaderCompilerAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  832. if(parser.getMutators().getSize() > 0)
  833. {
  834. // Initialize
  835. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  836. mutationValues.resize(parser.getMutators().getSize());
  837. ShaderCompilerDynamicArray<U32> dials;
  838. dials.resize(parser.getMutators().getSize(), 0);
  839. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  840. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  841. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  842. ShaderCompilerDynamicArray<ShaderBinaryMutation> mutations;
  843. mutations.resize(mutationCount);
  844. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  845. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  846. variants.resizeStorage(mutationCount);
  847. mutationCount = 0;
  848. // Spin for all possible combinations of mutators and
  849. // - Create the spirv
  850. // - Populate the binary variant
  851. do
  852. {
  853. // Create the mutation
  854. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  855. {
  856. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  857. }
  858. ShaderBinaryMutation& mutation = mutations[mutationCount++];
  859. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  860. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  861. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  862. ANKI_ASSERT(mutation.m_hash > 0);
  863. if(parser.skipMutation(mutationValues))
  864. {
  865. mutation.m_variantIndex = kMaxU32;
  866. }
  867. else
  868. {
  869. // New and unique mutation and thus variant, add it
  870. compileVariantAsync(parser, spirv, debugInfo, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  871. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  872. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  873. }
  874. } while(!spinDials(dials, parser.getMutators()));
  875. ANKI_ASSERT(mutationCount == mutations.getSize());
  876. // Done, wait the threads
  877. ANKI_CHECK(taskManager.joinTasks());
  878. // Now error out
  879. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  880. // Store temp containers to binary
  881. codeBlocks.moveAndReset(binary->m_codeBlocks);
  882. mutations.moveAndReset(binary->m_mutations);
  883. variants.moveAndReset(binary->m_variants);
  884. }
  885. else
  886. {
  887. newArray(memPool, 1, binary->m_mutations);
  888. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  889. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  890. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  891. compileVariantAsync(parser, spirv, debugInfo, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  892. ANKI_CHECK(taskManager.joinTasks());
  893. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  894. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  895. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  896. ANKI_ASSERT(variants.getSize() == 1);
  897. binary->m_mutations[0].m_hash = 1;
  898. codeBlocks.moveAndReset(binary->m_codeBlocks);
  899. variants.moveAndReset(binary->m_variants);
  900. }
  901. // Sort the mutations
  902. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(), [](const ShaderBinaryMutation& a, const ShaderBinaryMutation& b) {
  903. return a.m_hash < b.m_hash;
  904. });
  905. // Techniques
  906. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  907. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  908. {
  909. zeroMemory(binary->m_techniques[i].m_name);
  910. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  911. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  912. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  913. }
  914. // Structs
  915. if(parser.getGhostStructs().getSize())
  916. {
  917. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  918. }
  919. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  920. {
  921. const ShaderParserGhostStruct& in = parser.getGhostStructs()[i];
  922. ShaderBinaryStruct& out = binary->m_structs[i];
  923. zeroMemory(out);
  924. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  925. ANKI_ASSERT(in.m_members.getSize());
  926. newArray(memPool, in.m_members.getSize(), out.m_members);
  927. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  928. {
  929. const ShaderParserGhostStructMember& inm = in.m_members[j];
  930. ShaderBinaryStructMember& outm = out.m_members[j];
  931. zeroMemory(outm.m_name);
  932. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  933. outm.m_offset = inm.m_offset;
  934. outm.m_type = inm.m_type;
  935. }
  936. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  937. }
  938. return Error::kNone;
  939. }
  940. Error compileShaderProgram(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  941. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager,
  942. ConstWeakArray<ShaderCompilerDefine> defines, ShaderBinary*& binary)
  943. {
  944. const Error err = compileShaderProgramInternal(fname, spirv, debugInfo, fsystem, postParseCallback, taskManager, defines, binary);
  945. if(err)
  946. {
  947. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  948. freeShaderBinary(binary);
  949. }
  950. return err;
  951. }
  952. } // end namespace anki