ShaderImpl.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. // Copyright (C) 2009-2021, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Gr/Vulkan/ShaderImpl.h>
  6. #include <AnKi/Gr/Vulkan/GrManagerImpl.h>
  7. #include <AnKi/Gr/Utils/Functions.h>
  8. #include <SprivCross/spirv_cross.hpp>
  9. #define ANKI_DUMP_SHADERS 0
  10. #if ANKI_DUMP_SHADERS
  11. # include <AnKi/Util/File.h>
  12. # include <AnKi/Gr/GrManager.h>
  13. #endif
  14. namespace anki
  15. {
  16. class ShaderImpl::SpecConstsVector
  17. {
  18. public:
  19. spirv_cross::SmallVector<spirv_cross::SpecializationConstant> m_vec;
  20. };
  21. ShaderImpl::~ShaderImpl()
  22. {
  23. for(auto& x : m_bindings)
  24. {
  25. x.destroy(getAllocator());
  26. }
  27. if(m_handle)
  28. {
  29. vkDestroyShaderModule(getDevice(), m_handle, nullptr);
  30. }
  31. if(m_specConstInfo.pMapEntries)
  32. {
  33. getAllocator().deleteArray(const_cast<VkSpecializationMapEntry*>(m_specConstInfo.pMapEntries),
  34. m_specConstInfo.mapEntryCount);
  35. }
  36. if(m_specConstInfo.pData)
  37. {
  38. getAllocator().deleteArray(static_cast<I32*>(const_cast<void*>(m_specConstInfo.pData)),
  39. m_specConstInfo.dataSize / sizeof(I32));
  40. }
  41. }
  42. Error ShaderImpl::init(const ShaderInitInfo& inf)
  43. {
  44. ANKI_ASSERT(inf.m_binary.getSize() > 0);
  45. ANKI_ASSERT(m_handle == VK_NULL_HANDLE);
  46. m_shaderType = inf.m_shaderType;
  47. #if ANKI_DUMP_SHADERS
  48. {
  49. StringAuto fnameSpirv(getAllocator());
  50. fnameSpirv.sprintf("%s/%05u.spv", getManager().getCacheDirectory().cstr(), getUuid());
  51. File fileSpirv;
  52. ANKI_CHECK(
  53. fileSpirv.open(fnameSpirv.toCString(), FileOpenFlag::BINARY | FileOpenFlag::WRITE | FileOpenFlag::SPECIAL));
  54. ANKI_CHECK(fileSpirv.write(&inf.m_binary[0], inf.m_binary.getSize()));
  55. }
  56. #endif
  57. VkShaderModuleCreateInfo ci = {VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, nullptr, 0, inf.m_binary.getSize(),
  58. reinterpret_cast<const uint32_t*>(&inf.m_binary[0])};
  59. ANKI_VK_CHECK(vkCreateShaderModule(getDevice(), &ci, nullptr, &m_handle));
  60. // Get reflection info
  61. SpecConstsVector specConstIds;
  62. doReflection(inf.m_binary, specConstIds);
  63. // Set spec info
  64. if(specConstIds.m_vec.size())
  65. {
  66. const U32 constCount = U32(specConstIds.m_vec.size());
  67. m_specConstInfo.mapEntryCount = constCount;
  68. m_specConstInfo.pMapEntries = getAllocator().newArray<VkSpecializationMapEntry>(constCount);
  69. m_specConstInfo.dataSize = constCount * sizeof(U32);
  70. m_specConstInfo.pData = getAllocator().newArray<U32>(constCount);
  71. U32 count = 0;
  72. for(const spirv_cross::SpecializationConstant& sconst : specConstIds.m_vec)
  73. {
  74. // Set the entry
  75. VkSpecializationMapEntry& entry = const_cast<VkSpecializationMapEntry&>(m_specConstInfo.pMapEntries[count]);
  76. entry.constantID = sconst.constant_id;
  77. entry.offset = count * sizeof(U32);
  78. entry.size = sizeof(U32);
  79. // Find the value
  80. const ShaderSpecializationConstValue* val = nullptr;
  81. for(const ShaderSpecializationConstValue& v : inf.m_constValues)
  82. {
  83. if(v.m_constantId == entry.constantID)
  84. {
  85. val = &v;
  86. break;
  87. }
  88. }
  89. ANKI_ASSERT(val && "Contant ID wasn't found in the init info");
  90. // Copy the data
  91. U8* data = static_cast<U8*>(const_cast<void*>(m_specConstInfo.pData));
  92. data += entry.offset;
  93. *reinterpret_cast<U32*>(data) = val->m_uint;
  94. ++count;
  95. }
  96. }
  97. return Error::NONE;
  98. }
  99. void ShaderImpl::doReflection(ConstWeakArray<U8> spirv, SpecConstsVector& specConstIds)
  100. {
  101. spirv_cross::Compiler spvc(reinterpret_cast<const uint32_t*>(&spirv[0]), spirv.getSize() / sizeof(unsigned int));
  102. spirv_cross::ShaderResources rsrc = spvc.get_shader_resources();
  103. spirv_cross::ShaderResources rsrcActive = spvc.get_shader_resources(spvc.get_active_interface_variables());
  104. Array<U32, MAX_DESCRIPTOR_SETS> counts = {};
  105. Array2d<DescriptorBinding, MAX_DESCRIPTOR_SETS, MAX_BINDINGS_PER_DESCRIPTOR_SET> descriptors;
  106. auto func = [&](const spirv_cross::SmallVector<spirv_cross::Resource>& resources, DescriptorType type) -> void {
  107. for(const spirv_cross::Resource& r : resources)
  108. {
  109. const U32 id = r.id;
  110. const U32 set = spvc.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
  111. ANKI_ASSERT(set < MAX_DESCRIPTOR_SETS);
  112. const U32 binding = spvc.get_decoration(id, spv::Decoration::DecorationBinding);
  113. ANKI_ASSERT(binding < MAX_BINDINGS_PER_DESCRIPTOR_SET);
  114. const spirv_cross::SPIRType& typeInfo = spvc.get_type(r.type_id);
  115. U32 arraySize = 1;
  116. if(typeInfo.array.size() != 0)
  117. {
  118. ANKI_ASSERT(typeInfo.array.size() == 1 && "Only 1D arrays are supported");
  119. arraySize = typeInfo.array[0];
  120. ANKI_ASSERT(arraySize > 0 && (arraySize - 1) <= MAX_U8);
  121. }
  122. m_descriptorSetMask.set(set);
  123. m_activeBindingMask[set].set(set);
  124. // Check that there are no other descriptors with the same binding
  125. U32 foundIdx = MAX_U32;
  126. for(U32 i = 0; i < counts[set]; ++i)
  127. {
  128. if(descriptors[set][i].m_binding == binding)
  129. {
  130. foundIdx = i;
  131. break;
  132. }
  133. }
  134. if(foundIdx == MAX_U32)
  135. {
  136. // New binding, init it
  137. DescriptorBinding& descriptor = descriptors[set][counts[set]++];
  138. descriptor.m_binding = U8(binding);
  139. descriptor.m_type = type;
  140. descriptor.m_stageMask = ShaderTypeBit(1 << m_shaderType);
  141. descriptor.m_arraySizeMinusOne = U8(arraySize - 1);
  142. }
  143. else
  144. {
  145. // Same binding, make sure the type is compatible
  146. ANKI_ASSERT(type == descriptors[set][foundIdx].m_type && "Same binding different type");
  147. ANKI_ASSERT(arraySize - 1 == descriptors[set][foundIdx].m_arraySizeMinusOne
  148. && "Same binding different array size");
  149. }
  150. }
  151. };
  152. func(rsrc.uniform_buffers, DescriptorType::UNIFORM_BUFFER);
  153. func(rsrc.sampled_images, DescriptorType::COMBINED_TEXTURE_SAMPLER);
  154. func(rsrc.separate_images, DescriptorType::TEXTURE);
  155. func(rsrc.separate_samplers, DescriptorType::SAMPLER);
  156. func(rsrc.storage_buffers, DescriptorType::STORAGE_BUFFER);
  157. func(rsrc.storage_images, DescriptorType::IMAGE);
  158. func(rsrc.acceleration_structures, DescriptorType::ACCELERATION_STRUCTURE);
  159. for(U32 set = 0; set < MAX_DESCRIPTOR_SETS; ++set)
  160. {
  161. if(counts[set])
  162. {
  163. m_bindings[set].create(getAllocator(), counts[set]);
  164. memcpy(&m_bindings[set][0], &descriptors[set][0], counts[set] * sizeof(DescriptorBinding));
  165. }
  166. }
  167. // Color attachments
  168. if(m_shaderType == ShaderType::FRAGMENT)
  169. {
  170. for(const spirv_cross::Resource& r : rsrc.stage_outputs)
  171. {
  172. const U32 id = r.id;
  173. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  174. m_colorAttachmentWritemask.set(location);
  175. }
  176. }
  177. // Attribs
  178. if(m_shaderType == ShaderType::VERTEX)
  179. {
  180. for(const spirv_cross::Resource& r : rsrcActive.stage_inputs)
  181. {
  182. const U32 id = r.id;
  183. const U32 location = spvc.get_decoration(id, spv::Decoration::DecorationLocation);
  184. m_attributeMask.set(location);
  185. }
  186. }
  187. // Spec consts
  188. specConstIds.m_vec = spvc.get_specialization_constants();
  189. // Push consts
  190. if(rsrc.push_constant_buffers.size() == 1)
  191. {
  192. const U32 blockSize =
  193. U32(spvc.get_declared_struct_size(spvc.get_type(rsrc.push_constant_buffers[0].base_type_id)));
  194. ANKI_ASSERT(blockSize > 0);
  195. ANKI_ASSERT(blockSize % 16 == 0 && "Should be aligned");
  196. ANKI_ASSERT(blockSize <= getGrManagerImpl().getDeviceCapabilities().m_pushConstantsSize);
  197. m_pushConstantsSize = blockSize;
  198. }
  199. }
  200. } // end namespace anki