ShaderCompiler.cpp 31 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/ShaderCompiler/ShaderCompiler.h>
  6. #include <AnKi/ShaderCompiler/ShaderParser.h>
  7. #include <AnKi/ShaderCompiler/Dxc.h>
  8. #include <AnKi/Util/Serializer.h>
  9. #include <AnKi/Util/HashMap.h>
  10. #include <SpirvCross/spirv_cross.hpp>
  11. #define ANKI_DIXL_REFLECTION (ANKI_OS_WINDOWS)
  12. #if ANKI_DIXL_REFLECTION
  13. # include <windows.h>
  14. # include <ThirdParty/Dxc/dxcapi.h>
  15. # include <ThirdParty/Dxc/d3d12shader.h>
  16. # include <wrl.h>
  17. # include <AnKi/Util/CleanupWindows.h>
  18. #endif
  19. namespace anki {
  20. #if ANKI_DIXL_REFLECTION
  21. static HMODULE g_dxcLib = 0;
  22. static DxcCreateInstanceProc g_DxcCreateInstance = nullptr;
  23. static Mutex g_dxcLibMtx;
  24. #endif
  25. void freeShaderBinary(ShaderBinary*& binary)
  26. {
  27. if(binary == nullptr)
  28. {
  29. return;
  30. }
  31. BaseMemoryPool& mempool = ShaderCompilerMemoryPool::getSingleton();
  32. for(ShaderBinaryCodeBlock& code : binary->m_codeBlocks)
  33. {
  34. mempool.free(code.m_binary.getBegin());
  35. }
  36. mempool.free(binary->m_codeBlocks.getBegin());
  37. for(ShaderBinaryMutator& mutator : binary->m_mutators)
  38. {
  39. mempool.free(mutator.m_values.getBegin());
  40. }
  41. mempool.free(binary->m_mutators.getBegin());
  42. for(ShaderBinaryMutation& m : binary->m_mutations)
  43. {
  44. mempool.free(m.m_values.getBegin());
  45. }
  46. mempool.free(binary->m_mutations.getBegin());
  47. for(ShaderBinaryVariant& variant : binary->m_variants)
  48. {
  49. mempool.free(variant.m_techniqueCodeBlocks.getBegin());
  50. }
  51. mempool.free(binary->m_variants.getBegin());
  52. mempool.free(binary->m_techniques.getBegin());
  53. for(ShaderBinaryStruct& s : binary->m_structs)
  54. {
  55. mempool.free(s.m_members.getBegin());
  56. }
  57. mempool.free(binary->m_structs.getBegin());
  58. mempool.free(binary);
  59. binary = nullptr;
  60. }
  61. /// Spin the dials. Used to compute all mutator combinations.
  62. static Bool spinDials(ShaderCompilerDynamicArray<U32>& dials, ConstWeakArray<ShaderParserMutator> mutators)
  63. {
  64. ANKI_ASSERT(dials.getSize() == mutators.getSize() && dials.getSize() > 0);
  65. Bool done = true;
  66. U32 crntDial = dials.getSize() - 1;
  67. while(true)
  68. {
  69. // Turn dial
  70. ++dials[crntDial];
  71. if(dials[crntDial] >= mutators[crntDial].m_values.getSize())
  72. {
  73. if(crntDial == 0)
  74. {
  75. // Reached the 1st dial, stop spinning
  76. done = true;
  77. break;
  78. }
  79. else
  80. {
  81. dials[crntDial] = 0;
  82. --crntDial;
  83. }
  84. }
  85. else
  86. {
  87. done = false;
  88. break;
  89. }
  90. }
  91. return done;
  92. }
  93. /// Does SPIR-V reflection and re-writes the SPIR-V binary's bindings
  94. Error doReflectionSpirv(ConstWeakArray<U8> spirv, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  95. {
  96. spirv_cross::Compiler spvc(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32));
  97. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  98. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  99. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, const DescriptorType origType,
  100. const DescriptorFlag origFlags) -> Error {
  101. for(const spirv_cross::Resource& r : resources)
  102. {
  103. const U32 id = r.id;
  104. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  105. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  106. if(set >= kMaxDescriptorSets && set != kDxcVkBindlessRegisterSpace)
  107. {
  108. errorStr.sprintf("Exceeded set for: %s", r.name.c_str());
  109. return Error::kUserData;
  110. }
  111. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  112. U32 arraySize = 1;
  113. if(typeInfo.array.size() != 0)
  114. {
  115. if(typeInfo.array.size() != 1)
  116. {
  117. errorStr.sprintf("Only 1D arrays are supported: %s", r.name.c_str());
  118. return Error::kUserData;
  119. }
  120. if(set == kDxcVkBindlessRegisterSpace && typeInfo.array[0] != 0)
  121. {
  122. errorStr.sprintf("Only the bindless descriptor set can be an unbound array: %s", r.name.c_str());
  123. return Error::kUserData;
  124. }
  125. arraySize = typeInfo.array[0];
  126. }
  127. // Images are special, they might be texel buffers
  128. DescriptorType type = origType;
  129. DescriptorFlag flags = origFlags;
  130. if(type == DescriptorType::kTexture && typeInfo.image.dim == spv::DimBuffer)
  131. {
  132. type = DescriptorType::kTexelBuffer;
  133. if(typeInfo.image.sampled == 1)
  134. {
  135. flags = DescriptorFlag::kRead;
  136. }
  137. else
  138. {
  139. ANKI_ASSERT(typeInfo.image.sampled == 2);
  140. flags = DescriptorFlag::kReadWrite;
  141. }
  142. }
  143. if(set == kDxcVkBindlessRegisterSpace)
  144. {
  145. // Bindless dset
  146. if(arraySize != 0)
  147. {
  148. errorStr.sprintf("Unexpected unbound array for bindless: %s", r.name.c_str());
  149. }
  150. if(type != DescriptorType::kTexture || flags != DescriptorFlag::kRead)
  151. {
  152. errorStr.sprintf("Unexpected bindless binding: %s", r.name.c_str());
  153. return Error::kUserData;
  154. }
  155. refl.m_descriptor.m_hasVkBindlessDescriptorSet = true;
  156. }
  157. else
  158. {
  159. // Regular binding
  160. if(origType == DescriptorType::kStorageBuffer && binding >= kDxcVkBindingShifts[set][HlslResourceType::kSrv]
  161. && binding < kDxcVkBindingShifts[set][HlslResourceType::kSrv] + 1000)
  162. {
  163. // Storage buffer is read only
  164. flags = DescriptorFlag::kRead;
  165. }
  166. const HlslResourceType hlslResourceType = descriptorTypeToHlslResourceType(type, flags);
  167. if(binding < kDxcVkBindingShifts[set][hlslResourceType] || binding >= kDxcVkBindingShifts[set][hlslResourceType] + 1000)
  168. {
  169. errorStr.sprintf("Unexpected binding: %s", r.name.c_str());
  170. return Error::kUserData;
  171. }
  172. ShaderReflectionBinding akBinding;
  173. akBinding.m_registerBindingPoint = binding - kDxcVkBindingShifts[set][hlslResourceType];
  174. akBinding.m_arraySize = U16(arraySize);
  175. akBinding.m_type = type;
  176. akBinding.m_flags = flags;
  177. refl.m_descriptor.m_bindings[set][refl.m_descriptor.m_bindingCounts[set]] = akBinding;
  178. ++refl.m_descriptor.m_bindingCounts[set];
  179. }
  180. }
  181. return Error::kNone;
  182. };
  183. Error err = func(rsrc.uniform_buffers, DescriptorType::kUniformBuffer, DescriptorFlag::kRead);
  184. if(err)
  185. {
  186. return err;
  187. }
  188. err = func(rsrc.separate_images, DescriptorType::kTexture, DescriptorFlag::kRead); // This also handles texel buffers
  189. if(err)
  190. {
  191. return err;
  192. }
  193. err = func(rsrc.separate_samplers, DescriptorType::kSampler, DescriptorFlag::kRead);
  194. if(err)
  195. {
  196. return err;
  197. }
  198. err = func(rsrc.storage_buffers, DescriptorType::kStorageBuffer, DescriptorFlag::kReadWrite);
  199. if(err)
  200. {
  201. return err;
  202. }
  203. err = func(rsrc.storage_images, DescriptorType::kTexture, DescriptorFlag::kReadWrite);
  204. if(err)
  205. {
  206. return err;
  207. }
  208. err = func(rsrc.acceleration_structures, DescriptorType::kAccelerationStructure, DescriptorFlag::kRead);
  209. if(err)
  210. {
  211. return err;
  212. }
  213. for(U32 i = 0; i < kMaxDescriptorSets; ++i)
  214. {
  215. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  216. }
  217. // Color attachments
  218. if(type == ShaderType::kFragment)
  219. {
  220. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  221. {
  222. const U32 id = r.id;
  223. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  224. refl.m_fragment.m_colorAttachmentWritemask.set(location);
  225. }
  226. }
  227. // Push consts
  228. if(rsrc.push_constant_buffers.size() == 1)
  229. {
  230. const U32 blockSize = U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  231. if(blockSize == 0 || (blockSize % 16) != 0 || blockSize > kMaxU8)
  232. {
  233. errorStr.sprintf("Incorrect push constants size");
  234. return Error::kUserData;
  235. }
  236. refl.m_descriptor.m_pushConstantsSize = blockSize;
  237. }
  238. // Attribs
  239. if(type == ShaderType::kVertex)
  240. {
  241. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  242. {
  243. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  244. #define ANKI_ATTRIB_NAME(x) "in.var." #x
  245. if(r.name == ANKI_ATTRIB_NAME(POSITION))
  246. {
  247. a = VertexAttributeSemantic::kPosition;
  248. }
  249. else if(r.name == ANKI_ATTRIB_NAME(NORMAL))
  250. {
  251. a = VertexAttributeSemantic::kNormal;
  252. }
  253. else if(r.name == ANKI_ATTRIB_NAME(TEXCOORD0) || r.name == ANKI_ATTRIB_NAME(TEXCOORD))
  254. {
  255. a = VertexAttributeSemantic::kTexCoord;
  256. }
  257. else if(r.name == ANKI_ATTRIB_NAME(COLOR))
  258. {
  259. a = VertexAttributeSemantic::kColor;
  260. }
  261. else if(r.name == ANKI_ATTRIB_NAME(MISC0) || r.name == ANKI_ATTRIB_NAME(MISC))
  262. {
  263. a = VertexAttributeSemantic::kMisc0;
  264. }
  265. else if(r.name == ANKI_ATTRIB_NAME(MISC1))
  266. {
  267. a = VertexAttributeSemantic::kMisc1;
  268. }
  269. else if(r.name == ANKI_ATTRIB_NAME(MISC2))
  270. {
  271. a = VertexAttributeSemantic::kMisc2;
  272. }
  273. else if(r.name == ANKI_ATTRIB_NAME(MISC3))
  274. {
  275. a = VertexAttributeSemantic::kMisc3;
  276. }
  277. else
  278. {
  279. errorStr.sprintf("Unexpected attribute name: %s", r.name.c_str());
  280. return Error::kUserData;
  281. }
  282. #undef ANKI_ATTRIB_NAME
  283. refl.m_vertex.m_vertexAttributeMask.set(a);
  284. const U32 id = r.id;
  285. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  286. if(location > kMaxU8)
  287. {
  288. errorStr.sprintf("Too high location value for attribute: %s", r.name.c_str());
  289. return Error::kUserData;
  290. }
  291. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(location);
  292. }
  293. }
  294. // Discards?
  295. if(type == ShaderType::kFragment)
  296. {
  297. visitSpirv(ConstWeakArray<U32>(reinterpret_cast<const U32*>(&spirv[0]), spirv.getSize() / sizeof(U32)),
  298. [&](U32 cmd, [[maybe_unused]] ConstWeakArray<U32> instructions) {
  299. if(cmd == spv::OpKill)
  300. {
  301. refl.m_fragment.m_discards = true;
  302. }
  303. });
  304. }
  305. return Error::kNone;
  306. }
  307. #if ANKI_DIXL_REFLECTION
  308. # define ANKI_REFL_CHECK(x) \
  309. do \
  310. { \
  311. HRESULT rez; \
  312. if((rez = (x)) < 0) [[unlikely]] \
  313. { \
  314. errorStr.sprintf("DXC function failed (HRESULT: %d): %s", rez, #x); \
  315. return Error::kFunctionFailed; \
  316. } \
  317. } while(0)
  318. Error doReflectionDxil(ConstWeakArray<U8> dxil, ShaderType type, ShaderReflection& refl, ShaderCompilerString& errorStr)
  319. {
  320. using Microsoft::WRL::ComPtr;
  321. // Lazyly load the DXC DLL
  322. {
  323. LockGuard lock(g_dxcLibMtx);
  324. if(g_dxcLib == 0)
  325. {
  326. // Init DXC
  327. g_dxcLib = LoadLibraryA(ANKI_SOURCE_DIRECTORY "/ThirdParty/Bin/Windows64/dxcompiler.dll");
  328. if(g_dxcLib == 0)
  329. {
  330. ANKI_SHADER_COMPILER_LOGE("dxcompiler.dll missing or wrong architecture");
  331. return Error::kFunctionFailed;
  332. }
  333. g_DxcCreateInstance = reinterpret_cast<DxcCreateInstanceProc>(GetProcAddress(g_dxcLib, "DxcCreateInstance"));
  334. if(g_DxcCreateInstance == nullptr)
  335. {
  336. ANKI_SHADER_COMPILER_LOGE("DxcCreateInstance was not found in the dxcompiler.dll");
  337. return Error::kFunctionFailed;
  338. }
  339. }
  340. }
  341. const Bool isLib = (type >= ShaderType::kFirstRayTracing && type <= ShaderType::kLastRayTracing) || type == ShaderType::kWorkGraph;
  342. ComPtr<IDxcUtils> utils;
  343. ANKI_REFL_CHECK(g_DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&utils)));
  344. ComPtr<ID3D12ShaderReflection> dxRefl;
  345. ComPtr<ID3D12LibraryReflection> libRefl;
  346. ShaderCompilerDynamicArray<ID3D12FunctionReflection*> funcReflections;
  347. D3D12_SHADER_DESC shaderDesc = {};
  348. if(isLib)
  349. {
  350. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  351. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&libRefl)));
  352. D3D12_LIBRARY_DESC libDesc = {};
  353. libRefl->GetDesc(&libDesc);
  354. if(libDesc.FunctionCount == 0)
  355. {
  356. errorStr.sprintf("Expecting at least 1 in D3D12_LIBRARY_DESC::FunctionCount");
  357. return Error::kUserData;
  358. }
  359. funcReflections.resize(libDesc.FunctionCount);
  360. for(U32 i = 0; i < libDesc.FunctionCount; ++i)
  361. {
  362. funcReflections[i] = libRefl->GetFunctionByIndex(i);
  363. }
  364. }
  365. else
  366. {
  367. const DxcBuffer buff = {dxil.getBegin(), dxil.getSizeInBytes(), 0};
  368. ANKI_REFL_CHECK(utils->CreateReflection(&buff, IID_PPV_ARGS(&dxRefl)));
  369. ANKI_REFL_CHECK(dxRefl->GetDesc(&shaderDesc));
  370. }
  371. for(U32 ifunc = 0; ifunc < ((isLib) ? funcReflections.getSize() : 1); ++ifunc)
  372. {
  373. U32 bindingCount;
  374. if(isLib)
  375. {
  376. D3D12_FUNCTION_DESC funcDesc;
  377. ANKI_REFL_CHECK(funcReflections[ifunc]->GetDesc(&funcDesc));
  378. bindingCount = funcDesc.BoundResources;
  379. }
  380. else
  381. {
  382. bindingCount = shaderDesc.BoundResources;
  383. }
  384. for(U32 i = 0; i < bindingCount; ++i)
  385. {
  386. D3D12_SHADER_INPUT_BIND_DESC bindDesc;
  387. if(isLib)
  388. {
  389. ANKI_REFL_CHECK(funcReflections[ifunc]->GetResourceBindingDesc(i, &bindDesc));
  390. }
  391. else
  392. {
  393. ANKI_REFL_CHECK(dxRefl->GetResourceBindingDesc(i, &bindDesc));
  394. }
  395. ShaderReflectionBinding akBinding;
  396. akBinding.m_type = DescriptorType::kCount;
  397. akBinding.m_flags = DescriptorFlag::kNone;
  398. akBinding.m_arraySize = U16(bindDesc.BindCount);
  399. akBinding.m_registerBindingPoint = bindDesc.BindPoint;
  400. if(bindDesc.Type == D3D_SIT_CBUFFER)
  401. {
  402. // ConstantBuffer
  403. if(bindDesc.Space == 3000 && bindDesc.BindPoint == 0)
  404. {
  405. // It's push/root constants
  406. ID3D12ShaderReflectionConstantBuffer* cbuffer =
  407. (isLib) ? funcReflections[ifunc]->GetConstantBufferByName(bindDesc.Name) : dxRefl->GetConstantBufferByName(bindDesc.Name);
  408. D3D12_SHADER_BUFFER_DESC desc;
  409. ANKI_REFL_CHECK(cbuffer->GetDesc(&desc));
  410. refl.m_descriptor.m_pushConstantsSize = desc.Size;
  411. continue;
  412. }
  413. akBinding.m_type = DescriptorType::kUniformBuffer;
  414. akBinding.m_flags = DescriptorFlag::kRead;
  415. }
  416. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  417. {
  418. // Texture2D etc
  419. akBinding.m_type = DescriptorType::kTexture;
  420. akBinding.m_flags = DescriptorFlag::kRead;
  421. }
  422. else if(bindDesc.Type == D3D_SIT_TEXTURE && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  423. {
  424. // Buffer
  425. akBinding.m_type = DescriptorType::kTexelBuffer;
  426. akBinding.m_flags = DescriptorFlag::kRead;
  427. }
  428. else if(bindDesc.Type == D3D_SIT_SAMPLER)
  429. {
  430. // SamplerState
  431. akBinding.m_type = DescriptorType::kSampler;
  432. akBinding.m_flags = DescriptorFlag::kRead;
  433. }
  434. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension == D3D_SRV_DIMENSION_BUFFER)
  435. {
  436. // RWBuffer
  437. akBinding.m_type = DescriptorType::kTexelBuffer;
  438. akBinding.m_flags = DescriptorFlag::kReadWrite;
  439. }
  440. else if(bindDesc.Type == D3D_SIT_UAV_RWTYPED && bindDesc.Dimension != D3D_SRV_DIMENSION_BUFFER)
  441. {
  442. // RWTexture2D etc
  443. akBinding.m_type = DescriptorType::kTexture;
  444. akBinding.m_flags = DescriptorFlag::kReadWrite;
  445. }
  446. else if(bindDesc.Type == D3D_SIT_BYTEADDRESS)
  447. {
  448. // ByteAddressBuffer
  449. akBinding.m_type = DescriptorType::kStorageBuffer;
  450. akBinding.m_flags = DescriptorFlag::kRead | DescriptorFlag::kByteAddressBuffer;
  451. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  452. }
  453. else if(bindDesc.Type == D3D_SIT_UAV_RWBYTEADDRESS)
  454. {
  455. // RWByteAddressBuffer
  456. akBinding.m_type = DescriptorType::kStorageBuffer;
  457. akBinding.m_flags = DescriptorFlag::kReadWrite | DescriptorFlag::kByteAddressBuffer;
  458. akBinding.m_d3dStructuredBufferStride = sizeof(U32);
  459. }
  460. else if(bindDesc.Type == D3D_SIT_RTACCELERATIONSTRUCTURE)
  461. {
  462. // RaytracingAccelerationStructure
  463. akBinding.m_type = DescriptorType::kAccelerationStructure;
  464. akBinding.m_flags = DescriptorFlag::kRead;
  465. }
  466. else if(bindDesc.Type == D3D_SIT_STRUCTURED)
  467. {
  468. // StructuredBuffer
  469. akBinding.m_type = DescriptorType::kStorageBuffer;
  470. akBinding.m_flags = DescriptorFlag::kRead;
  471. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  472. }
  473. else if(bindDesc.Type == D3D_SIT_UAV_RWSTRUCTURED)
  474. {
  475. // RWStructuredBuffer
  476. akBinding.m_type = DescriptorType::kStorageBuffer;
  477. akBinding.m_flags = DescriptorFlag::kReadWrite;
  478. akBinding.m_d3dStructuredBufferStride = U16(bindDesc.NumSamples);
  479. }
  480. else
  481. {
  482. errorStr.sprintf("Unrecognized type for binding: %s", bindDesc.Name);
  483. return Error::kUserData;
  484. }
  485. Bool skip = false;
  486. if(isLib)
  487. {
  488. // Search if the binding exists because it may repeat
  489. for(U32 i = 0; i < refl.m_descriptor.m_bindingCounts[bindDesc.Space]; ++i)
  490. {
  491. if(refl.m_descriptor.m_bindings[bindDesc.Space][i] == akBinding)
  492. {
  493. skip = true;
  494. break;
  495. }
  496. }
  497. }
  498. if(!skip)
  499. {
  500. refl.m_descriptor.m_bindings[bindDesc.Space][refl.m_descriptor.m_bindingCounts[bindDesc.Space]] = akBinding;
  501. ++refl.m_descriptor.m_bindingCounts[bindDesc.Space];
  502. }
  503. }
  504. }
  505. for(U32 i = 0; i < kMaxDescriptorSets; ++i)
  506. {
  507. std::sort(refl.m_descriptor.m_bindings[i].getBegin(), refl.m_descriptor.m_bindings[i].getBegin() + refl.m_descriptor.m_bindingCounts[i]);
  508. }
  509. if(type == ShaderType::kVertex)
  510. {
  511. for(U32 i = 0; i < shaderDesc.InputParameters; ++i)
  512. {
  513. D3D12_SIGNATURE_PARAMETER_DESC in;
  514. ANKI_REFL_CHECK(dxRefl->GetInputParameterDesc(i, &in));
  515. VertexAttributeSemantic a = VertexAttributeSemantic::kCount;
  516. # define ANKI_ATTRIB_NAME(x, idx) CString(in.SemanticName) == # x&& in.SemanticIndex == idx
  517. if(ANKI_ATTRIB_NAME(POSITION, 0))
  518. {
  519. a = VertexAttributeSemantic::kPosition;
  520. }
  521. else if(ANKI_ATTRIB_NAME(NORMAL, 0))
  522. {
  523. a = VertexAttributeSemantic::kNormal;
  524. }
  525. else if(ANKI_ATTRIB_NAME(TEXCOORD, 0))
  526. {
  527. a = VertexAttributeSemantic::kTexCoord;
  528. }
  529. else if(ANKI_ATTRIB_NAME(COLOR, 0))
  530. {
  531. a = VertexAttributeSemantic::kColor;
  532. }
  533. else if(ANKI_ATTRIB_NAME(MISC, 0))
  534. {
  535. a = VertexAttributeSemantic::kMisc0;
  536. }
  537. else if(ANKI_ATTRIB_NAME(MISC, 1))
  538. {
  539. a = VertexAttributeSemantic::kMisc1;
  540. }
  541. else if(ANKI_ATTRIB_NAME(MISC, 2))
  542. {
  543. a = VertexAttributeSemantic::kMisc2;
  544. }
  545. else if(ANKI_ATTRIB_NAME(MISC, 3))
  546. {
  547. a = VertexAttributeSemantic::kMisc3;
  548. }
  549. else if(ANKI_ATTRIB_NAME(SV_VERTEXID, 0) || ANKI_ATTRIB_NAME(SV_INSTANCEID, 0))
  550. {
  551. // Ignore
  552. continue;
  553. }
  554. else
  555. {
  556. errorStr.sprintf("Unexpected attribute name: %s", in.SemanticName);
  557. return Error::kUserData;
  558. }
  559. # undef ANKI_ATTRIB_NAME
  560. refl.m_vertex.m_vertexAttributeMask.set(a);
  561. refl.m_vertex.m_vkVertexAttributeLocations[a] = U8(i); // Just set something
  562. }
  563. }
  564. if(type == ShaderType::kFragment)
  565. {
  566. for(U32 i = 0; i < shaderDesc.OutputParameters; ++i)
  567. {
  568. D3D12_SIGNATURE_PARAMETER_DESC desc;
  569. ANKI_REFL_CHECK(dxRefl->GetOutputParameterDesc(i, &desc));
  570. if(CString(desc.SemanticName) == "SV_TARGET")
  571. {
  572. refl.m_fragment.m_colorAttachmentWritemask.set(desc.SemanticIndex);
  573. }
  574. }
  575. }
  576. return Error::kNone;
  577. }
  578. #endif // #if ANKI_DIXL_REFLECTION
  579. static void compileVariantAsync(const ShaderParser& parser, Bool spirv, Bool debugInfo, ShaderBinaryMutation& mutation,
  580. ShaderCompilerDynamicArray<ShaderBinaryVariant>& variants,
  581. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>& codeBlocks, ShaderCompilerDynamicArray<U64>& sourceCodeHashes,
  582. ShaderCompilerAsyncTaskInterface& taskManager, Mutex& mtx, Atomic<I32>& error)
  583. {
  584. class Ctx
  585. {
  586. public:
  587. const ShaderParser* m_parser;
  588. ShaderBinaryMutation* m_mutation;
  589. ShaderCompilerDynamicArray<ShaderBinaryVariant>* m_variants;
  590. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock>* m_codeBlocks;
  591. ShaderCompilerDynamicArray<U64>* m_sourceCodeHashes;
  592. Mutex* m_mtx;
  593. Atomic<I32>* m_err;
  594. Bool m_spirv;
  595. Bool m_debugInfo;
  596. };
  597. Ctx* ctx = newInstance<Ctx>(ShaderCompilerMemoryPool::getSingleton());
  598. ctx->m_parser = &parser;
  599. ctx->m_mutation = &mutation;
  600. ctx->m_variants = &variants;
  601. ctx->m_codeBlocks = &codeBlocks;
  602. ctx->m_sourceCodeHashes = &sourceCodeHashes;
  603. ctx->m_mtx = &mtx;
  604. ctx->m_err = &error;
  605. ctx->m_spirv = spirv;
  606. ctx->m_debugInfo = debugInfo;
  607. auto callback = [](void* userData) {
  608. Ctx& ctx = *static_cast<Ctx*>(userData);
  609. class Cleanup
  610. {
  611. public:
  612. Ctx* m_ctx;
  613. ~Cleanup()
  614. {
  615. deleteInstance(ShaderCompilerMemoryPool::getSingleton(), m_ctx);
  616. }
  617. } cleanup{&ctx};
  618. if(ctx.m_err->load() != 0)
  619. {
  620. // Don't bother
  621. return;
  622. }
  623. const U32 techniqueCount = ctx.m_parser->getTechniques().getSize();
  624. // Compile the sources
  625. ShaderCompilerDynamicArray<ShaderBinaryTechniqueCodeBlocks> codeBlockIndices;
  626. codeBlockIndices.resize(techniqueCount);
  627. for(auto& it : codeBlockIndices)
  628. {
  629. it.m_codeBlockIndices.fill(kMaxU32);
  630. }
  631. ShaderCompilerString compilerErrorLog;
  632. Error err = Error::kNone;
  633. U newCodeBlockCount = 0;
  634. for(U32 t = 0; t < techniqueCount && !err; ++t)
  635. {
  636. const ShaderParserTechnique& technique = ctx.m_parser->getTechniques()[t];
  637. for(ShaderType shaderType : EnumBitsIterable<ShaderType, ShaderTypeBit>(technique.m_shaderTypes))
  638. {
  639. ShaderCompilerString source;
  640. ctx.m_parser->generateVariant(ctx.m_mutation->m_values, technique, shaderType, source);
  641. // Check if the source code was found before
  642. const U64 sourceCodeHash = source.computeHash();
  643. if(technique.m_activeMutators[shaderType] != kMaxU64)
  644. {
  645. LockGuard lock(*ctx.m_mtx);
  646. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  647. Bool found = false;
  648. for(U32 i = 0; i < ctx.m_sourceCodeHashes->getSize(); ++i)
  649. {
  650. if((*ctx.m_sourceCodeHashes)[i] == sourceCodeHash)
  651. {
  652. codeBlockIndices[t].m_codeBlockIndices[shaderType] = i;
  653. found = true;
  654. break;
  655. }
  656. }
  657. if(found)
  658. {
  659. continue;
  660. }
  661. }
  662. ShaderCompilerDynamicArray<U8> il;
  663. if(ctx.m_spirv)
  664. {
  665. err = compileHlslToSpirv(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo, il, compilerErrorLog);
  666. }
  667. else
  668. {
  669. err = compileHlslToDxil(source, shaderType, ctx.m_parser->compileWith16bitTypes(), ctx.m_debugInfo, il, compilerErrorLog);
  670. }
  671. if(err)
  672. {
  673. break;
  674. }
  675. const U64 newHash = computeHash(il.getBegin(), il.getSizeInBytes());
  676. ShaderReflection refl;
  677. if(ctx.m_spirv)
  678. {
  679. err = doReflectionSpirv(il, shaderType, refl, compilerErrorLog);
  680. }
  681. else
  682. {
  683. #if ANKI_DIXL_REFLECTION
  684. err = doReflectionDxil(il, shaderType, refl, compilerErrorLog);
  685. #else
  686. ANKI_SHADER_COMPILER_LOGE("Can't generate shader compilation on non-windows platforms");
  687. err = Error::kFunctionFailed;
  688. #endif
  689. }
  690. if(err)
  691. {
  692. break;
  693. }
  694. // Add the binary if not already there
  695. {
  696. LockGuard lock(*ctx.m_mtx);
  697. Bool found = false;
  698. for(U32 j = 0; j < ctx.m_codeBlocks->getSize(); ++j)
  699. {
  700. if((*ctx.m_codeBlocks)[j].m_hash == newHash)
  701. {
  702. codeBlockIndices[t].m_codeBlockIndices[shaderType] = j;
  703. found = true;
  704. break;
  705. }
  706. }
  707. if(!found)
  708. {
  709. codeBlockIndices[t].m_codeBlockIndices[shaderType] = ctx.m_codeBlocks->getSize();
  710. auto& codeBlock = *ctx.m_codeBlocks->emplaceBack();
  711. il.moveAndReset(codeBlock.m_binary);
  712. codeBlock.m_hash = newHash;
  713. codeBlock.m_reflection = refl;
  714. ctx.m_sourceCodeHashes->emplaceBack(sourceCodeHash);
  715. ANKI_ASSERT(ctx.m_sourceCodeHashes->getSize() == ctx.m_codeBlocks->getSize());
  716. ++newCodeBlockCount;
  717. }
  718. }
  719. }
  720. }
  721. if(err)
  722. {
  723. I32 expectedErr = 0;
  724. const Bool isFirstError = ctx.m_err->compareExchange(expectedErr, err._getCode());
  725. if(isFirstError)
  726. {
  727. ANKI_SHADER_COMPILER_LOGE("Shader compilation failed:\n%s", compilerErrorLog.cstr());
  728. return;
  729. }
  730. return;
  731. }
  732. // Do variant stuff
  733. {
  734. LockGuard lock(*ctx.m_mtx);
  735. Bool createVariant = true;
  736. if(newCodeBlockCount == 0)
  737. {
  738. // No new code blocks generated, search all variants to see if there is a duplicate
  739. for(U32 i = 0; i < ctx.m_variants->getSize(); ++i)
  740. {
  741. Bool same = true;
  742. for(U32 t = 0; t < techniqueCount; ++t)
  743. {
  744. const ShaderBinaryTechniqueCodeBlocks& a = (*ctx.m_variants)[i].m_techniqueCodeBlocks[t];
  745. const ShaderBinaryTechniqueCodeBlocks& b = codeBlockIndices[t];
  746. if(memcmp(&a, &b, sizeof(a)) != 0)
  747. {
  748. // Not the same
  749. same = false;
  750. break;
  751. }
  752. }
  753. if(same)
  754. {
  755. createVariant = false;
  756. ctx.m_mutation->m_variantIndex = i;
  757. break;
  758. }
  759. }
  760. }
  761. // Create a new variant
  762. if(createVariant)
  763. {
  764. ctx.m_mutation->m_variantIndex = ctx.m_variants->getSize();
  765. ShaderBinaryVariant* variant = ctx.m_variants->emplaceBack();
  766. codeBlockIndices.moveAndReset(variant->m_techniqueCodeBlocks);
  767. }
  768. }
  769. };
  770. taskManager.enqueueTask(callback, ctx);
  771. }
  772. static Error compileShaderProgramInternal(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  773. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager_,
  774. ConstWeakArray<ShaderCompilerDefine> defines_, ShaderBinary*& binary)
  775. {
  776. ShaderCompilerMemoryPool& memPool = ShaderCompilerMemoryPool::getSingleton();
  777. ShaderCompilerDynamicArray<ShaderCompilerDefine> defines;
  778. for(const ShaderCompilerDefine& d : defines_)
  779. {
  780. defines.emplaceBack(d);
  781. }
  782. // Initialize the binary
  783. binary = newInstance<ShaderBinary>(memPool);
  784. memcpy(&binary->m_magic[0], kShaderBinaryMagic, 8);
  785. // Parse source
  786. ShaderParser parser(fname, &fsystem, defines);
  787. ANKI_CHECK(parser.parse());
  788. if(postParseCallback && postParseCallback->skipCompilation(parser.getHash()))
  789. {
  790. return Error::kNone;
  791. }
  792. // Get mutators
  793. U32 mutationCount = 0;
  794. if(parser.getMutators().getSize() > 0)
  795. {
  796. newArray(memPool, parser.getMutators().getSize(), binary->m_mutators);
  797. for(U32 i = 0; i < binary->m_mutators.getSize(); ++i)
  798. {
  799. ShaderBinaryMutator& out = binary->m_mutators[i];
  800. const ShaderParserMutator& in = parser.getMutators()[i];
  801. zeroMemory(out);
  802. newArray(memPool, in.m_values.getSize(), out.m_values);
  803. memcpy(out.m_values.getBegin(), in.m_values.getBegin(), in.m_values.getSizeInBytes());
  804. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  805. // Update the count
  806. mutationCount = (i == 0) ? out.m_values.getSize() : mutationCount * out.m_values.getSize();
  807. }
  808. }
  809. else
  810. {
  811. ANKI_ASSERT(binary->m_mutators.getSize() == 0);
  812. }
  813. // Create all variants
  814. Mutex mtx;
  815. Atomic<I32> errorAtomic(0);
  816. class SyncronousShaderCompilerAsyncTaskInterface : public ShaderCompilerAsyncTaskInterface
  817. {
  818. public:
  819. void enqueueTask(void (*callback)(void* userData), void* userData) final
  820. {
  821. callback(userData);
  822. }
  823. Error joinTasks() final
  824. {
  825. // Nothing
  826. return Error::kNone;
  827. }
  828. } syncTaskManager;
  829. ShaderCompilerAsyncTaskInterface& taskManager = (taskManager_) ? *taskManager_ : syncTaskManager;
  830. if(parser.getMutators().getSize() > 0)
  831. {
  832. // Initialize
  833. ShaderCompilerDynamicArray<MutatorValue> mutationValues;
  834. mutationValues.resize(parser.getMutators().getSize());
  835. ShaderCompilerDynamicArray<U32> dials;
  836. dials.resize(parser.getMutators().getSize(), 0);
  837. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  838. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  839. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  840. ShaderCompilerDynamicArray<ShaderBinaryMutation> mutations;
  841. mutations.resize(mutationCount);
  842. ShaderCompilerHashMap<U64, U32> mutationHashToIdx;
  843. // Grow the storage of the variants array. Can't have it resize, threads will work on stale data
  844. variants.resizeStorage(mutationCount);
  845. mutationCount = 0;
  846. // Spin for all possible combinations of mutators and
  847. // - Create the spirv
  848. // - Populate the binary variant
  849. do
  850. {
  851. // Create the mutation
  852. for(U32 i = 0; i < parser.getMutators().getSize(); ++i)
  853. {
  854. mutationValues[i] = parser.getMutators()[i].m_values[dials[i]];
  855. }
  856. ShaderBinaryMutation& mutation = mutations[mutationCount++];
  857. newArray(memPool, mutationValues.getSize(), mutation.m_values);
  858. memcpy(mutation.m_values.getBegin(), mutationValues.getBegin(), mutationValues.getSizeInBytes());
  859. mutation.m_hash = computeHash(mutationValues.getBegin(), mutationValues.getSizeInBytes());
  860. ANKI_ASSERT(mutation.m_hash > 0);
  861. if(parser.skipMutation(mutationValues))
  862. {
  863. mutation.m_variantIndex = kMaxU32;
  864. }
  865. else
  866. {
  867. // New and unique mutation and thus variant, add it
  868. compileVariantAsync(parser, spirv, debugInfo, mutation, variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  869. ANKI_ASSERT(mutationHashToIdx.find(mutation.m_hash) == mutationHashToIdx.getEnd());
  870. mutationHashToIdx.emplace(mutation.m_hash, mutationCount - 1);
  871. }
  872. } while(!spinDials(dials, parser.getMutators()));
  873. ANKI_ASSERT(mutationCount == mutations.getSize());
  874. // Done, wait the threads
  875. ANKI_CHECK(taskManager.joinTasks());
  876. // Now error out
  877. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  878. // Store temp containers to binary
  879. codeBlocks.moveAndReset(binary->m_codeBlocks);
  880. mutations.moveAndReset(binary->m_mutations);
  881. variants.moveAndReset(binary->m_variants);
  882. }
  883. else
  884. {
  885. newArray(memPool, 1, binary->m_mutations);
  886. ShaderCompilerDynamicArray<ShaderBinaryVariant> variants;
  887. ShaderCompilerDynamicArray<ShaderBinaryCodeBlock> codeBlocks;
  888. ShaderCompilerDynamicArray<U64> sourceCodeHashes;
  889. compileVariantAsync(parser, spirv, debugInfo, binary->m_mutations[0], variants, codeBlocks, sourceCodeHashes, taskManager, mtx, errorAtomic);
  890. ANKI_CHECK(taskManager.joinTasks());
  891. ANKI_CHECK(Error(errorAtomic.getNonAtomically()));
  892. ANKI_ASSERT(codeBlocks.getSize() >= parser.getTechniques().getSize());
  893. ANKI_ASSERT(binary->m_mutations[0].m_variantIndex == 0);
  894. ANKI_ASSERT(variants.getSize() == 1);
  895. binary->m_mutations[0].m_hash = 1;
  896. codeBlocks.moveAndReset(binary->m_codeBlocks);
  897. variants.moveAndReset(binary->m_variants);
  898. }
  899. // Sort the mutations
  900. std::sort(binary->m_mutations.getBegin(), binary->m_mutations.getEnd(), [](const ShaderBinaryMutation& a, const ShaderBinaryMutation& b) {
  901. return a.m_hash < b.m_hash;
  902. });
  903. // Techniques
  904. newArray(memPool, parser.getTechniques().getSize(), binary->m_techniques);
  905. for(U32 i = 0; i < parser.getTechniques().getSize(); ++i)
  906. {
  907. zeroMemory(binary->m_techniques[i].m_name);
  908. memcpy(binary->m_techniques[i].m_name.getBegin(), parser.getTechniques()[i].m_name.cstr(), parser.getTechniques()[i].m_name.getLength() + 1);
  909. binary->m_techniques[i].m_shaderTypes = parser.getTechniques()[i].m_shaderTypes;
  910. binary->m_shaderTypes |= parser.getTechniques()[i].m_shaderTypes;
  911. }
  912. // Structs
  913. if(parser.getGhostStructs().getSize())
  914. {
  915. newArray(memPool, parser.getGhostStructs().getSize(), binary->m_structs);
  916. }
  917. for(U32 i = 0; i < parser.getGhostStructs().getSize(); ++i)
  918. {
  919. const ShaderParserGhostStruct& in = parser.getGhostStructs()[i];
  920. ShaderBinaryStruct& out = binary->m_structs[i];
  921. zeroMemory(out);
  922. memcpy(out.m_name.getBegin(), in.m_name.cstr(), in.m_name.getLength() + 1);
  923. ANKI_ASSERT(in.m_members.getSize());
  924. newArray(memPool, in.m_members.getSize(), out.m_members);
  925. for(U32 j = 0; j < in.m_members.getSize(); ++j)
  926. {
  927. const ShaderParserGhostStructMember& inm = in.m_members[j];
  928. ShaderBinaryStructMember& outm = out.m_members[j];
  929. zeroMemory(outm.m_name);
  930. memcpy(outm.m_name.getBegin(), inm.m_name.cstr(), inm.m_name.getLength() + 1);
  931. outm.m_offset = inm.m_offset;
  932. outm.m_type = inm.m_type;
  933. }
  934. out.m_size = in.m_members.getBack().m_offset + getShaderVariableDataTypeInfo(in.m_members.getBack().m_type).m_size;
  935. }
  936. return Error::kNone;
  937. }
  938. Error compileShaderProgram(CString fname, Bool spirv, Bool debugInfo, ShaderCompilerFilesystemInterface& fsystem,
  939. ShaderCompilerPostParseInterface* postParseCallback, ShaderCompilerAsyncTaskInterface* taskManager,
  940. ConstWeakArray<ShaderCompilerDefine> defines, ShaderBinary*& binary)
  941. {
  942. const Error err = compileShaderProgramInternal(fname, spirv, debugInfo, fsystem, postParseCallback, taskManager, defines, binary);
  943. if(err)
  944. {
  945. ANKI_SHADER_COMPILER_LOGE("Failed to compile: %s", fname.cstr());
  946. freeShaderBinary(binary);
  947. }
  948. return err;
  949. }
  950. } // end namespace anki