Просмотр исходного кода

implement compute shader support

Rinthel 6 лет назад
Родитель
Сommit
2fa32d855d
5 измененных файлов с 436 добавлено и 120 удалено
  1. 1 1
      src/bgfx_compute.sh
  2. 1 1
      src/bgfx_p.h
  3. 309 86
      src/renderer_vk.cpp
  4. 12 3
      src/renderer_vk.h
  5. 113 29
      tools/shaderc/shaderc_spirv.cpp

+ 1 - 1
src/bgfx_compute.sh

@@ -138,7 +138,7 @@
 #define IMAGE3D_WR( _name, _format, _reg) IMAGE3D_RW(_name, _format, _reg)
 #define UIMAGE3D_WR(_name, _format, _reg) IMAGE3D_RW(_name, _format, _reg)
 
-#if BGFX_SHADER_LANGUAGE_METAL
+#if BGFX_SHADER_LANGUAGE_METAL || BGFX_SHADER_LANGUAGE_SPIRV
 #define BUFFER_RO(_name, _struct, _reg) StructuredBuffer<_struct>   _name : REGISTER(t, _reg)
 #define BUFFER_RW(_name, _struct, _reg) RWStructuredBuffer <_struct> _name : REGISTER(u, _reg)
 #define BUFFER_WR(_name, _struct, _reg) BUFFER_RW(_name, _struct, _reg)

+ 1 - 1
src/bgfx_p.h

@@ -3836,7 +3836,7 @@ constexpr uint64_t kSortKeyComputeProgramMask  = uint64_t(BGFX_CONFIG_MAX_PROGRA
 				bx::read(&reader, regCount, &err);
 
 				PredefinedUniform::Enum predefined = nameToPredefinedUniformEnum(name);
-				if (PredefinedUniform::Count == predefined)
+				if (PredefinedUniform::Count == predefined && UniformType::End != UniformType::Enum(type))
 				{
 					uniforms[sr.m_num] = createUniform(name, UniformType::Enum(type), regCount);
 					sr.m_num++;

+ 309 - 86
src/renderer_vk.cpp

@@ -1286,6 +1286,11 @@ VK_IMPORT_INSTANCE
 				}
 			}
 
+			if (m_qfiCompute != UINT32_MAX)
+			{
+				g_caps.supported |= BGFX_CAPS_COMPUTE;
+			}
+
 			{
 				const char* enabledLayerNames[] =
 				{
@@ -2007,7 +2012,8 @@ VK_IMPORT_DEVICE
 					{ VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE,          (10 * BGFX_CONFIG_MAX_TEXTURE_SAMPLERS) << 10 },
 					{ VK_DESCRIPTOR_TYPE_SAMPLER,                (10 * BGFX_CONFIG_MAX_TEXTURE_SAMPLERS) << 10 },
 					{ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,         10<<10                           },
-//					{ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         BGFX_CONFIG_MAX_TEXTURE_SAMPLERS },
+					{ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         BGFX_CONFIG_MAX_TEXTURE_SAMPLERS << 10 },
+					{ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,         BGFX_CONFIG_MAX_TEXTURE_SAMPLERS << 10 },
 				};
 
 // 				VkDescriptorSetLayoutBinding dslb[] =
@@ -2635,7 +2641,7 @@ VK_IMPORT_DEVICE
 			wds[1].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
 			wds[1].pNext = NULL;
 			wds[1].dstSet = scratchBuffer.m_descriptorSet[scratchBuffer.m_currentDs];
-			wds[1].dstBinding = program.m_fsh->m_sampler[0].imageBinding;
+			wds[1].dstBinding = program.m_fsh->m_bindInfo[0].binding;
 			wds[1].dstArrayElement = 0;
 			wds[1].descriptorCount = 1;
 			wds[1].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
@@ -2646,7 +2652,7 @@ VK_IMPORT_DEVICE
 			wds[2].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
 			wds[2].pNext = NULL;
 			wds[2].dstSet = scratchBuffer.m_descriptorSet[scratchBuffer.m_currentDs];
-			wds[2].dstBinding = program.m_fsh->m_sampler[0].samplerBinding;
+			wds[2].dstBinding = program.m_fsh->m_bindInfo[0].samplerBinding;
 			wds[2].dstArrayElement = 0;
 			wds[2].descriptorCount = 1;
 			wds[2].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLER;
@@ -3297,9 +3303,42 @@ VK_IMPORT_DEVICE
 
 		VkPipeline getPipeline(ProgramHandle _program)
 		{
-			BX_UNUSED(_program);
-			// vkCreateComputePipelines
-			return VK_NULL_HANDLE;
+			ProgramVK& program = m_program[_program.idx];
+
+			bx::HashMurmur2A murmur;
+			murmur.begin();
+			murmur.add(program.m_vsh->m_hash);
+			const uint32_t hash = murmur.end();
+
+			VkPipeline pipeline = m_pipelineStateCache.find(hash);
+
+			if (VK_NULL_HANDLE != pipeline)
+			{
+				return pipeline;
+			}
+
+			VkComputePipelineCreateInfo cpci;
+			cpci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+			cpci.pNext = NULL;
+			cpci.flags = 0;
+
+			cpci.stage.sType  = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
+			cpci.stage.pNext  = NULL;
+			cpci.stage.flags  = 0;
+			cpci.stage.stage  = VK_SHADER_STAGE_COMPUTE_BIT;
+			cpci.stage.module = program.m_vsh->m_module;
+			cpci.stage.pName  = "main";
+			cpci.stage.pSpecializationInfo = NULL;
+
+			cpci.layout             = program.m_pipelineLayout;
+			cpci.basePipelineHandle = VK_NULL_HANDLE;
+			cpci.basePipelineIndex  = 0;
+
+			VK_CHECK( vkCreateComputePipelines(m_device, m_pipelineCache, 1, &cpci, m_allocatorCb, &pipeline) );
+
+			m_pipelineStateCache.add(hash, pipeline);
+
+			return pipeline;
 		}
 
 		VkPipeline getPipeline(uint64_t _state, uint64_t _stencil, uint8_t _numStreams, const VertexLayout** _vertexDecls, ProgramHandle _program, uint8_t _numInstanceData)
@@ -4085,14 +4124,16 @@ VK_DESTROY
 		m_flags   = _flags;
 		m_dynamic = NULL == _data;
 
+		bool compute = m_flags   & BGFX_BUFFER_COMPUTE_READ_WRITE;
 		VkBufferCreateInfo bci;
 		bci.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
 		bci.pNext = NULL;
 		bci.flags = 0;
 		bci.size  = _size;
 		bci.usage = 0
-//			| (m_dynamic ? VK_BUFFER_USAGE_TRANSFER_DST_BIT  : 0)
-			| (_vertex   ? VK_BUFFER_USAGE_VERTEX_BUFFER_BIT : VK_BUFFER_USAGE_INDEX_BUFFER_BIT)
+//			| (m_dynamic ? VK_BUFFER_USAGE_TRANSFER_DST_BIT   : 0)
+			| (_vertex   ? VK_BUFFER_USAGE_VERTEX_BUFFER_BIT  : VK_BUFFER_USAGE_INDEX_BUFFER_BIT)
+			| (compute   ? VK_BUFFER_USAGE_STORAGE_BUFFER_BIT : 0)
 			| VK_BUFFER_USAGE_TRANSFER_DST_BIT
 			;
 		bci.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
@@ -4335,9 +4376,10 @@ VK_DESTROY
 
 		for (uint32_t ii = 0; ii < BGFX_CONFIG_MAX_TEXTURE_SAMPLERS; ++ii)
 		{
-			m_sampler[ii].uniformHandle = {kInvalidHandle};
-			m_sampler[ii].imageBinding = 0;
-			m_sampler[ii].samplerBinding = 0;
+			m_bindInfo[ii].uniformHandle = {kInvalidHandle};
+			m_bindInfo[ii].type = UNKNOWN;
+			m_bindInfo[ii].binding = 0;
+			m_bindInfo[ii].samplerBinding = 0;
 		}
 
 		if (0 < count)
@@ -4374,7 +4416,28 @@ VK_DESTROY
 					m_predefined[m_numPredefined].m_type  = uint8_t(predefined|fragmentBit);
 					m_numPredefined++;
 				}
