Browse Source

vulkan: implement compute shaders
A lot of stuff probably isn't implemented yet, however the
following code already works:

```lua
local shader = love.graphics.newComputeShader [[
layout (local_size_x = 1, local_size_y = 1) in;

layout(r32f) uniform highp image2D out_tex;

uniform float time;

void computemain() {
imageStore(out_tex, ivec2(love_GlobalThreadID.xy), vec4(time, 1.0, 1.0, 1.0));
}
]]

local tex = love.graphics.newTexture(64, 64, {computewrite = true,format = "r32f"})

function love.draw()
shader:send("out_tex", tex)
shader:send("time", love.timer.getTime() % 1)
love.graphics.dispatchThreadgroups(shader, 64, 64, 1)

love.graphics.print("compute shader test")
love.graphics.draw(tex, 25, 25)
end
```

niki 3 years ago
parent
commit
a0ee9f7b9b

+ 84 - 8
src/modules/graphics/vulkan/Graphics.cpp

@@ -159,7 +159,8 @@ void Graphics::submitGpuCommands(bool present) {
 	imagesInFlight[imageIndex] = inFlightFences[currentFrame];
 
 	std::vector<VkCommandBuffer> submitCommandbuffers = { 
-		dataTransferCommandBuffers.at(currentFrame), 
+		dataTransferCommandBuffers.at(currentFrame),
+		computeCommandBuffers.at(currentFrame),
 		commandBuffers.at(currentFrame), 
 		readbackCommandBuffers.at(currentFrame)};
 
@@ -642,6 +643,16 @@ graphics::StreamBuffer* Graphics::newStreamBuffer(BufferUsage type, size_t size)
 	return new StreamBuffer(this, type, size);
 }
 
