ShaderCompiler.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/Util/Serializer.h>
  9. #include <AnKi/Util/HashMap.h>
  10. #include <SpirvCross/spirv_cross.hpp>
  11. #define ANKI_DIXL_REFLECTION (ANKI_OS_WINDOWS)
  12. #if ANKI_DIXL_REFLECTION
  13. # include <windows.h>
  14. # include <ThirdParty/Dxc/dxcapi.h>
  15. # include <ThirdParty/Dxc/d3d12shader.h>
  16. # include <wrl.h>
  17. # include <AnKi/Util/CleanupWindows.h>
  18. #endif
  19. namespace anki {
  20. #if ANKI_DIXL_REFLECTION
  21. static HMODULE g_dxcLib = 0;
  22. static DxcCreateInstanceProc g_DxcCreateInstance = nullptr;
  23. static Mutex g_dxcLibMtx;
  24. #endif
  25. void freeShaderBinary(ShaderBinary*& binary)
  26. {
  27. if(binary == nullptr)
  28. {
  29. return;
  30. }
  31. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  32. for(ShaderBinaryCodeBlock& code : binary->m_codeBlocks)
  33. {
  34. mempool.free(code.m_binary.getBegin());
  35. }
  36. mempool.free(binary->m_codeBlocks.getBegin());
  37. for(ShaderBinaryMutator& mutator : binary->m_mutators)
  38. {
  39. mempool.free(mutator.m_values.getBegin());
  40. }
  41. mempool.free(binary->m_mutators.getBegin());
  42. for(ShaderBinaryMutation& m : binary->m_mutations)
  43. {
  44. mempool.free(m.m_values.getBegin());
  45. }
  46. mempool.free(binary->m_mutations.getBegin());
  47. for(ShaderBinaryVariant& variant : binary->m_variants)
  48. {
  49. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  50. }
  51. mempool.free(binary->m_variants.getBegin());
  52. mempool.free(binary->m_techniques.getBegin());
  53. for(ShaderBinaryStruct& s : binary->m_structs)
  54. {
  55. mempool.free(s.m_members.getBegin());
  56. }
  57. mempool.free(binary->m_structs.getBegin());
  58. mempool.free(binary);
  59. binary = nullptr;
  60. }
  61. /// Spin the dials. Used to compute all mutator combinations.
  62. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderParserMutator> mutators)
  63. {
  64. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  65. Bool done = true;
  66. U32 crntDial = dials.getSize() - 1;
  67. while(true)
  68. {
  69. // Turn dial
  70. ++dials[crntDial];
  71. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  72. {
  73. if(crntDial == 0)
  74. {
  75. // Reached the 1st dial, stop spinning
  76. done = true;
  77. break;
  78. }
  79. else
  80. {
  81. dials[crntDial] = 0;
  82. --crntDial;
  83. }
  84. }
  85. else
  86. {
  87. done = false;
  88. break;
  89. }
  90. }
  91. return done;
  92. }
  93. /// Does SPIR-V reflection and re-writes the SPIR-V binary's bindings
  94. Error doReflectionSpirv(ConstWeakArray<U8> spirv, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  95. {
  96. spirv_cross::Compiler spvc(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32));
  97. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  98. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  99. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, const DescriptorType origType) -> Error {
  100. for(const spirv_cross::Resource& r : resources)
  101. {
  102. const U32 id = r.id;
  103. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  104. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  105. if(set >= kMaxRegisterSpaces && set != kDxcVkBindlessRegisterSpace)
  106. {
  107. errorStr.sprintf("Exceeded set for: %s", r.name.c_str());
  108. return Error::kUserData;
  109. }
  110. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  111. U32 arraySize = 1;
  112. if(typeInfo.array.size() != 0)
  113. {
  114. if(typeInfo.array.size() != 1)
  115. {
  116. errorStr.sprintf("Only 1D arrays are supported: %s", r.name.c_str());
  117. return Error::kUserData;
  118. }
  119. if(set == kDxcVkBindlessRegisterSpace && typeInfo.array[0] != 0)
  120. {
  121. errorStr.sprintf("Only the bindless descriptor set can be an unbound array: %s", r.name.c_str());
  122. return Error::kUserData;
  123. }
  124. arraySize = typeInfo.array[0];
  125. }
  126. // Images are special, they might be texel buffers
  127. DescriptorType type = origType;
  128. if((type == DescriptorType::kSrvTexture || type == DescriptorType::kUavTexture) && typeInfo.image.dim == spv::DimBuffer)
  129. {
  130. if(typeInfo.image.sampled == 1)
  131. {
  132. type = DescriptorType::kSrvTexelBuffer;
  133. }
  134. else
  135. {
  136. ANKI_ASSERT(typeInfo.image.sampled == 2);
  137. type = DescriptorType::kUavTexelBuffer;
  138. }
  139. }
  140. if(set == kDxcVkBindlessRegisterSpace)
  141. {
  142. // Bindless dset
  143. if(arraySize != 0)
  144. {
  145. errorStr.sprintf("Unexpected unbound array for bindless: %s", r.name.c_str());
  146. }
  147. if(type != DescriptorType::kSrvTexture)
  148. {
  149. errorStr.sprintf("Unexpected bindless binding: %s", r.name.c_str());
  150. return Error::kUserData;
  151. }
  152. refl.m_descriptor.m_hasVkBindlessDescriptorSet = true;
  153. }
  154. else
  155. {
  156. // Regular binding
  157. // Use the binding to find out if it's a read or write storage buffer, there is no other way
  158. if(origType == DescriptorType::kSrvStructuredBuffer
  159. && (binding < kDxcVkBindingShifts[set][HlslResourceType::kSrv]
  160. || binding >= kDxcVkBindingShifts[set][HlslResourceType::kSrv] + 1000))
  161. {
  162. type = DescriptorType::kUavStructuredBuffer;
  163. }
  164. const HlslResourceType hlslResourceType = descriptorTypeToHlslResourceType(type);
  165. if(binding < kDxcVkBindingShifts[set][hlslResourceType] || binding >= kDxcVkBindingShifts[set][hlslResourceType] + 1000)
  166. {
  167. errorStr.sprintf("Unexpected binding: %s", r.name.c_str());
  168. return Error::kUserData;
  169. }
  170. ShaderReflectionBinding akBinding;
  171. akBinding.m_registerBindingPoint = binding - kDxcVkBindingShifts[set][hlslResourceType];
  172. akBinding.m_arraySize = U16(arraySize);
  173. akBinding.m_type = type;
  174. refl.m_descriptor.m_bindings[set][refl.m_descriptor.m_bindingCounts[set]] = akBinding;
  175. ++refl.m_descriptor.m_bindingCounts[set];
  176. }
  177. }
  178. return Error::kNone;
  179. };
  180. Error err = func(rsrc.uniform_buffers, DescriptorType::kConstantBuffer);
  181. if(err)
  182. {
  183. return err;
  184. }
  185. err = func(rsrc.separate_images, DescriptorType::kSrvTexture); // This also handles texel buffers
  186. if(err)
  187. {
  188. return err;
  189. }
  190. err = func(rsrc.separate_samplers, DescriptorType::kSampler);
  191. if(err)
  192. {
  193. return err;
  194. }
  195. err = func(rsrc.storage_buffers, DescriptorType::kSrvStructuredBuffer);
  196. if(err)
  197. {
  198. return err;
  199. }
  200. err = func(rsrc.storage_images, DescriptorType::kUavTexture);
  201. if(err)
  202. {
  203. return err;
  204. }
  205. err = func(rsrc.acceleration_structures, DescriptorType::kAccelerationStructure);
  206. if(err)
  207. {
  208. return err;
  209. }
  210. for(U32 i = 0; i < kMaxRegisterSpaces; ++i)
  211. {
  212. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  213. }
  214. // Color attachments
  215. if(type == ShaderType::kPixel)
  216. {
  217. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  218. {
  219. const U32 id = r.id;
  220. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  221. refl.m_pixel.m_colorRenderTargetWritemask.set(location);
  222. }
  223. }
  224. // Push consts
  225. if(rsrc.push_constant_buffers.size() == 1)
  226. {
  227. const U32 blockSize = U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  228. if(blockSize == 0 || (blockSize % 16) != 0 || blockSize > kMaxFastConstantsSize)
  229. {
  230. errorStr.sprintf("Incorrect push constants size");
  231. return Error::kUserData;
  232. }
  233. refl.m_descriptor.m_fastConstantsSize = blockSize;
  234. }
  235. // Attribs
  236. if(type == ShaderType::kVertex)
  237. {
  238. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  239. {
  240. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  241. #define ANKI_ATTRIB_NAME(x) "in.var." #x
  242. if(r.name == ANKI_ATTRIB_NAME(POSITION))
  243. {
  244. a = VertexAttributeSemantic::kPosition;
  245. }
  246. else if(r.name == ANKI_ATTRIB_NAME(NORMAL))
  247. {
  248. a = VertexAttributeSemantic::kNormal;
  249. }
  250. else if(r.name == ANKI_ATTRIB_NAME(TEXCOORD0) || r.name == ANKI_ATTRIB_NAME(TEXCOORD))
  251. {
  252. a = VertexAttributeSemantic::kTexCoord;
  253. }
  254. else if(r.name == ANKI_ATTRIB_NAME(COLOR))
  255. {
  256. a = VertexAttributeSemantic::kColor;
  257. }
  258. else if(r.name == ANKI_ATTRIB_NAME(MISC0) || r.name == ANKI_ATTRIB_NAME(MISC))
  259. {
  260. a = VertexAttributeSemantic::kMisc0;
  261. }
  262. else if(r.name == ANKI_ATTRIB_NAME(MISC1))
  263. {
  264. a = VertexAttributeSemantic::kMisc1;
  265. }
  266. else if(r.name == ANKI_ATTRIB_NAME(MISC2))
  267. {
  268. a = VertexAttributeSemantic::kMisc2;
  269. }
  270. else if(r.name == ANKI_ATTRIB_NAME(MISC3))
  271. {
  272. a = VertexAttributeSemantic::kMisc3;
  273. }
  274. else
  275. {
  276. errorStr.sprintf("Unexpected attribute name: %s", r.name.c_str());
  277. return Error::kUserData;
  278. }
  279. #undef ANKI_ATTRIB_NAME
  280. refl.m_vertex.m_vertexAttributeMask.set(a);
  281. const U32 id = r.id;
  282. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  283. if(location > kMaxU8)
  284. {
  285. errorStr.sprintf("Too high location value for attribute: %s", r.name.c_str());
  286. return Error::kUserData;
  287. }
  288. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(location);
  289. }
  290. }
  291. // Discards?
  292. if(type == ShaderType::kPixel)
  293. {
  294. visitSpirv(ConstWeakArray<U32>(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32)),
  295. [&](U32 cmd, [[maybe_unused]] ConstWeakArray<U32> instructions) {
  296. if(cmd == spv::OpKill)
  297. {
  298. refl.m_pixel.m_discards = true;
  299. }
  300. });
  301. }
  302. return Error::kNone;
  303. }
  304. #if ANKI_DIXL_REFLECTION
  305. # define ANKI_REFL_CHECK(x) \
  306. do \
  307. { \
  308. HRESULT rez; \
  309. if((rez = (x)) < 0) [[unlikely]] \
  310. { \
  311. errorStr.sprintf("DXC function failed (HRESULT: %d): %s", rez, #x); \
  312. return Error::kFunctionFailed; \
  313. } \
  314. } while(0)
  315. Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  316. {
  317. using Microsoft::WRL::ComPtr;
  318. // Lazyly load the DXC DLL
  319. {
  320. LockGuard lock(g_dxcLibMtx);
  321. if(g_dxcLib == 0)
  322. {
  323. // Init DXC
  324. g_dxcLib = LoadLibraryA(ANKI_SOURCE_DIRECTORY "/ThirdParty/Bin/Windows64/dxcompiler.dll");
  325. if(g_dxcLib == 0)
  326. {
  327. ANKI_SHADER_COMPILER_LOGE("dxcompiler.dll missing or wrong architecture");
  328. return Error::kFunctionFailed;
  329. }
  330. g_DxcCreateInstance = reinterpret_cast<DxcCreateInstanceProc>(GetProcAddress(g_dxcLib, "DxcCreateInstance"));
  331. if(g_DxcCreateInstance == nullptr)
  332. {
  333. ANKI_SHADER_COMPILER_LOGE("DxcCreateInstance was not found in the dxcompiler.dll");
  334. return Error::kFunctionFailed;
  335. }
  336. }
  337. }
  338. const Bool isLib = (type >= ShaderType::kFirstRayTracing && type <= ShaderType::kLastRayTracing) || type == ShaderType::kWorkGraph;
  339. ComPtr<IDxcUtils> utils;
  340. ANKI_REFL_CHECK(g_DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&utils)));
  341. ComPtr<ID3D12ShaderReflection> dxRefl;
  342. ComPtr<ID3D12LibraryReflection> libRefl;
  343. ShaderCompilerDynamicArray<ID3D12FunctionReflection*> funcReflections;
  344. D3D12_SHADER_DESC shaderDesc = {};
  345. if(isLib)
  346. {
  347. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  348. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&libRefl)));
  349. D3D12_LIBRARY_DESC libDesc = {};
  350. libRefl->GetDesc(&libDesc);
  351. if(libDesc.FunctionCount == 0)
  352. {
  353. errorStr.sprintf("Expecting at least 1 in D3D12_LIBRARY_DESC::FunctionCount");
  354. return Error::kUserData;
  355. }
  356. funcReflections.resize(libDesc.FunctionCount);
  357. for(U32 i = 0; i < libDesc.FunctionCount; ++i)
  358. {
  359. funcReflections[i] = libRefl->GetFunctionByIndex(i);
  360. }
  361. }
  362. else
  363. {
  364. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  365. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&dxRefl)));
  366. ANKI_REFL_CHECK(dxRefl->GetDesc(&shaderDesc));
  367. }
  368. for(U32 ifunc = 0; ifunc < ((isLib) ? funcReflections.getSize() : 1); ++ifunc)
  369. {
  370. U32 bindingCount;
  371. if(isLib)
  372. {
  373. D3D12_FUNCTION_DESC funcDesc;
  374. ANKI_REFL_CHECK(funcReflections[ifunc]->GetDesc(&funcDesc));
  375. bindingCount = funcDesc.BoundResources;
  376. }
  377. else
  378. {
  379. bindingCount = shaderDesc.BoundResources;
  380. }
  381. for(U32 i = 0; i < bindingCount; ++i)
  382. {
  383. D3D12_SHADER_INPUT_BIND_DESC bindDesc;
  384. if(isLib)
  385. {
  386. ANKI_REFL_CHECK(funcReflections[ifunc]->GetResourceBindingDesc(i, &bindDesc));
  387. }
  388. else
  389. {
  390. ANKI_REFL_CHECK(dxRefl->GetResourceBindingDesc(i, &bindDesc));
  391. }
  392. ShaderReflectionBinding akBinding;
  393. akBinding.m_type = DescriptorType::kCount;
  394. akBinding.m_arraySize = U16(bindDesc.BindCount);
  395. akBinding.m_registerBindingPoint = bindDesc.BindPoint;
  396. if(bindDesc.Type == D3D_SIT_CBUFFER)
  397. {
  398. // ConstantBuffer
  399. if(bindDesc.Space == 3000 && bindDesc.BindPoint == 0)
  400. {
  401. // It's push/root constants
  402. ID3D12ShaderReflectionConstantBuffer* cbuffer =
  403. (isLib) ? funcReflections[ifunc]->GetConstantBufferByName(bindDesc.Name) : dxRefl->GetConstantBufferByName(bindDesc.Name);
  404. D3D12_SHADER_BUFFER_DESC desc;
  405. ANKI_REFL_CHECK(cbuffer->GetDesc(&desc));
  406. refl.m_descriptor.m_fastConstantsSize = desc.Size;
  407. continue;
  408. }
  409. akBinding.m_type = DescriptorType::kConstantBuffer;
  410. }
  411. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  412. {
  413. // Texture2D etc
  414. akBinding.m_type = DescriptorType::kSrvTexture;
  415. }
  416. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  417. {
  418. // Buffer
  419. akBinding.m_type = DescriptorType::kSrvTexelBuffer;
  420. }
  421. else if(bindDesc.Type == D3D_SIT_SAMPLER)
  422. {
  423. // SamplerState
  424. akBinding.m_type = DescriptorType::kSampler;
  425. }
  426. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  427. {
  428. // RWBuffer
  429. akBinding.m_type = DescriptorType::kUavTexelBuffer;
  430. }
  431. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  432. {
  433. // RWTexture2D etc
  434. akBinding.m_type = DescriptorType::kUavTexture;
  435. }
  436. else if(bindDesc.Type == D3D_SIT_BYTEADDRESS)
  437. {
  438. // ByteAddressBuffer
  439. akBinding.m_type = DescriptorType::kSrvByteAddressBuffer;
  440. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  441. }
  442. else if(bindDesc.Type == D3D_SIT_UAV_RWBYTEADDRESS)
  443. {
  444. // RWByteAddressBuffer
  445. akBinding.m_type = DescriptorType::kUavByteAddressBuffer;
  446. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  447. }
  448. else if(bindDesc.Type == D3D_SIT_RTACCELERATIONSTRUCTURE)
  449. {
  450. // RaytracingAccelerationStructure
  451. akBinding.m_type = DescriptorType::kAccelerationStructure;
  452. }
  453. else if(bindDesc.Type == D3D_SIT_STRUCTURED)
  454. {
  455. // StructuredBuffer
  456. akBinding.m_type = DescriptorType::kSrvStructuredBuffer;
  457. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  458. }
  459. else if(bindDesc.Type == D3D_SIT_UAV_RWSTRUCTURED)
  460. {
  461. // RWStructuredBuffer
  462. akBinding.m_type = DescriptorType::kUavStructuredBuffer;
  463. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  464. }
  465. else
  466. {
  467. errorStr.sprintf("Unrecognized type for binding: %s", bindDesc.Name);
  468. return Error::kUserData;
  469. }
  470. Bool skip = false;
  471. if(isLib)
  472. {
  473. // Search if the binding exists because it may repeat
  474. for(U32 i = 0; i < refl.m_descriptor.m_bindingCounts[bindDesc.Space]; ++i)
  475. {
  476. if(refl.m_descriptor.m_bindings[bindDesc.Space][i] == akBinding)
  477. {
  478. skip = true;
  479. break;
  480. }
  481. }
  482. }
  483. if(!skip)
  484. {
  485. refl.m_descriptor.m_bindings[bindDesc.Space][refl.m_descriptor.m_bindingCounts[bindDesc.Space]] = akBinding;
  486. ++refl.m_descriptor.m_bindingCounts[bindDesc.Space];
  487. }
  488. }
  489. }
  490. for(U32 i = 0; i < kMaxRegisterSpaces; ++i)
  491. {
  492. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  493. }
  494. if(type == ShaderType::kVertex)
  495. {
  496. for(U32 i = 0; i < shaderDesc.InputParameters; ++i)
  497. {
  498. D3D12_SIGNATURE_PARAMETER_DESC in;
  499. ANKI_REFL_CHECK(dxRefl->GetInputParameterDesc(i, &in));
  500. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  501. # define ANKI_ATTRIB_NAME(x, idx) CString(in.SemanticName) == # x&& in.SemanticIndex == idx
  502. if(ANKI_ATTRIB_NAME(POSITION, 0))
  503. {
  504. a = VertexAttributeSemantic::kPosition;
  505. }
  506. else if(ANKI_ATTRIB_NAME(NORMAL, 0))
  507. {
  508. a = VertexAttributeSemantic::kNormal;
  509. }
  510. else if(ANKI_ATTRIB_NAME(TEXCOORD, 0))
  511. {
  512. a = VertexAttributeSemantic::kTexCoord;
  513. }
  514. else if(ANKI_ATTRIB_NAME(COLOR, 0))
  515. {
  516. a = VertexAttributeSemantic::kColor;
  517. }
  518. else if(ANKI_ATTRIB_NAME(MISC, 0))
  519. {
  520. a = VertexAttributeSemantic::kMisc0;
  521. }
  522. else if(ANKI_ATTRIB_NAME(MISC, 1))
  523. {
  524. a = VertexAttributeSemantic::kMisc1;
  525. }
  526. else if(ANKI_ATTRIB_NAME(MISC, 2))
  527. {
  528. a = VertexAttributeSemantic::kMisc2;
  529. }
  530. else if(ANKI_ATTRIB_NAME(MISC, 3))
  531. {
  532. a = VertexAttributeSemantic::kMisc3;
  533. }
  534. else if(ANKI_ATTRIB_NAME(SV_VERTEXID, 0) || ANKI_ATTRIB_NAME(SV_INSTANCEID, 0))
  535. {
  536. // Ignore
  537. continue;
  538. }
  539. else
  540. {
  541. errorStr.sprintf("Unexpected attribute name: %s", in.SemanticName);
  542. return Error::kUserData;
  543. }
  544. # undef ANKI_ATTRIB_NAME
  545. refl.m_vertex.m_vertexAttributeMask.set(a);
  546. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(i); // Just set something
  547. }
  548. }
  549. if(type == ShaderType::kPixel)
  550. {
  551. for(U32 i = 0; i < shaderDesc.OutputParameters; ++i)
  552. {
  553. D3D12_SIGNATURE_PARAMETER_DESC desc;
  554. ANKI_REFL_CHECK(dxRefl->GetOutputParameterDesc(i, &desc));
  555. if(CString(desc.SemanticName) == "SV_TARGET")
  556. {
  557. refl.m_pixel.m_colorRenderTargetWritemask.set(desc.SemanticIndex);
  558. }
  559. }
  560. }
  561. return Error::kNone;
  562. }
  563. #endif // #if ANKI_DIXL_REFLECTION
  564. static void compileVariantAsync(const ShaderParser& parser, Bool spirv, Bool debugInfo, ShaderBinaryMutation& mutation,
  565. ShaderCompilerDynamicArray<ShaderBinaryVariant>& variants,
  566. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>& codeBlocks, ShaderCompilerDynamicArray<U64>& sourceCodeHashes,
  567. ShaderCompilerAsyncTaskInterface& taskManager, Mutex& mtx, Atomic<I32>& error)
  568. {
  569. class Ctx
  570. {
  571. public:
  572. const ShaderParser* m_parser;
  573. ShaderBinaryMutation* m_mutation;
  574. ShaderCompilerDynamicArray<ShaderBinaryVariant>* m_variants;
  575. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>* m_codeBlocks;
  576. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  577. Mutex* m_mtx;
  578. Atomic<I32>* m_err;
  579. Bool m_spirv;
  580. Bool m_debugInfo;
  581. };
  582. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  583. ctx->m_parser = &parser;
  584. ctx->m_mutation = &mutation;
  585. ctx->m_variants = &variants;
  586. ctx->m_codeBlocks = &codeBlocks;
  587. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  588. ctx->m_mtx = &mtx;
  589. ctx->m_err = &error;
  590. ctx->m_spirv = spirv;
  591. ctx->m_debugInfo = debugInfo;
  592. auto callback = [](void* userData) {
  593. Ctx& ctx = *static_cast<Ctx*>(userData);
  594. class Cleanup
  595. {
  596. public:
  597. Ctx* m_ctx;
  598. ~Cleanup()
  599. {
  600. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  601. }
  602. } cleanup{&ctx};
  603. if(ctx.m_err->load() != 0)
  604. {
  605. // Don't bother
  606. return;
  607. }
  608. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  609. // Compile the sources
  610. ShaderCompilerDynamicArray<ShaderBinaryTechniqueCodeBlocks> codeBlockIndices;
  611. codeBlockIndices.resize(techniqueCount);
  612. for(auto& it : codeBlockIndices)
  613. {
  614. it.m_codeBlockIndices.fill(kMaxU32);
  615. }
  616. ShaderCompilerString compilerErrorLog;
  617. Error err = Error::kNone;
  618. U newCodeBlockCount = 0;
  619. for(U32 t = 0; t < techniqueCount && !err; ++t)
  620. {
  621. const ShaderParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  622. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  623. {
  624. ShaderCompilerString source;
  625. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  626. // Check if the source code was found before
  627. const U64 sourceCodeHash = source.computeHash();
  628. if(technique.m_activeMutators[shaderType] != kMaxU64)
  629. {
  630. LockGuard lock(*ctx.m_mtx);
  631. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  632. Bool found = false;
  633. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  634. {
  635. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  636. {
  637. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  638. found = true;
  639. break;
  640. }
  641. }
  642. if(found)
  643. {
  644. continue;
  645. }
  646. }
  647. ShaderCompilerDynamicArray<U8> il;
  648. if(ctx.m_spirv)
  649. {
  650. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo,
  651. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  652. }
  653. else
  654. {
  655. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo,
  656. ctx.m_parser->getExtraCompilerArgs(), il, compilerErrorLog);
  657. }
  658. if(err)
  659. {
  660. break;
  661. }
  662. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  663. ShaderReflection refl;
  664. if(ctx.m_spirv)
  665. {
  666. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  667. }
  668. else
  669. {
  670. #if ANKI_DIXL_REFLECTION
  671. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  672. #else
  673. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  674. err = Error::kFunctionFailed;
  675. #endif
  676. }
  677. if(err)
  678. {
  679. break;
  680. }
  681. // Add the binary if not already there
  682. {
  683. LockGuard lock(*ctx.m_mtx);
  684. Bool found = false;
  685. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  686. {
  687. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  688. {
  689. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  690. found = true;
  691. break;
  692. }
  693. }
  694. if(!found)
  695. {
  696. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  697. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  698. il.moveAndReset(codeBlock.m_binary);
  699. codeBlock.m_hash = newHash;
  700. codeBlock.m_reflection = refl;
  701. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  702. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  703. ++newCodeBlockCount;
  704. }
  705. }
  706. }
  707. }
  708. if(err)
  709. {
  710. I32 expectedErr = 0;
  711. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  712. if(isFirstError)
  713. {
  714. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  715. return;
  716. }
  717. return;
  718. }
  719. // Do variant stuff
  720. {
  721. LockGuard lock(*ctx.m_mtx);
  722. Bool createVariant = true;
  723. if(newCodeBlockCount == 0)
  724. {
  725. // No new code blocks generated, search all variants to see if there is a duplicate
  726. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  727. {
  728. Bool same = true;
  729. for(U32 t = 0; t < techniqueCount; ++t)
  730. {
  731. const ShaderBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  732. const ShaderBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  733. if(memcmp(&a, &b, sizeof(a)) != 0)
  734. {
  735. // Not the same
  736. same = false;
  737. break;
  738. }
  739. }
  740. if(same)
  741. {
  742. createVariant = false;
  743. ctx.m_mutation->m_variantIndex = i;
  744. break;
  745. }
  746. }
  747. }
  748. // Create a new variant
  749. if(createVariant)
  750. {
  751. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  752. ShaderBinaryVariant* variant = ctx.m_variants->emplaceBack();
  753. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  754. }
  755. }
  756. };
  757. taskManager.enqueueTask(callback, ctx);
  758. }
  759. static Error compileShaderProgramInternal(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  760. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager_,
  761. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderBinary*& binary)
  762. {
  763. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  764. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  765. for(const ShaderCompilerDefine& d : defines_)
  766. {
  767. defines.emplaceBack(d);
  768. }
  769. // Initialize the binary
  770. binary = newInstance<ShaderBinary>(memPool);
  771. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  772. // Parse source
  773. ShaderParser parser(fname, &fsystem, defines);
  774. ANKI_CHECK(parser.parse());
  775. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  776. {
  777. return Error::kNone;
  778. }
  779. // Get mutators
  780. U32 mutationCount = 0;
  781. if(parser.getMutators().getSize() > 0)
  782. {
  783. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  784. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  785. {
  786. ShaderBinaryMutator& out = binary->m_mutators[i];
  787. const ShaderParserMutator& in = parser.getMutators()[i];
  788. zeroMemory(out);
  789. newArray(memPool, in.m_values.getSize(), out.m_values);
  790. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  791. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  792. // Update the count
  793. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  794. }
  795. }
  796. else
  797. {
  798. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  799. }
  800. // Create all variants
  801. Mutex mtx;
  802. Atomic<I32> errorAtomic(0);
  803. class SyncronousShaderCompilerAsyncTaskInterface : public ShaderCompilerAsyncTaskInterface
  804. {
  805. public:
  806. void enqueueTask(void (*callback)(void* userData), void* userData) final
  807. {
  808. callback(userData);
  809. }
  810. Error joinTasks() final
  811. {
  812. // Nothing
  813. return Error::kNone;
  814. }
  815. } syncTaskManager;
  816. ShaderCompilerAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  817. if(parser.getMutators().getSize() > 0)
  818. {
  819. // Initialize
  820. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  821. mutationValues.resize(parser.getMutators().getSize());
  822. ShaderCompilerDynamicArray<U32> dials;
  823. dials.resize(parser.getMutators().getSize(), 0);
  824. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  825. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  826. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  827. ShaderCompilerDynamicArray<ShaderBinaryMutation> mutations;
  828. mutations.resize(mutationCount);
  829. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  830. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  831. variants.resizeStorage(mutationCount);
  832. mutationCount = 0;
  833. // Spin for all possible combinations of mutators and
  834. // - Create the spirv
  835. // - Populate the binary variant
  836. do
  837. {
  838. // Create the mutation
  839. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  840. {
  841. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  842. }
  843. ShaderBinaryMutation& mutation = mutations[mutationCount++];
  844. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  845. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  846. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  847. ANKI_ASSERT(mutation.m_hash > 0);
  848. if(parser.skipMutation(mutationValues))
  849. {
  850. mutation.m_variantIndex = kMaxU32;
  851. }
  852. else
  853. {
  854. // New and unique mutation and thus variant, add it
  855. compileVariantAsync(parser, spirv, debugInfo, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  856. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  857. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  858. }
  859. } while(!spinDials(dials, parser.getMutators()));
  860. ANKI_ASSERT(mutationCount == mutations.getSize());
  861. // Done, wait the threads
  862. ANKI_CHECK(taskManager.joinTasks());
  863. // Now error out
  864. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  865. // Store temp containers to binary
  866. codeBlocks.moveAndReset(binary->m_codeBlocks);
  867. mutations.moveAndReset(binary->m_mutations);
  868. variants.moveAndReset(binary->m_variants);
  869. }
  870. else
  871. {
  872. newArray(memPool, 1, binary->m_mutations);
  873. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  874. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  875. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  876. compileVariantAsync(parser, spirv, debugInfo, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  877. ANKI_CHECK(taskManager.joinTasks());
  878. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  879. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  880. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  881. ANKI_ASSERT(variants.getSize() == 1);
  882. binary->m_mutations[0].m_hash = 1;
  883. codeBlocks.moveAndReset(binary->m_codeBlocks);
  884. variants.moveAndReset(binary->m_variants);
  885. }
  886. // Sort the mutations
  887. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(), [](const ShaderBinaryMutation& a, const ShaderBinaryMutation& b) {
  888. return a.m_hash < b.m_hash;
  889. });
  890. // Techniques
  891. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  892. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  893. {
  894. zeroMemory(binary->m_techniques[i].m_name);
  895. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  896. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  897. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  898. }
  899. // Structs
  900. if(parser.getGhostStructs().getSize())
  901. {
  902. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  903. }
  904. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  905. {
  906. const ShaderParserGhostStruct& in = parser.getGhostStructs()[i];
  907. ShaderBinaryStruct& out = binary->m_structs[i];
  908. zeroMemory(out);
  909. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  910. ANKI_ASSERT(in.m_members.getSize());
  911. newArray(memPool, in.m_members.getSize(), out.m_members);
  912. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  913. {
  914. const ShaderParserGhostStructMember& inm = in.m_members[j];
  915. ShaderBinaryStructMember& outm = out.m_members[j];
  916. zeroMemory(outm.m_name);
  917. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  918. outm.m_offset = inm.m_offset;
  919. outm.m_type = inm.m_type;
  920. }
  921. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  922. }
  923. return Error::kNone;
  924. }
  925. Error compileShaderProgram(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  926. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager,
  927. ConstWeakArray<ShaderCompilerDefine> defines, ShaderBinary*& binary)
  928. {
  929. const Error err = compileShaderProgramInternal(fname, spirv, debugInfo, fsystem, postParseCallback, taskManager, defines, binary);
  930. if(err)
  931. {
  932. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  933. freeShaderBinary(binary);
  934. }
  935. return err;
  936. }
  937. } // end namespace anki