-				else if (UniformType::Sampler != (~BGFX_UNIFORM_MASK & type) )
+				else if (UniformType::End == (~BGFX_UNIFORM_MASK & type))
+				{
+					m_bindInfo[num].uniformHandle = {0};
+					m_bindInfo[num].type = STORAGE;
+					m_bindInfo[num].binding = regCount;	// regCount is used for buffer binding index
+					m_bindInfo[num].samplerBinding = regIndex;	// regIndex is used for descriptor type
+
+					kind = "storage";
+				}
+				else if (UniformType::Sampler == (~BGFX_UNIFORM_MASK & type) )
+				{
+					const UniformRegInfo* info = s_renderVK->m_uniformReg.find(name);
+					BX_CHECK(NULL != info, "User defined uniform '%s' is not found, it won't be set.", name);
+
+					m_bindInfo[num].uniformHandle = info->m_handle;
+					m_bindInfo[num].type = SAMPLER;
+					m_bindInfo[num].binding = regIndex;	// regIndex is used for image binding index
+					m_bindInfo[num].samplerBinding = regCount;	// regCount is used for sampler binding index
+
+					kind = "sampler";
+				}
+				else
 				{
 					const UniformRegInfo* info = s_renderVK->m_uniformReg.find(name);
 					BX_CHECK(NULL != info, "User defined uniform '%s' is not found, it won't be set.", name);
@@ -4390,17 +4453,7 @@ VK_DESTROY
 						m_constantBuffer->writeUniformHandle( (UniformType::Enum)(type|fragmentBit), regIndex, info->m_handle, regCount);
 					}
 				}
-				else
-				{
-					const UniformRegInfo* info = s_renderVK->m_uniformReg.find(name);
-					BX_CHECK(NULL != info, "User defined uniform '%s' is not found, it won't be set.", name);
 
-					m_sampler[num].uniformHandle = info->m_handle;
-					m_sampler[num].imageBinding = regIndex;	// regIndex is used for image binding index
-					m_sampler[num].samplerBinding = regCount;	// regCount is used for sampler binding index
-
-					kind = "sampler";
-				}
 
 				BX_TRACE("\t%s: %s (%s), num %2d, r.index %3d, r.count %2d"
 					, kind
@@ -4489,23 +4542,37 @@ VK_DESTROY
 				bidx++;
 			}
 