+bool Graphics::dispatch(int x, int y, int z) {
+	vkCmdBindPipeline(computeCommandBuffers.at(currentFrame), VK_PIPELINE_BIND_POINT_COMPUTE, computeShader->getComputePipeline());
+
+	computeShader->cmdPushDescriptorSets(computeCommandBuffers.at(currentFrame), currentFrame, VK_PIPELINE_BIND_POINT_COMPUTE);
+
+	vkCmdDispatch(computeCommandBuffers.at(currentFrame), static_cast<uint32_t>(x), static_cast<uint32_t>(y), static_cast<uint32_t>(z));
+
+	return true;
+}
+
 Matrix4 Graphics::computeDeviceProjection(const Matrix4& projection, bool rendertotexture) const {
 	uint32 flags = DEVICE_PROJECTION_DEFAULT;
 	return calculateDeviceProjection(projection, flags);
@@ -735,7 +746,7 @@ void Graphics::beginFrame() {
 void Graphics::startRecordingGraphicsCommands(bool newFrame) {
 	VkCommandBufferBeginInfo beginInfo{};
 	beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
-	beginInfo.flags = 0;
+	beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
 	beginInfo.pInheritanceInfo = nullptr;
 
 	if (vkBeginCommandBuffer(commandBuffers.at(currentFrame), &beginInfo) != VK_SUCCESS) {
@@ -747,6 +758,9 @@ void Graphics::startRecordingGraphicsCommands(bool newFrame) {
 	if (vkBeginCommandBuffer(readbackCommandBuffers.at(currentFrame), &beginInfo) != VK_SUCCESS) {
 		throw love::Exception("failed to begin recording readback command buffer");
 	}
+	if (vkBeginCommandBuffer(computeCommandBuffers.at(currentFrame), &beginInfo) != VK_SUCCESS) {
+		throw love::Exception("failed to begin recording compute command buffer");
+	}
 
 	initDynamicState();
 
@@ -773,6 +787,9 @@ void Graphics::endRecordingGraphicsCommands(bool present) {
 	if (vkEndCommandBuffer(readbackCommandBuffers.at(currentFrame)) != VK_SUCCESS) {
 		throw love::Exception("failed to record read back command buffer");
 	}
+	if (vkEndCommandBuffer(computeCommandBuffers.at(currentFrame)) != VK_SUCCESS) {
+		throw love::Exception("failed to record compute command buffer");
+	}
 }
 
 void Graphics::updatedBatchedDrawBuffers() {
@@ -804,12 +821,54 @@ VkCommandBuffer Graphics::getReadbackCommandBuffer() {
 	return readbackCommandBuffers.at(currentFrame);
 }
 
+void Graphics::oneTimeCommand(std::function<void(VkCommandBuffer)> cmd) {
+	VkCommandBufferAllocateInfo allocInfo{};
+	allocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+	allocInfo.commandPool = commandPool;
+	allocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+	allocInfo.commandBufferCount = 1;
+
+	VkCommandBuffer commandBuffer;
+	if (vkAllocateCommandBuffers(device, &allocInfo, &commandBuffer) != VK_SUCCESS) {
+		throw love::Exception("failed to allocate one time command buffer");
+	}
+
+	VkCommandBufferBeginInfo beginInfo{};
+	beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+	beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+
+	if (vkBeginCommandBuffer(commandBuffer, &beginInfo) != VK_SUCCESS) {
+		throw love::Exception("failed to start recording one time command buffer");
+	}
+
+	cmd(commandBuffer);
+
+	if (vkEndCommandBuffer(commandBuffer) != VK_SUCCESS) {
+		throw love::Exception("failed to end recording one time command buffer");
+	}
+
+	VkSubmitInfo submitInfo{};
+	submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+	submitInfo.commandBufferCount = 1;
+	submitInfo.pCommandBuffers = &commandBuffer;
+
+	if (vkQueueSubmit(graphicsQueue, 1, &submitInfo, VK_NULL_HANDLE) != VK_SUCCESS) {
+		throw love::Exception("failed to submit to queue");
+	}
+	
+	if (vkQueueWaitIdle(graphicsQueue) != VK_SUCCESS) {
+		throw love::Exception("failed to wait for queue idle");
+	}
+
+	vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
+}
+
 void Graphics::queueCleanUp(std::function<void()> cleanUp) {
-	cleanUpFunctions.at(currentFrame).push_back(std::move(cleanUp));
+	cleanUpFunctions.at(currentFrame).push_back(cleanUp);
 }
 
-void Graphics::addReadbackCallback(const std::function<void()>& callback) {
-	readbackCallbacks.at(currentFrame).push_back(std::move(callback));
+void Graphics::addReadbackCallback(std::function<void()> callback) {
+	readbackCallbacks.at(currentFrame).push_back(callback);
 }
 
 graphics::Shader::BuiltinUniformData Graphics::getCurrentBuiltinUniformData() {
@@ -1047,7 +1106,7 @@ QueueFamilyIndices Graphics::findQueueFamilies(VkPhysicalDevice device) {
 
 	int i = 0;
 	for (const auto& queueFamily : queueFamilies) {
-		if (queueFamily.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
+		if (queueFamily.queueFlags & VK_QUEUE_GRAPHICS_BIT && queueFamily.queueFlags & VK_QUEUE_COMPUTE_BIT) {
 			indices.graphicsFamily = i;
 		}
 
@@ -1089,7 +1148,7 @@ void Graphics::createLogicalDevice() {
 	QueueFamilyIndices indices = findQueueFamilies(physicalDevice);
 
 	std::vector<VkDeviceQueueCreateInfo> queueCreateInfos;
-	std::set<uint32_t> uniqueQueueFamilies = { indices.graphicsFamily.value(), indices.presentFamily.value() };
+	std::set<uint32_t> uniqueQueueFamilies = { indices.graphicsFamily.value(), indices.presentFamily.value()};
 
 	float queuePriority = 1.0f;
 	for (uint32_t queueFamily : uniqueQueueFamilies) {
@@ -1764,7 +1823,7 @@ void Graphics::prepareDraw(const VertexAttributes& attributes, const BufferBindi
 
 	ensureGraphicsPipelineConfiguration(configuration);
 
-	configuration.shader->cmdPushDescriptorSets(commandBuffers.at(currentFrame), static_cast<uint32_t>(currentFrame));
+	configuration.shader->cmdPushDescriptorSets(commandBuffers.at(currentFrame), static_cast<uint32_t>(currentFrame), VK_PIPELINE_BIND_POINT_GRAPHICS);
 	vkCmdBindVertexBuffers(commandBuffers.at(currentFrame), 0, static_cast<uint32_t>(bufferVector.size()), bufferVector.data(), offsets.data());
 }
 
@@ -1926,6 +1985,10 @@ VkSampler Graphics::createSampler(const SamplerState& samplerState) {
 	return sampler;
 }
 
+void Graphics::setComputeShader(Shader* shader) {
+	computeShader = shader;
+}
+
 VkSampler Graphics::getCachedSampler(const SamplerState& samplerState) {
 	auto it = samplers.find(samplerState);
 	if (it != samplers.end()) {
@@ -2266,6 +2329,7 @@ void Graphics::createCommandBuffers() {
 	commandBuffers.resize(MAX_FRAMES_IN_FLIGHT);
 	dataTransferCommandBuffers.resize(MAX_FRAMES_IN_FLIGHT);
 	readbackCommandBuffers.resize(MAX_FRAMES_IN_FLIGHT);
+	computeCommandBuffers.resize(MAX_FRAMES_IN_FLIGHT);
 
 	VkCommandBufferAllocateInfo allocInfo{};
 	allocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
@@ -2296,6 +2360,16 @@ void Graphics::createCommandBuffers() {
 	if (vkAllocateCommandBuffers(device, &readbackAllocInfo, readbackCommandBuffers.data()) != VK_SUCCESS) {
 		throw love::Exception("failed to allocate readback command buffers");
 	}
+
+	VkCommandBufferAllocateInfo commandAllocInfo{};
+	commandAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+	commandAllocInfo.commandPool = commandPool;
+	commandAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+	commandAllocInfo.commandBufferCount = static_cast<uint32_t>(MAX_FRAMES_IN_FLIGHT);
+
+	if (vkAllocateCommandBuffers(device, &commandAllocInfo, computeCommandBuffers.data()) != VK_SUCCESS) {
+		throw love::Exception("failed to allocate compute command buffers");
+	}
 }
 
 void Graphics::createSyncObjects() {
@@ -2347,6 +2421,8 @@ void Graphics::cleanup() {
 
 	vkFreeCommandBuffers(device, commandPool, MAX_FRAMES_IN_FLIGHT, commandBuffers.data());
 	vkFreeCommandBuffers(device, commandPool, MAX_FRAMES_IN_FLIGHT, dataTransferCommandBuffers.data());
+	vkFreeCommandBuffers(device, commandPool, MAX_FRAMES_IN_FLIGHT, readbackCommandBuffers.data());
+	vkFreeCommandBuffers(device, commandPool, MAX_FRAMES_IN_FLIGHT, computeCommandBuffers.data());
 
 	for (auto const& p : samplers) {
 		vkDestroySampler(device, p.second, nullptr);

+ 9 - 3
src/modules/graphics/vulkan/Graphics.h

@@ -156,7 +156,7 @@ struct QueueFamilyIndices {
 	std::optional<uint32_t> graphicsFamily;
 	std::optional<uint32_t> presentFamily;
 
-	bool isComplete() {
+	bool isComplete() const {
 		return graphicsFamily.has_value() && presentFamily.has_value();
 	}
 };
@@ -226,8 +226,10 @@ public:
 	VkCommandBuffer getDataTransferCommandBuffer();
 	VkCommandBuffer getReadbackCommandBuffer();
 
+	void oneTimeCommand(std::function<void(VkCommandBuffer)> cmd);
+
 	void queueCleanUp(std::function<void()> cleanUp);
-	void addReadbackCallback(const std::function<void()> &callback);
+	void addReadbackCallback(std::function<void()> callback);
 
 	void submitGpuCommands(bool present);
 
@@ -236,11 +238,13 @@ public:
 	graphics::Texture* getDefaultTexture() const;
 	VkSampler getCachedSampler(const SamplerState&);
 
+	void setComputeShader(Shader*);
+
 protected:
 	graphics::ShaderStage* newShaderStageInternal(ShaderStageType stage, const std::string& cachekey, const std::string& source, bool gles) override;
 	graphics::Shader* newShaderInternal(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) override;
 	graphics::StreamBuffer* newStreamBuffer(BufferUsage type, size_t size) override;
-	bool dispatch(int x, int y, int z) override { return false; }
+	bool dispatch(int x, int y, int z) override;
 	void initCapabilities() override;
 	void getAPIStats(int& shaderswitches) const override;
 	void setRenderTargetsInternal(const RenderTargets& rts, int pixelw, int pixelh, bool hasSRGBtexture) override;
@@ -333,8 +337,10 @@ private:
 	std::unordered_map<SamplerState, VkSampler, SamplerStateHasher> samplers;
 	VkCommandPool commandPool = VK_NULL_HANDLE;
 	std::vector<VkCommandBuffer> dataTransferCommandBuffers;
+	std::vector<VkCommandBuffer> computeCommandBuffers;
 	std::vector<VkCommandBuffer> commandBuffers;
 	std::vector<VkCommandBuffer> readbackCommandBuffers;
+	Shader* computeShader = nullptr;
 	std::vector<VkSemaphore> imageAvailableSemaphores;
 	std::vector<VkSemaphore> renderFinishedSemaphores;
 	std::vector<VkFence> inFlightFences;

+ 155 - 49
src/modules/graphics/vulkan/Shader.cpp

@@ -152,6 +152,8 @@ Shader::Shader(StrongRef<love::graphics::ShaderStage> stages[])
 }
 
 bool Shader::loadVolatile() {
+	computePipeline = VK_NULL_HANDLE;
+
 	for (int i = 0; i < BUILTIN_MAX_ENUM; i++) {
 		builtinUniformInfo[i] = nullptr;
 	}
@@ -176,7 +178,7 @@ void Shader::unloadVolatile() {
 	}
 
 	auto gfx = Module::getInstance<Graphics>(Module::M_GRAPHICS);
-	gfx->queueCleanUp([shaderModules = std::move(shaderModules), device = device, descriptorSetLayout = descriptorSetLayout, pipelineLayout = pipelineLayout, descriptorPools = descriptorPools](){
+	gfx->queueCleanUp([shaderModules = std::move(shaderModules), device = device, descriptorSetLayout = descriptorSetLayout, pipelineLayout = pipelineLayout, descriptorPools = descriptorPools, computePipeline = computePipeline](){
 		for (const auto pool : descriptorPools) {
 			vkDestroyDescriptorPool(device, pool, nullptr);
 		}
@@ -185,6 +187,8 @@ void Shader::unloadVolatile() {
 		}
 		vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
 		vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
+		if (computePipeline != VK_NULL_HANDLE)
+			vkDestroyPipeline(device, computePipeline, nullptr);
 	});
 	for (const auto &streamBufferVector : streamBuffers) {
 		for (const auto streamBuffer : streamBufferVector) {
@@ -206,19 +210,24 @@ const VkPipelineLayout Shader::getGraphicsPipelineLayout() const {
 	return pipelineLayout;
 }
 
-static VkDescriptorImageInfo* createDescriptorImageInfo(graphics::Texture* texture) {
+VkPipeline Shader::getComputePipeline() const {
+	return computePipeline;
+}
+
+static VkDescriptorImageInfo* createDescriptorImageInfo(graphics::Texture* texture, bool sampler) {
 	auto vkTexture = (Texture*)texture;
 
 	auto imageInfo = new VkDescriptorImageInfo();
 
-	imageInfo->imageLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
+	imageInfo->imageLayout = vkTexture->getImageLayout();
 	imageInfo->imageView = (VkImageView)vkTexture->getRenderTargetHandle();
-	imageInfo->sampler = (VkSampler)vkTexture->getSamplerHandle();
+	if (sampler)
+		imageInfo->sampler = (VkSampler)vkTexture->getSamplerHandle();
 
 	return imageInfo;
 }
 
-void Shader::cmdPushDescriptorSets(VkCommandBuffer commandBuffer, uint32_t frameIndex) {
+void Shader::cmdPushDescriptorSets(VkCommandBuffer commandBuffer, uint32_t frameIndex, VkPipelineBindPoint bindPoint) {
 	// detect whether a new frame has begun
 	if (currentFrame != frameIndex) {
 		currentFrame = frameIndex;
@@ -255,35 +264,34 @@ void Shader::cmdPushDescriptorSets(VkCommandBuffer commandBuffer, uint32_t frame
 		descriptorSetsVector.at(currentFrame).push_back(allocateDescriptorSet());
 	}
 
-	// additional data is always added onto the last stream buffer in the current frame
-	auto currentStreamBuffer = streamBuffers.at(currentFrame).back();
-
-	auto mapInfo = currentStreamBuffer->map(uniformBufferSizeAligned);
-	memcpy(mapInfo.data, localUniformStagingData.data(), localUniformStagingData.size());
-	currentStreamBuffer->unmap(uniformBufferSizeAligned);
-	currentStreamBuffer->markUsed(uniformBufferSizeAligned);
-
-	VkDescriptorBufferInfo bufferInfo{};
-	bufferInfo.buffer = (VkBuffer)currentStreamBuffer->getHandle();
-	bufferInfo.offset = currentUsedUniformStreamBuffersCount * uniformBufferSizeAligned;
-	bufferInfo.range = localUniformStagingData.size();
-	
 	VkDescriptorSet currentDescriptorSet = descriptorSetsVector.at(currentFrame).at(currentUsedDescriptorSetsCount);
-
 	std::vector<VkWriteDescriptorSet> descriptorWrite{};
 
-	// uniform buffer update always happens
-	// (are there cases without ubos at all?)
-	VkWriteDescriptorSet uniformWrite{};
-	uniformWrite.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
-	uniformWrite.dstSet = currentDescriptorSet;
-	uniformWrite.dstBinding = builtinUniformInfo[BUILTIN_UNIFORMS_PER_DRAW]->location;
-	uniformWrite.dstArrayElement = 0;
-	uniformWrite.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
-	uniformWrite.descriptorCount = 1;
-	uniformWrite.pBufferInfo = &bufferInfo;			
-
-	descriptorWrite.push_back(uniformWrite);
+	if (!localUniformStagingData.empty()) {
+		// additional data is always added onto the last stream buffer in the current frame
+		auto currentStreamBuffer = streamBuffers.at(currentFrame).back();
+
+		auto mapInfo = currentStreamBuffer->map(uniformBufferSizeAligned);
+		memcpy(mapInfo.data, localUniformStagingData.data(), localUniformStagingData.size());
+		currentStreamBuffer->unmap(uniformBufferSizeAligned);
+		currentStreamBuffer->markUsed(uniformBufferSizeAligned);
+
+		VkDescriptorBufferInfo bufferInfo{};
+		bufferInfo.buffer = (VkBuffer)currentStreamBuffer->getHandle();
+		bufferInfo.offset = currentUsedUniformStreamBuffersCount * uniformBufferSizeAligned;
+		bufferInfo.range = localUniformStagingData.size();
+
+		VkWriteDescriptorSet uniformWrite{};
+		uniformWrite.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+		uniformWrite.dstSet = currentDescriptorSet;
+		uniformWrite.dstBinding = uniformLocation;
+		uniformWrite.dstArrayElement = 0;
+		uniformWrite.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
+		uniformWrite.descriptorCount = 1;
+		uniformWrite.pBufferInfo = &bufferInfo;			
+
+		descriptorWrite.push_back(uniformWrite);
+	}
 
 	std::vector<VkDescriptorImageInfo*> imageInfos;
 	
@@ -299,13 +307,28 @@ void Shader::cmdPushDescriptorSets(VkCommandBuffer commandBuffer, uint32_t frame
 			write.descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
 			write.descriptorCount = 1;
 
-			VkDescriptorImageInfo* imageInfo = createDescriptorImageInfo(val.textures[0]);	// fixme: arrays
+			VkDescriptorImageInfo* imageInfo = createDescriptorImageInfo(val.textures[0], true);	// fixme: arrays
 			imageInfos.push_back(imageInfo);
 
 			write.pImageInfo = imageInfo;
 
 			descriptorWrite.push_back(write);
 		}
+		if (val.baseType == UNIFORM_STORAGETEXTURE) {
+			VkWriteDescriptorSet write{};
+			write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+			write.dstSet = currentDescriptorSet;
+			write.dstBinding = val.location;
+			write.dstArrayElement = 0;
+			write.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
+			write.descriptorCount = 1;
+
+			VkDescriptorImageInfo* imageInfo = createDescriptorImageInfo(val.textures[0], false);	// fixme: arrays
+			imageInfos.push_back(imageInfo);
+
+			write.pImageInfo = imageInfo;
+			descriptorWrite.push_back(write);
+		}
 	}
 
 	vkUpdateDescriptorSets(device, static_cast<uint32_t>(descriptorWrite.size()), descriptorWrite.data(), 0, nullptr);
@@ -314,7 +337,7 @@ void Shader::cmdPushDescriptorSets(VkCommandBuffer commandBuffer, uint32_t frame
 		delete imageInfo;
 	}
 
-	vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout, 0, 1, &currentDescriptorSet, 0, nullptr);
+	vkCmdBindDescriptorSets(commandBuffer, bindPoint, pipelineLayout, 0, 1, &currentDescriptorSet, 0, nullptr);
 
 	currentUsedUniformStreamBuffersCount++;
 	currentUsedDescriptorSetsCount++;
@@ -325,11 +348,15 @@ Shader::~Shader() {
 }
 
 void Shader::attach() {
-	if (Shader::current != this) {
-		Graphics::flushBatchedDrawsGlobal();
-		Shader::current = this;
-		Vulkan::shaderSwitch();
+	if (!isCompute) {
+		if (Shader::current != this) {
+			Graphics::flushBatchedDrawsGlobal();
+			Shader::current = this;
+			Vulkan::shaderSwitch();
+		}
 	}
+	else
+		((Graphics*)gfx)->setComputeShader(this);
 }
 
 int Shader::getVertexAttributeIndex(const std::string& name) {
@@ -349,7 +376,8 @@ void Shader::sendTextures(const UniformInfo* info, graphics::Texture** textures,
 		auto oldTexture = info->textures[i];
 		info->textures[i] = textures[i];
 		info->textures[i]->retain();
-		oldTexture->release();
+		if (oldTexture)
+			oldTexture->release();
 	}
 }
 
@@ -458,6 +486,10 @@ void Shader::compileShaders() {
 			continue;
 
 		auto stage = (ShaderStageType)i;
+
+		if (stage == SHADERSTAGE_COMPUTE)
+			isCompute = true;
+
 		auto glslangShaderStage = getGlslShaderType(stage);
 		auto tshader = new TShader(glslangShaderStage);
 
@@ -475,12 +507,12 @@ void Shader::compileShaders() {
 		const int sourceLength = static_cast<int>(glsl.length());
 		tshader->setStringsWithLengths(&csrc, &sourceLength, 1);
 
-		int defaultVersio = 450;
+		int defaultVersion = 450;
 		EProfile defaultProfile = ECoreProfile;
 		bool forceDefault = false;
 		bool forwardCompat = true;
 
-		if (!tshader->parse(&defaultTBuiltInResource, defaultVersio, defaultProfile, forceDefault, forwardCompat, EShMsgSuppressWarnings)) {
+		if (!tshader->parse(&defaultTBuiltInResource, defaultVersion, defaultProfile, forceDefault, forwardCompat, EShMsgSuppressWarnings)) {
 			const char* msg1 = tshader->getInfoLog();
 			const char* msg2 = tshader->getInfoDebugLog();
 
@@ -633,6 +665,57 @@ void Shader::compileShaders() {
 				builtinUniformInfo[builtin] = &uniformInfos[info.name];
 			}
 		}
+
+		for (const auto& r : shaderResources.storage_buffers) {
+			const auto& type = comp.get_type(r.type_id);
+
+			UniformInfo u{};
+			u.baseType = UNIFORM_STORAGEBUFFER;
+			u.components = 1;
+			u.name = r.name;
+			u.count = type.array.empty() ? 1 : type.array[0];
+			u.location = comp.get_decoration(r.id, spv::DecorationBinding);
+			
+			const auto reflectionit = validationReflection.storageBuffers.find(u.name);
+			if (reflectionit != validationReflection.storageBuffers.end()) {
+				u.bufferStride = reflectionit->second.stride;
+				u.bufferMemberCount = reflectionit->second.memberCount;
+				u.access = reflectionit->second.access;
+			}
+			else {
+				continue;
+			}
+
+			// todo: some stuff missing
+
+			u.buffers = new love::graphics::Buffer * [u.count];
+
+			for (int i = 0; i < u.count; i++) {
+				u.buffers[i] = nullptr;
+			}
+
+			uniformInfos[u.name] = u;
+		}
+
+		for (const auto& r : shaderResources.storage_images) {
+			const auto& type = comp.get_type(r.type_id);
+
+			UniformInfo u{};
+			u.baseType = UNIFORM_STORAGETEXTURE;
+			u.components = 1;
+			u.name = r.name;
+			u.count = type.array.empty() ? 1 : type.array[0];
+			u.textures = new love::graphics::Texture * [u.count];
+			u.location = comp.get_decoration(r.id, spv::DecorationBinding);
+
+			for (int i = 0; i < u.count; i++) {
+				u.textures[i] = nullptr;
+			}
+
+			// some stuff missing ?
+
+			uniformInfos[u.name] = u;
+		}
 	}
 
 	delete program;
@@ -650,20 +733,30 @@ void Shader::createDescriptorSetLayout() {
 			VkDescriptorSetLayoutBinding layoutBinding{};
 
 			layoutBinding.binding = val.location;
-			layoutBinding.descriptorType = val.baseType == UNIFORM_SAMPLER ? VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER : VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
+			layoutBinding.descriptorType = type;
 			layoutBinding.descriptorCount = val.count;
-			layoutBinding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
+			if (isCompute) {
+				layoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+			}
+			else {
+				layoutBinding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
+			}
 
 			bindings.push_back(layoutBinding);
 		}
 	}
-	
-	VkDescriptorSetLayoutBinding uniformBinding{};
-	uniformBinding.binding = uniformLocation;
-	uniformBinding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
-	uniformBinding.descriptorCount = 1;
-	uniformBinding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
-	bindings.push_back(uniformBinding);
+
+	if (!localUniformStagingData.empty()) {
+		VkDescriptorSetLayoutBinding uniformBinding{};
+		uniformBinding.binding = uniformLocation;
+		uniformBinding.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
+		uniformBinding.descriptorCount = 1;
+		if (isCompute)
+			uniformBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+		else
+			uniformBinding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
+		bindings.push_back(uniformBinding);
+	}
 
 	VkDescriptorSetLayoutCreateInfo layoutInfo{};
 	layoutInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
@@ -685,6 +778,19 @@ void Shader::createPipelineLayout() {
 	if (vkCreatePipelineLayout(device, &pipelineLayoutInfo, nullptr, &pipelineLayout) != VK_SUCCESS) {
 		throw love::Exception("failed to create pipeline layout");
 	}
+
+	if (isCompute) {
+		assert(shaderStages.size() == 1);
+
+		VkComputePipelineCreateInfo computeInfo{};
+		computeInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+		computeInfo.stage = shaderStages.at(0);
+		computeInfo.layout = pipelineLayout;
+
+		if (vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &computeInfo, nullptr, &computePipeline) != VK_SUCCESS) {
+			throw love::Exception("failed to create compute pipeline");
+		}
+	}
 }
 
 void Shader::createStreamBuffers() {

+ 7 - 1
src/modules/graphics/vulkan/Shader.h

@@ -25,11 +25,13 @@ public:
 	bool loadVolatile() override;
 	void unloadVolatile() override;
 
+	VkPipeline getComputePipeline() const;
+
 	const std::vector<VkPipelineShaderStageCreateInfo>& getShaderStages() const;
 
 	const VkPipelineLayout getGraphicsPipelineLayout() const;
 
-	void cmdPushDescriptorSets(VkCommandBuffer, uint32_t currentFrame);
+	void cmdPushDescriptorSets(VkCommandBuffer, uint32_t currentFrame, VkPipelineBindPoint);
 
 	void attach() override;
 
@@ -74,6 +76,8 @@ private:
 
 	VkDeviceSize uniformBufferSizeAligned;
 
+	VkPipeline computePipeline;
+
 	VkDescriptorSetLayout descriptorSetLayout;
 	VkPipelineLayout pipelineLayout;
 
@@ -90,6 +94,8 @@ private:
 	Graphics* gfx;
 	VkDevice device;
 
+	bool isCompute = false;
+
 	std::unordered_map<std::string, graphics::Shader::UniformInfo> uniformInfos;
 	UniformInfo* builtinUniformInfo[BUILTIN_MAX_ENUM];
 

+ 112 - 94
src/modules/graphics/vulkan/Texture.cpp

@@ -30,7 +30,8 @@ bool Texture::loadVolatile() {
 		VK_IMAGE_USAGE_TRANSFER_SRC_BIT |
 		VK_IMAGE_USAGE_TRANSFER_DST_BIT | 
 		VK_IMAGE_USAGE_SAMPLED_BIT | 
-		VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT;
+		VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT |
+		VK_IMAGE_USAGE_STORAGE_BIT;
 
 	VkImageCreateFlags createFlags = 0;
 
@@ -70,11 +71,17 @@ bool Texture::loadVolatile() {
 
 	auto commandBuffer = vgfx->getDataTransferCommandBuffer();
 
-	// fixme: we probably should select a different default layout when the texture is not readable, instead of VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, 
-		0, static_cast<uint32_t>(getMipmapCount()), 
-		0, static_cast<uint32_t>(layerCount));
+	if (computeWrite)
+		imageLayout = VK_IMAGE_LAYOUT_GENERAL;
+	else
+		imageLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
+
+	vgfx->oneTimeCommand([=](VkCommandBuffer commandBuffer) {
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage,
+			VK_IMAGE_LAYOUT_UNDEFINED, imageLayout,
+			0, VK_REMAINING_MIP_LEVELS,
+			0, VK_REMAINING_ARRAY_LAYERS);
+	});
 
 	bool hasdata = slices.get(0, 0) != nullptr;
 
@@ -134,6 +141,10 @@ void Texture::setSamplerState(const SamplerState &s) {
 	textureSampler = vgfx->getCachedSampler(s);
 }
 
+VkImageLayout Texture::getImageLayout() const {
+	return imageLayout;
+}
+
 void Texture::createTextureImageView() {
 	auto vulkanFormat = Vulkan::getTextureFormat(format);
 
@@ -160,7 +171,7 @@ void Texture::createTextureImageView() {
 void Texture::clear() {
 	auto commandBuffer = vgfx->getDataTransferCommandBuffer();
 
-	auto clearColor = getClearValue(false);
+	auto clearColor = getClearValue();
 
 	VkImageSubresourceRange range{};
 	range.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
@@ -169,64 +180,45 @@ void Texture::clear() {
 	range.baseArrayLayer = 0;
 	range.layerCount = VK_REMAINING_ARRAY_LAYERS;
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
-		0, range.levelCount, 0, range.layerCount);
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL) {
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
+			VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
+			0, range.levelCount, 0, range.layerCount);
 
-	vkCmdClearColorImage(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearColor, 1, &range);
+		vkCmdClearColorImage(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearColor, 1, &range);
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, 
-		0, range.levelCount, 0, range.layerCount);
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
+			VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, 
+			0, range.levelCount, 0, range.layerCount);
+	}
+	else {
+		vkCmdClearColorImage(commandBuffer, textureImage, VK_IMAGE_LAYOUT_GENERAL, &clearColor, 1, &range);
+	}
 }
 
-VkClearColorValue Texture::getClearValue(bool white) {
+VkClearColorValue Texture::getClearValue() {
 	auto vulkanFormat = Vulkan::getTextureFormat(format);
 
 	VkClearColorValue clearColor{};
-	if (white) {
-		switch (vulkanFormat.internalFormatRepresentation) {
-		case FORMATREPRESENTATION_FLOAT:
-			clearColor.float32[0] = 1.0f;
-			clearColor.float32[1] = 1.0f;
-			clearColor.float32[2] = 1.0f;
-			clearColor.float32[3] = 1.0f;
-			break;
-		case FORMATREPRESENTATION_SINT:
-			clearColor.int32[0] = std::numeric_limits<int32_t>::max();
-			clearColor.int32[1] = std::numeric_limits<int32_t>::max();
-			clearColor.int32[2] = std::numeric_limits<int32_t>::max();
-			clearColor.int32[3] = std::numeric_limits<int32_t>::max();
-			break;
-		case FORMATREPRESENTATION_UINT:
-			clearColor.uint32[0] = std::numeric_limits<uint32_t>::max();
-			clearColor.uint32[1] = std::numeric_limits<uint32_t>::max();
-			clearColor.uint32[2] = std::numeric_limits<uint32_t>::max();
-			clearColor.uint32[3] = std::numeric_limits<uint32_t>::max();
-			break;
-		}
-	}
-	else {
-		switch (vulkanFormat.internalFormatRepresentation) {
-		case FORMATREPRESENTATION_FLOAT:
-			clearColor.float32[0] = 0.0f;
-			clearColor.float32[1] = 0.0f;
-			clearColor.float32[2] = 0.0f;
-			clearColor.float32[3] = 0.0f;
-			break;
-		case FORMATREPRESENTATION_SINT:
-			clearColor.int32[0] = 0;
-			clearColor.int32[1] = 0;
-			clearColor.int32[2] = 0;
-			clearColor.int32[3] = 0;
-			break;
-		case FORMATREPRESENTATION_UINT:
-			clearColor.uint32[0] = 0;
-			clearColor.uint32[1] = 0;
-			clearColor.uint32[2] = 0;
-			clearColor.uint32[3] = 0;
-			break;
-		}
+	switch (vulkanFormat.internalFormatRepresentation) {
+	case FORMATREPRESENTATION_FLOAT:
+		clearColor.float32[0] = 0.0f;
+		clearColor.float32[1] = 0.0f;
+		clearColor.float32[2] = 0.0f;
+		clearColor.float32[3] = 0.0f;
+		break;
+	case FORMATREPRESENTATION_SINT:
+		clearColor.int32[0] = 0;
+		clearColor.int32[1] = 0;
+		clearColor.int32[2] = 0;
+		clearColor.int32[3] = 0;
+		break;
+	case FORMATREPRESENTATION_UINT:
+		clearColor.uint32[0] = 0;
+		clearColor.uint32[1] = 0;
+		clearColor.uint32[2] = 0;
+		clearColor.uint32[3] = 0;
+		break;
 	}
 	return clearColor;
 }
@@ -234,9 +226,10 @@ VkClearColorValue Texture::getClearValue(bool white) {
 void Texture::generateMipmapsInternal() {
 	auto commandBuffer = vgfx->getDataTransferCommandBuffer();
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
-		0, static_cast<uint32_t>(getMipmapCount()), 0, static_cast<uint32_t>(layerCount));
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL)
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
+			VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
+			0, static_cast<uint32_t>(getMipmapCount()), 0, static_cast<uint32_t>(layerCount));
 
 	VkImageMemoryBarrier barrier{};
 	barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
@@ -258,10 +251,11 @@ void Texture::generateMipmapsInternal() {
 		barrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
 		barrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT;
 
-		vkCmdPipelineBarrier(commandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0,
-			0, nullptr,
-			0, nullptr,
-			1, &barrier);
+		if (imageLayout != VK_IMAGE_LAYOUT_GENERAL)
+			vkCmdPipelineBarrier(commandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0,
+				0, nullptr,
+				0, nullptr,
+				1, &barrier);
 
 		VkImageBlit blit{};
 		blit.srcOffsets[0] = { 0, 0, 0 };
@@ -289,10 +283,11 @@ void Texture::generateMipmapsInternal() {
 		barrier.srcAccessMask = VK_ACCESS_TRANSFER_READ_BIT;
 		barrier.dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
 
-		vkCmdPipelineBarrier(commandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0,
-			0, nullptr,
-			0, nullptr,
-			1, &barrier);
+		if (imageLayout != VK_IMAGE_LAYOUT_GENERAL)
+			vkCmdPipelineBarrier(commandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0,
+				0, nullptr,
+				0, nullptr,
+				1, &barrier);
 	}
 
 	barrier.subresourceRange.baseMipLevel = mipLevels - 1;
@@ -301,11 +296,12 @@ void Texture::generateMipmapsInternal() {
 	barrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
 	barrier.dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
 
-	vkCmdPipelineBarrier(commandBuffer,
-		VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0,
-		0, nullptr,
-		0, nullptr,
-		1, &barrier);
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL)
+		vkCmdPipelineBarrier(commandBuffer,
+			VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0,
+			0, nullptr,
+			0, nullptr,
+			1, &barrier);
 }
 
 void Texture::uploadByteData(PixelFormat pixelformat, const void* data, size_t size, int level, int slice, const Rect& r) {
@@ -344,22 +340,36 @@ void Texture::uploadByteData(PixelFormat pixelformat, const void* data, size_t s
 
 	auto commandBuffer = vgfx->getDataTransferCommandBuffer();
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
-		level, 1, slice, 1);
 
-	vkCmdCopyBufferToImage(
-		commandBuffer,
-		stagingBuffer,
-		textureImage,
-		VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
-		1,
-		&region
-	);
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL) {
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
+			VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 
+			level, 1, slice, 1);
+
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, 
-		VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, 
-		level, 1, slice, 1);
+		vkCmdCopyBufferToImage(
+			commandBuffer,
+			stagingBuffer,
+			textureImage,
+			VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
+			1,
+			&region
+		);
+
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage,
+			VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
+			level, 1, slice, 1);
+	}
+	else {
+		vkCmdCopyBufferToImage(
+			commandBuffer,
+			stagingBuffer,
+			textureImage,
+			imageLayout,
+			1,
+			&region
+		);
+	}
 
 	vgfx->queueCleanUp([allocator = allocator, stagingBuffer, vmaAllocation]() {
 		vmaDestroyBuffer(allocator, stagingBuffer, vmaAllocation);
@@ -383,11 +393,15 @@ void Texture::copyFromBuffer(graphics::Buffer* source, size_t sourceoffset, int
 	region.imageExtent.width = static_cast<uint32_t>(rect.w);
 	region.imageExtent.height = static_cast<uint32_t>(rect.h);
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL);
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL) {
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL);
 
-	vkCmdCopyBufferToImage(commandBuffer, (VkBuffer)source->getHandle(), textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, &region);
+		vkCmdCopyBufferToImage(commandBuffer, (VkBuffer)source->getHandle(), textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, &region);
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
+	}
+	else
+		vkCmdCopyBufferToImage(commandBuffer, (VkBuffer)source->getHandle(), textureImage, VK_IMAGE_LAYOUT_GENERAL, 1, &region);
 }
 
 void Texture::copyToBuffer(graphics::Buffer* dest, int slice, int mipmap, const Rect& rect, size_t destoffset, int destwidth, size_t size) {
@@ -408,11 +422,15 @@ void Texture::copyToBuffer(graphics::Buffer* dest, int slice, int mipmap, const
 	region.imageExtent.height = static_cast<uint32_t>(rect.h);
 	region.imageExtent.depth = 1;
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL);
+	if (imageLayout != VK_IMAGE_LAYOUT_GENERAL) {
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL);
 
-	vkCmdCopyImageToBuffer(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, (VkBuffer) dest->getHandle(), 1, &region);
+		vkCmdCopyImageToBuffer(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, (VkBuffer) dest->getHandle(), 1, &region);
 
-	Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
+		Vulkan::cmdTransitionImageLayout(commandBuffer, textureImage, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
+	}
+	else
+		vkCmdCopyImageToBuffer(commandBuffer, textureImage, VK_IMAGE_LAYOUT_GENERAL, (VkBuffer)dest->getHandle(), 1, &region);
 }
 
 } // vulkan

+ 4 - 1
src/modules/graphics/vulkan/Texture.h

@@ -22,6 +22,8 @@ public:
 	 
 	void setSamplerState(const SamplerState &s) override;
 
+	VkImageLayout getImageLayout() const;
+
 	void copyFromBuffer(graphics::Buffer* source, size_t sourceoffset, int sourcewidth, size_t size, int slice, int mipmap, const Rect& rect) override;
 	void copyToBuffer(graphics::Buffer* dest, int slice, int mipmap, const Rect& rect, size_t destoffset, int destwidth, size_t size) override;
 
@@ -39,12 +41,13 @@ private:
 	void createTextureImageView();
 	void clear();
 
-	VkClearColorValue getClearValue(bool white);
+	VkClearColorValue getClearValue();
 
 	graphics::Graphics* gfx = nullptr;
 	VkDevice device = VK_NULL_HANDLE;
 	VmaAllocator allocator = VK_NULL_HANDLE;
 	VkImage textureImage = VK_NULL_HANDLE;
+	VkImageLayout imageLayout = VK_IMAGE_LAYOUT_UNDEFINED;
 	VmaAllocation textureImageAllocation = VK_NULL_HANDLE;
 	VkImageView textureImageView = VK_NULL_HANDLE;
 	VkSampler textureSampler = VK_NULL_HANDLE;

+ 9 - 1
src/modules/graphics/vulkan/Vulkan.cpp

@@ -720,7 +720,7 @@ void Vulkan::cmdTransitionImageLayout(VkCommandBuffer commandBuffer, VkImage ima
 		barrier.srcAccessMask = 0;
 		barrier.dstAccessMask = VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT;
 
-		sourceStage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT;
+		sourceStage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
 		destinationStage = VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
 	}
 	else if (oldLayout == VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL && newLayout == VK_IMAGE_LAYOUT_PRESENT_SRC_KHR) {
@@ -730,6 +730,14 @@ void Vulkan::cmdTransitionImageLayout(VkCommandBuffer commandBuffer, VkImage ima
 		sourceStage = VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
 		destinationStage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT;
 	}
+	// we use general for images that are both sampled and compute write
+	else if (oldLayout == VK_IMAGE_LAYOUT_UNDEFINED && newLayout == VK_IMAGE_LAYOUT_GENERAL) {
+		barrier.srcAccessMask = 0;
+		barrier.dstAccessMask = VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT;
+
+		sourceStage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
+		destinationStage = VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT | VK_PIPELINE_STAGE_TRANSFER_BIT;
+	}
 	else {
 		throw std::invalid_argument("unsupported layout transition!");
 	}