-			for (uint32_t ii = 0; ii < BX_COUNTOF(m_sampler); ++ii)
+			for (uint32_t ii = 0; ii < BX_COUNTOF(m_bindInfo); ++ii)
 			{
-				if (m_sampler[ii].imageBinding > 0 && m_sampler[ii].samplerBinding > 0)
+				switch (m_bindInfo[ii].type)
 				{
-					m_bindings[bidx].stageFlags = VK_SHADER_STAGE_ALL;
-					m_bindings[bidx].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
-					m_bindings[bidx].binding = m_sampler[ii].imageBinding;
-					m_bindings[bidx].pImmutableSamplers = NULL;
-					m_bindings[bidx].descriptorCount = 1;
-					bidx++;
-
-					m_bindings[bidx].stageFlags = VK_SHADER_STAGE_ALL;
-					m_bindings[bidx].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLER;
-					m_bindings[bidx].binding = m_sampler[ii].samplerBinding;
-					m_bindings[bidx].pImmutableSamplers = NULL;
-					m_bindings[bidx].descriptorCount = 1;
-					bidx++;
+					case STORAGE:
+						m_bindings[bidx].stageFlags = VK_SHADER_STAGE_ALL;
+						m_bindings[bidx].descriptorType = (VkDescriptorType)m_bindInfo[ii].samplerBinding;
+						m_bindings[bidx].binding = m_bindInfo[ii].binding;
+						m_bindings[bidx].pImmutableSamplers = NULL;
+						m_bindings[bidx].descriptorCount = 1;
+						bidx++;
+						break;
+
+					case SAMPLER:
+						m_bindings[bidx].stageFlags = VK_SHADER_STAGE_ALL;
+						m_bindings[bidx].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
+						m_bindings[bidx].binding = m_bindInfo[ii].binding;
+						m_bindings[bidx].pImmutableSamplers = NULL;
+						m_bindings[bidx].descriptorCount = 1;
+						bidx++;
+
+						m_bindings[bidx].stageFlags = VK_SHADER_STAGE_ALL;
+						m_bindings[bidx].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLER;
+						m_bindings[bidx].binding = m_bindInfo[ii].samplerBinding;
+						m_bindings[bidx].pImmutableSamplers = NULL;
+						m_bindings[bidx].descriptorCount = 1;
+						bidx++;
+						break;
+						
+					default:
+						break;
 				}
 			}
 
@@ -4559,17 +4626,20 @@ VK_DESTROY
 			m_numPredefined += _fsh->m_numPredefined;
 		}
 
-
 		// create exact pipeline layout
 		VkDescriptorSetLayout dsl = VK_NULL_HANDLE;
 
-		if (m_vsh->m_numBindings + m_fsh->m_numBindings > 0)
+		uint32_t numBindings = m_vsh->m_numBindings + (m_fsh ? m_fsh->m_numBindings : 0);
+		if (0 < numBindings)
 		{
 			// generate descriptor set layout hash
 			bx::HashMurmur2A murmur;
 			murmur.begin();
 			murmur.add(m_vsh->m_bindings, sizeof(VkDescriptorSetLayoutBinding) * m_vsh->m_numBindings);
-			murmur.add(m_fsh->m_bindings, sizeof(VkDescriptorSetLayoutBinding) * m_fsh->m_numBindings);
+			if (NULL != m_fsh)
+			{
+				murmur.add(m_fsh->m_bindings, sizeof(VkDescriptorSetLayoutBinding) * m_fsh->m_numBindings);
+			}
 			m_descriptorSetLayoutHash = murmur.end();
 
 			dsl = s_renderVK->m_descriptorSetLayoutCache.find(m_descriptorSetLayoutHash);
@@ -4582,17 +4652,20 @@ VK_DESTROY
 					, m_vsh->m_bindings
 					, sizeof(VkDescriptorSetLayoutBinding) * m_vsh->m_numBindings
 					);
-				bx::memCopy(
-					  bindings + m_vsh->m_numBindings
-					, m_fsh->m_bindings
-					, sizeof(VkDescriptorSetLayoutBinding) * m_fsh->m_numBindings
-					);
+				if (NULL != m_fsh)
+				{
+					bx::memCopy(
+						  bindings + m_vsh->m_numBindings
+						, m_fsh->m_bindings
+						, sizeof(VkDescriptorSetLayoutBinding) * m_fsh->m_numBindings
+						);
+				}
 
 				VkDescriptorSetLayoutCreateInfo dslci;
 				dslci.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
 				dslci.pNext = NULL;
 				dslci.flags = 0;
-				dslci.bindingCount = m_vsh->m_numBindings + m_fsh->m_numBindings;
+				dslci.bindingCount = numBindings;
 				dslci.pBindings = bindings;
 
 				VK_CHECK(vkCreateDescriptorSetLayout(
@@ -4924,7 +4997,9 @@ VK_DESTROY
 						? VK_IMAGE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT
 						: VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT)
 					: 0
-					);
+					)
+				| (_flags & BGFX_TEXTURE_COMPUTE_WRITE ? VK_IMAGE_USAGE_STORAGE_BIT : 0)
+				;
 			ici.format = bimg::isDepth(bimg::TextureFormat::Enum(m_textureFormat) )
 				? s_textureFormat[m_textureFormat].m_fmtDsv
 				: s_textureFormat[m_textureFormat].m_fmt
@@ -4964,7 +5039,8 @@ VK_DESTROY
 			else
 			{
 				VkCommandBuffer commandBuffer = s_renderVK->beginNewCommand();
-				setImageMemoryBarrier(commandBuffer, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
+				setImageMemoryBarrier(commandBuffer
+					, (m_flags & BGFX_TEXTURE_COMPUTE_WRITE? VK_IMAGE_LAYOUT_GENERAL : VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
 				s_renderVK->submitCommandAndWait(commandBuffer);
 			}
 
@@ -4998,6 +5074,28 @@ VK_DESTROY
 				viewInfo.subresourceRange.layerCount     = m_numSides; //(m_type == VK_IMAGE_VIEW_TYPE_CUBE ? 6 : m_numLayers);
 				VK_CHECK(vkCreateImageView(device, &viewInfo, &s_allocationCb, &m_textureImageView));
 			}
+
+			// image view creation for storage if needed
+			if (m_flags & BGFX_TEXTURE_COMPUTE_WRITE)
+			{
+				VkImageViewCreateInfo viewInfo;
+				viewInfo.sType        = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
+				viewInfo.pNext        = NULL;
+				viewInfo.flags        = 0;
+				viewInfo.image        = m_textureImage;
+				viewInfo.viewType     = (m_type == VK_IMAGE_VIEW_TYPE_CUBE ? VK_IMAGE_VIEW_TYPE_2D_ARRAY : m_type);
+				viewInfo.format       = m_vkTextureFormat;
+				viewInfo.components.r = VK_COMPONENT_SWIZZLE_IDENTITY;
+				viewInfo.components.g = VK_COMPONENT_SWIZZLE_IDENTITY;
+				viewInfo.components.b = VK_COMPONENT_SWIZZLE_IDENTITY;
+				viewInfo.components.a = VK_COMPONENT_SWIZZLE_IDENTITY;
+				viewInfo.subresourceRange.aspectMask     = m_vkTextureAspect;
+				viewInfo.subresourceRange.baseMipLevel   = 0;
+				viewInfo.subresourceRange.levelCount     = m_numMips; //m_numMips;
+				viewInfo.subresourceRange.baseArrayLayer = 0;
+				viewInfo.subresourceRange.layerCount     = m_numSides; //(m_type == VK_IMAGE_VIEW_TYPE_CUBE ? 6 : m_numLayers);
+				VK_CHECK(vkCreateImageView(device, &viewInfo, &s_allocationCb, &m_textureImageStorageView));
+			}
 		}
 
 		return m_directAccessPtr;
@@ -5010,6 +5108,7 @@ VK_DESTROY
 			VkDevice device = s_renderVK->m_device;
 			vkFreeMemory(device, m_textureDeviceMem, &s_allocationCb);
 
+			vkDestroy(m_textureImageStorageView);
 			vkDestroy(m_textureImageView);
 			vkDestroy(m_textureImage);
 		}
@@ -5394,7 +5493,7 @@ VK_DESTROY
 				const RenderBind& renderBind = _render->m_renderItemBind[itemIdx];
 				++item;
 
-				if (viewChanged)
+				if (viewChanged || isCompute || wasCompute)
 				{
 					if (beginRenderPass)
 					{
@@ -5412,12 +5511,12 @@ VK_DESTROY
 //					m_batch.flush(m_commandList, true);
 					kick(renderWait);
 					renderWait = VK_NULL_HANDLE;
-finishAll();
+					finishAll();
 
 					view = key.m_view;
 					currentPipeline = VK_NULL_HANDLE;
 					currentSamplerStateIdx = kInvalidHandle;
-BX_UNUSED(currentSamplerStateIdx);
+					BX_UNUSED(currentSamplerStateIdx);
 					currentProgram         = BGFX_INVALID_HANDLE;
 					hasPredefined          = false;
 
@@ -5451,38 +5550,41 @@ BX_UNUSED(currentSamplerStateIdx);
 						vkCmdBeginDebugUtilsLabelEXT(m_commandBuffer, &dul);
 					}
 
-					vkCmdBeginRenderPass(m_commandBuffer, &rpbi, VK_SUBPASS_CONTENTS_INLINE);
-					beginRenderPass = true;
-
-					VkViewport vp;
-					vp.x        = rect.m_x;
-					vp.y        = rect.m_y;
-					vp.width    = rect.m_width;
-					vp.height   = rect.m_height;
-					vp.minDepth = 0.0f;
-					vp.maxDepth = 1.0f;
-					vkCmdSetViewport(m_commandBuffer, 0, 1, &vp);
-
-					VkRect2D rc;
-					rc.offset.x      = viewScissorRect.m_x;
-					rc.offset.y      = viewScissorRect.m_y;
-					rc.extent.width  = viewScissorRect.m_width;
-					rc.extent.height = viewScissorRect.m_height;
-					vkCmdSetScissor(m_commandBuffer, 0, 1, &rc);
-
-					restoreScissor = false;
-
-					Clear& clr = _render->m_view[view].m_clear;
-					if (BGFX_CLEAR_NONE != clr.m_flags)
+					if (!isCompute && !beginRenderPass)
 					{
-						Rect clearRect = rect;
-						clearRect.setIntersect(rect, viewScissorRect);
-						clearQuad(clearRect, clr, _render->m_colorPalette);
-					}
+						vkCmdBeginRenderPass(m_commandBuffer, &rpbi, VK_SUBPASS_CONTENTS_INLINE);
+						beginRenderPass = true;
+
+						VkViewport vp;
+						vp.x        = rect.m_x;
+						vp.y        = rect.m_y;
+						vp.width    = rect.m_width;
+						vp.height   = rect.m_height;
+						vp.minDepth = 0.0f;
+						vp.maxDepth = 1.0f;
+						vkCmdSetViewport(m_commandBuffer, 0, 1, &vp);
+
+						VkRect2D rc;
+						rc.offset.x      = viewScissorRect.m_x;
+						rc.offset.y      = viewScissorRect.m_y;
+						rc.extent.width  = viewScissorRect.m_width;
+						rc.extent.height = viewScissorRect.m_height;
+						vkCmdSetScissor(m_commandBuffer, 0, 1, &rc);
+
+						restoreScissor = false;
+
+						Clear& clr = _render->m_view[view].m_clear;
+						if (BGFX_CLEAR_NONE != clr.m_flags)
+						{
+							Rect clearRect = rect;
+							clearRect.setIntersect(rect, viewScissorRect);
+							clearQuad(clearRect, clr, _render->m_colorPalette);
+						}
 
-					prim = s_primInfo[Topology::Count]; // Force primitive type update.
+						prim = s_primInfo[Topology::Count]; // Force primitive type update.
 
-					submitBlit(bs, view);
+						submitBlit(bs, view);
+					}
 				}
 
 				if (isCompute)
@@ -5625,6 +5727,128 @@ BX_UNUSED(currentSamplerStateIdx);
 //						m_commandList->SetComputeRootConstantBufferView(Rdt::CBV, gpuAddress);
 					}
 
+					{
+						ProgramVK& program = m_program[currentProgram.idx];
+						ScratchBufferVK& sb = m_scratchBuffer[m_backBufferColorIdx];
+
+						VkDescriptorSetLayout dsl = m_descriptorSetLayoutCache.find(program.m_descriptorSetLayoutHash);
+						VkDescriptorSetAllocateInfo dsai;
+						dsai.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
+						dsai.pNext = NULL;
+						dsai.descriptorPool = m_descriptorPool;
+						dsai.descriptorSetCount = 1;
+						dsai.pSetLayouts = &dsl;
+						vkAllocateDescriptorSets(m_device, &dsai, &sb.m_descriptorSet[sb.m_currentDs]);
+
+						VkDescriptorImageInfo imageInfo[BGFX_MAX_COMPUTE_BINDINGS];
+						VkDescriptorBufferInfo bufferInfo[BGFX_MAX_COMPUTE_BINDINGS];
+						VkWriteDescriptorSet wds[BGFX_MAX_COMPUTE_BINDINGS];
+						bx::memSet(wds, 0, sizeof(VkWriteDescriptorSet) * BGFX_MAX_COMPUTE_BINDINGS);
+						uint32_t wdsCount = 0;
+						uint32_t imageCount = 0;
+						uint32_t bufferCount = 0;
+						for (uint32_t stage = 0; stage < BGFX_MAX_COMPUTE_BINDINGS; ++stage)
+						{
+							const Binding& bind = renderBind.m_bind[stage];
+							if (kInvalidHandle != bind.m_idx)
+							{
+								VkDescriptorType descriptorType = (VkDescriptorType)program.m_vsh->m_bindInfo[stage].samplerBinding;
+								if (descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER)
+								{
+									VertexBufferVK& vb = m_vertexBuffers[bind.m_idx];
+									bufferInfo[bufferCount].buffer = vb.m_buffer;
+									bufferInfo[bufferCount].offset = 0;
+									bufferInfo[bufferCount].range = vb.m_size;
+
+									wds[wdsCount].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+									wds[wdsCount].pNext = NULL;
+									wds[wdsCount].dstSet = sb.m_descriptorSet[sb.m_currentDs];
+									wds[wdsCount].dstBinding = program.m_vsh->m_bindInfo[stage].binding;
+									wds[wdsCount].dstArrayElement = 0;
+									wds[wdsCount].descriptorCount = 1;
+									wds[wdsCount].descriptorType = descriptorType;
+									wds[wdsCount].pImageInfo = NULL;
+									wds[wdsCount].pBufferInfo = &bufferInfo[bufferCount];
+									wds[wdsCount].pTexelBufferView = NULL;
+									wdsCount++;
+									bufferCount++;
+								}
+								else if (descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
+								{
+									TextureVK& texture = m_textures[bind.m_idx];
+									VkSampler sampler = getSampler(
+										(0 == (BGFX_SAMPLER_INTERNAL_DEFAULT & bind.m_samplerFlags)
+											? bind.m_samplerFlags
+											: (uint32_t)texture.m_flags
+										) & (BGFX_SAMPLER_BITS_MASK | BGFX_SAMPLER_BORDER_COLOR_MASK)
+										, (uint32_t)texture.m_numMips);
+
+									imageInfo[stage].imageLayout = texture.m_currentImageLayout;
+									imageInfo[stage].imageView   = texture.m_textureImageStorageView ? texture.m_textureImageStorageView : texture.m_textureImageView;
+									imageInfo[stage].sampler     = sampler;
+
+									wds[wdsCount].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+									wds[wdsCount].pNext = NULL;
+									wds[wdsCount].dstSet = sb.m_descriptorSet[sb.m_currentDs];
+									wds[wdsCount].dstBinding = program.m_vsh->m_bindInfo[stage].binding;
+									wds[wdsCount].dstArrayElement = 0;
+									wds[wdsCount].descriptorCount = 1;
+									wds[wdsCount].descriptorType = descriptorType;
+									wds[wdsCount].pImageInfo = &imageInfo[imageCount];
+									wds[wdsCount].pBufferInfo = NULL;
+									wds[wdsCount].pTexelBufferView = NULL;
+									wdsCount++;
+								}
+							}
+						}
+
+						const uint32_t align = uint32_t(m_deviceProperties.limits.minUniformBufferOffsetAlignment);
+						const uint32_t vsize = bx::strideAlign(program.m_vsh->m_size, align);
+						const uint32_t fsize = bx::strideAlign((NULL != program.m_fsh ? program.m_fsh->m_size : 0), align);
+						const uint32_t total = vsize + fsize;
+
+						if (vsize > 0)
+						{
+							bufferInfo[bufferCount].buffer = sb.m_buffer;
+							bufferInfo[bufferCount].offset = sb.m_pos;
+							bufferInfo[bufferCount].range = vsize;
+
+							wds[wdsCount].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+							wds[wdsCount].pNext = NULL;
+							wds[wdsCount].dstSet = sb.m_descriptorSet[sb.m_currentDs];
+							wds[wdsCount].dstBinding = program.m_vsh->m_uniformBinding;
+							wds[wdsCount].dstArrayElement = 0;
+							wds[wdsCount].descriptorCount = 1;
+							wds[wdsCount].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
+							wds[wdsCount].pImageInfo = NULL;
+							wds[wdsCount].pBufferInfo = &bufferInfo[bufferCount];
+							wds[wdsCount].pTexelBufferView = NULL;
+							wdsCount++;
+							bufferCount++;
+
+							bx::memCopy(&sb.m_data[sb.m_pos], m_vsScratch, program.m_vsh->m_size);
+						}
+
+						sb.m_pos += vsize;
+
+						m_vsChanges = 0;
+						m_fsChanges = 0;
+
+						vkUpdateDescriptorSets(m_device, wdsCount, wds, 0, NULL);
+						vkCmdBindDescriptorSets(
+								m_commandBuffer
+								, VK_PIPELINE_BIND_POINT_COMPUTE
+								, program.m_pipelineLayout
+								, 0
+								, 1
+								, &sb.m_descriptorSet[sb.m_currentDs]
+								, 0
+								, NULL
+						);
+
+						sb.m_currentDs++;
+					}
+
 					if (isValid(compute.m_indirectBuffer) )
 					{
 						const VertexBufferVK& vb = m_vertexBuffers[compute.m_indirectBuffer.idx];
@@ -5637,13 +5861,13 @@ BX_UNUSED(currentSamplerStateIdx);
 						uint32_t args = compute.m_startIndirect * BGFX_CONFIG_DRAW_INDIRECT_STRIDE;
 						for (uint32_t ii = 0; ii < numDrawIndirect; ++ii)
 						{
-//							m_commandList->ExecuteIndirect(ptr, args);
+							vkCmdDispatchIndirect(m_commandBuffer, vb.m_buffer, args);
 							args += BGFX_CONFIG_DRAW_INDIRECT_STRIDE;
 						}
 					}
 					else
 					{
-//						m_commandList->Dispatch(compute.m_numX, compute.m_numY, compute.m_numZ);
+						vkCmdDispatch(m_commandBuffer, compute.m_numX, compute.m_numY, compute.m_numZ);
 					}
 
 					continue;
@@ -5963,7 +6187,6 @@ BX_UNUSED(currentSamplerStateIdx);
 						dsai.pSetLayouts = &dsl;
 						vkAllocateDescriptorSets(m_device, &dsai, &sb.m_descriptorSet[sb.m_currentDs]);
 
-
 						VkDescriptorImageInfo imageInfo[BGFX_CONFIG_MAX_TEXTURE_SAMPLERS];
 						VkDescriptorBufferInfo bufferInfo[16];
 						VkWriteDescriptorSet wds[BGFX_CONFIG_MAX_TEXTURE_SAMPLERS];
@@ -5974,7 +6197,7 @@ BX_UNUSED(currentSamplerStateIdx);
 						{
 							const Binding& bind = renderBind.m_bind[stage];
 							if (kInvalidHandle != bind.m_idx &&
-								isValid(program.m_fsh->m_sampler[stage].uniformHandle))
+								isValid(program.m_fsh->m_bindInfo[stage].uniformHandle))
 							{
 								TextureVK& texture = m_textures[bind.m_idx];
 								VkSampler sampler = getSampler(
@@ -5991,7 +6214,7 @@ BX_UNUSED(currentSamplerStateIdx);
 								wds[wdsCount].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
 								wds[wdsCount].pNext = NULL;
 								wds[wdsCount].dstSet = sb.m_descriptorSet[sb.m_currentDs];
-								wds[wdsCount].dstBinding = program.m_fsh->m_sampler[stage].imageBinding;
+								wds[wdsCount].dstBinding = program.m_fsh->m_bindInfo[stage].binding;
 								wds[wdsCount].dstArrayElement = 0;
 								wds[wdsCount].descriptorCount = 1;
 								wds[wdsCount].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
@@ -6003,7 +6226,7 @@ BX_UNUSED(currentSamplerStateIdx);
 								wds[wdsCount].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
 								wds[wdsCount].pNext = NULL;
 								wds[wdsCount].dstSet = sb.m_descriptorSet[sb.m_currentDs];
-								wds[wdsCount].dstBinding = program.m_fsh->m_sampler[stage].samplerBinding;
+								wds[wdsCount].dstBinding = program.m_fsh->m_bindInfo[stage].samplerBinding;
 								wds[wdsCount].dstArrayElement = 0;
 								wds[wdsCount].descriptorCount = 1;
 								wds[wdsCount].descriptorType = VK_DESCRIPTOR_TYPE_SAMPLER;

+ 12 - 3
src/renderer_vk.h

@@ -414,13 +414,20 @@ VK_DESTROY
 		uint8_t m_numPredefined;
 		uint8_t m_numAttrs;
 
-		struct SamplerInfo
+		enum BindType
+		{
+			UNKNOWN,
+			STORAGE,
+			SAMPLER,
+		};
+		struct BindInfo
 		{
 			UniformHandle uniformHandle;
+			BindType type;
+			uint32_t binding;
 			uint32_t samplerBinding;
-			uint32_t imageBinding;
 		};
-		SamplerInfo m_sampler[BGFX_CONFIG_MAX_TEXTURE_SAMPLERS];
+		BindInfo m_bindInfo[BGFX_CONFIG_MAX_TEXTURE_SAMPLERS];
 		uint32_t m_uniformBinding;
 		uint16_t m_numBindings;
 		VkDescriptorSetLayoutBinding m_bindings[32];
@@ -456,6 +463,7 @@ VK_DESTROY
 			, m_textureImage(VK_NULL_HANDLE)
 			, m_textureDeviceMem(VK_NULL_HANDLE)
 			, m_textureImageView(VK_NULL_HANDLE)
+			, m_textureImageStorageView(VK_NULL_HANDLE)
 			, m_currentImageLayout(VK_IMAGE_LAYOUT_UNDEFINED)
 		{
 		}
@@ -484,6 +492,7 @@ VK_DESTROY
 		VkImage m_textureImage;
 		VkDeviceMemory m_textureDeviceMem;
 		VkImageView m_textureImageView;
+		VkImageView m_textureImageStorageView;
 		VkImageLayout m_currentImageLayout;
 	};
 

+ 113 - 29
tools/shaderc/shaderc_spirv.cpp

@@ -53,6 +53,7 @@ namespace bgfx
 namespace stl = tinystl;
 
 #include "../../src/shader_spirv.h"
+#include "../../3rdparty/khronos/vulkan/vulkan.h"
 
 namespace bgfx { namespace spirv
 {
@@ -566,6 +567,32 @@ namespace bgfx { namespace spirv
 	};
 	BX_STATIC_ASSERT(bgfx::Attrib::Count == BX_COUNTOF(s_attribName) );
 
+	int32_t extractStageNumber(const std::string _strLine)
+	{
+		bx::StringView found = bx::findIdentifierMatch(_strLine.c_str(), "register");
+		const char* ptr = found.getPtr() + found.getLength();
+		const char* start = NULL;
+		const char* end = NULL;
+		while (*ptr != ')' && ptr < _strLine.c_str() + _strLine.size())
+		{
+			if (*ptr >= '0' && *ptr <= '9')
+			{
+				if (start == NULL)
+					start = ptr;
+				end = ptr;
+			}
+			ptr++;
+		}
+		BX_CHECK(start != NULL && end != NULL, "cannot find register number");
+
+		bx::StringView numberString(start, end - start + 1);
+		int32_t regNumber = -1;
+		bx::fromString(&regNumber, numberString);
+		BX_CHECK(regNumber >= 0, "register number should be semi-positive integer");
+
+		return regNumber;
+	}
+
 	bgfx::Attrib::Enum toAttribEnum(const bx::StringView& _name)
 	{
 		for (uint8_t ii = 0; ii < Attrib::Count; ++ii)
@@ -642,6 +669,8 @@ namespace bgfx { namespace spirv
 		shader->setShiftBinding(glslang::EResUbo, bindingOffset);
 		shader->setShiftBinding(glslang::EResTexture, bindingOffset + 16);
 		shader->setShiftBinding(glslang::EResSampler, bindingOffset + 32);
+		shader->setShiftBinding(glslang::EResSsbo, bindingOffset + 16);
+		shader->setShiftBinding(glslang::EResImage, bindingOffset + 32);
 
 		const char* shaderStrings[] = { _code.c_str() };
 		shader->setStrings(
@@ -711,7 +740,7 @@ namespace bgfx { namespace spirv
 			{
 				program->buildReflection();
 
-				std::map<std::string, uint32_t> samplerStageMap;
+				std::map<std::string, uint32_t> stageMap;
 				if (_firstPass)
 				{
 					// first time through, we just find unused uniforms and get rid of them
@@ -784,34 +813,14 @@ namespace bgfx { namespace spirv
 								if (!bx::findIdentifierMatch(strLine.c_str(), "SamplerState").isEmpty() ||
 									!bx::findIdentifierMatch(strLine.c_str(), "SamplerComparisonState").isEmpty())
 								{
-									bx::StringView found = bx::findIdentifierMatch(strLine.c_str(), "register");
-									const char* ptr = found.getPtr() + found.getLength();
-									const char* start = NULL;
-									const char* end = NULL;
-									while (*ptr != ')' && ptr < strLine.c_str() + strLine.size())
-									{
-										if (*ptr >= '0' && *ptr <= '9')
-										{
-											if (start == NULL)
-												start = ptr;
-											end = ptr;
-										}
-										ptr++;
-									}
-									BX_CHECK(start != NULL && end != NULL, "SamplerState should have register number");
-
-									bx::StringView numberString(start, end - start + 1);
-									int32_t regNumber = -1;
-									bx::fromString(&regNumber, numberString);
-									BX_CHECK(regNumber >= 0, "register number should be semi-positive integer");
-
-									found = bx::findIdentifierMatch(strLine.c_str(), "SamplerState");
+									int32_t regNumber = extractStageNumber(strLine);
+									bx::StringView found = bx::findIdentifierMatch(strLine.c_str(), "SamplerState");
 									if (found.isEmpty())
 										found = bx::findIdentifierMatch(strLine.c_str(), "SamplerComparisonState");
 
-									ptr = found.getPtr() + found.getLength();
-									start = NULL;
-									end = NULL;
+									const char* ptr = found.getPtr() + found.getLength();
+									const char* start = NULL;
+									const char* end = NULL;
 									while (ptr < strLine.c_str() + strLine.size())
 									{
 										if (*ptr != ' ')
@@ -829,9 +838,42 @@ namespace bgfx { namespace spirv
 									BX_CHECK(start != NULL && end != NULL, "sampler name cannot be found");
 
 									std::string samplerName(start, end - start + 1);
-									samplerStageMap[samplerName] = regNumber;
+									stageMap[samplerName] = regNumber;
 								}
 							}
+							else if (!bx::findIdentifierMatch(strLine.c_str(), "StructuredBuffer").isEmpty() ||
+								!bx::findIdentifierMatch(strLine.c_str(), "RWStructuredBuffer").isEmpty())
+							{
+								int32_t regNumber = extractStageNumber(strLine);
+
+								const char* ptr = strLine.c_str();
+								const char* start = NULL;
+								const char* end = NULL;
+								while (ptr < strLine.c_str() + strLine.size())
+								{
+									if (*ptr == '>')
+									{
+										start = ptr + 1;
+										while (*start == ' ')
+											start++;
+									}
+									if (*ptr == ':')
+									{
+										end = ptr - 1;
+										while (*end == ' ')
+											end--;
+									}
+									if (start != NULL && end != NULL)
+									{
+										break;
+									}
+									ptr++;
+								}
+								BX_CHECK(start != NULL && end != NULL, "sampler name cannot be found");
+
+								std::string bufferName(start, end - start + 1);
+								stageMap[bufferName] = regNumber;
+							}
 						}
 					}
 				}
@@ -939,20 +981,62 @@ namespace bgfx { namespace spirv
 								sampler_name = refl.get_name(sampler_resource.id);
 								if (sampler_name.size() > 7 &&
 									!bx::strFind(sampler_name.c_str(), uniform_name.c_str()).isEmpty() &&
-									0 == bx::strCmp(sampler_name.c_str() + name.length() - 7, "Sampler"))
+									(0 == bx::strCmp(sampler_name.c_str() + name.length() - 7, "Sampler") ||
+									0 == bx::strCmp(sampler_name.c_str() + name.length() - 7, "SamplerComparison")))
 								{
 									sampler_binding_index = refl.get_decoration(sampler_resource.id, spv::Decoration::DecorationBinding);
 									break;
 								}
 							}
 
-							un.num = samplerStageMap[sampler_name];	// want to write stage index
+							un.num = stageMap[sampler_name];	// want to write stage index
 							un.regIndex = texture_binding_index;	// for sampled image binding index
 							un.regCount = sampler_binding_index;	// for sampler binding index
 
 							uniforms.push_back(un);
 						}
 					}
+
+					// Loop through the separate_images, and extract the uniform names:
+					for (auto &resource : resourcesrefl.storage_images)
+					{
+						std::string name = refl.get_name(resource.id);
+						if (name.size() > 7 && 0 == bx::strCmp(name.c_str() + name.length() - 7, "Texture") )
+						{
+							auto uniform_name = name.substr(0, name.length() - 7);
+							uint32_t binding_index = refl.get_decoration(resource.id, spv::Decoration::DecorationBinding);
+							std::string sampler_name = uniform_name + "Sampler";
+
+							Uniform un;
+							un.name = uniform_name;
+							un.type = UniformType::End;
+							un.num = stageMap[sampler_name];	// want to write stage index
+							un.regIndex = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;	// for descriptor type
+							un.regCount = binding_index; // for image binding index
+
+							uniforms.push_back(un);
+						}
+					}
+
+					// Loop through the storage buffer, and extract the uniform names:
+					for (auto& resource : resourcesrefl.storage_buffers)
+					{
+						std::string name = refl.get_name(resource.id);
+						for (auto& uniform : uniforms)
+						{
+							if (!bx::strFind(uniform.name.c_str(), name.c_str()).isEmpty())
+							{
+								uint32_t binding_index = refl.get_decoration(resource.id, spv::Decoration::DecorationBinding);
+								uniform.name = name;
+								uniform.type = UniformType::End;
+								uniform.num = stageMap[name];
+								uniform.regIndex = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+								uniform.regCount = binding_index;
+								break;
+							}
+						}
+					}
+
 					uint16_t size = writeUniformArray( _writer, uniforms, _options.shaderType == 'f');
 
 					if (_version == BX_MAKEFOURCC('M', 'T', 'L', 0))