Browse Source

metal: initial love.graphics.dispatchThreadgroups implementation

Alex Szpakowski 3 years ago
parent
commit
c6e4cfdc64

+ 1 - 0
src/modules/graphics/metal/Graphics.h

@@ -203,6 +203,7 @@ private:
 
 	id<MTLDepthStencilState> getCachedDepthStencilState(const DepthState &depth, const StencilState &stencil);
 	void applyRenderState(id<MTLRenderCommandEncoder> renderEncoder, const VertexAttributes &attributes);
+	void applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, Shader *shader);
 	void applyShaderUniforms(id<MTLRenderCommandEncoder> renderEncoder, Shader *shader, Texture *maintex);
 
 	id<MTLCommandQueue> commandQueue;

+ 118 - 2
src/modules/graphics/metal/Graphics.mm

@@ -168,6 +168,23 @@ static inline void setBuffer(id<MTLRenderCommandEncoder> encoder, Graphics::Rend
 	}
 }
 
+static inline void setBuffer(id<MTLComputeCommandEncoder> encoder, Graphics::RenderEncoderBindings &bindings, int index, id<MTLBuffer> buffer, size_t offset)
+{
+	void *b = (__bridge void *)buffer;
+	auto &binding = bindings.buffers[index][SHADERSTAGE_COMPUTE];
+	if (binding.buffer != b)
+	{
+		binding.buffer = b;
+		binding.offset = offset;
+		[encoder setBuffer:buffer offset:offset atIndex:index];
+	}
+	else if (binding.offset != offset)
+	{
+		binding.offset = offset;
+		[encoder setBufferOffset:offset atIndex:index];
+	}
+}
+
 static inline void setTexture(id<MTLRenderCommandEncoder> encoder, Graphics::RenderEncoderBindings &bindings, ShaderStageType stage, int index, id<MTLTexture> texture)
 {
 	void *t = (__bridge void *)texture;
@@ -182,6 +199,17 @@ static inline void setTexture(id<MTLRenderCommandEncoder> encoder, Graphics::Ren
 	}
 }
 
+static inline void setTexture(id<MTLComputeCommandEncoder> encoder, Graphics::RenderEncoderBindings &bindings, int index, id<MTLTexture> texture)
+{
+	void *t = (__bridge void *)texture;
+	auto &binding = bindings.textures[index][SHADERSTAGE_COMPUTE];
+	if (binding != t)
+	{
+		binding = t;
+		[encoder setTexture:texture atIndex:index];
+	}
+}
+
 static inline void setSampler(id<MTLRenderCommandEncoder> encoder, Graphics::RenderEncoderBindings &bindings, ShaderStageType stage, int index, id<MTLSamplerState> sampler)
 {
 	void *s = (__bridge void *)sampler;
@@ -196,6 +224,17 @@ static inline void setSampler(id<MTLRenderCommandEncoder> encoder, Graphics::Ren
 	}
 }
 
+static inline void setSampler(id<MTLComputeCommandEncoder> encoder, Graphics::RenderEncoderBindings &bindings, int index, id<MTLSamplerState> sampler)
+{
+	void *s = (__bridge void *)sampler;
+	auto &binding = bindings.samplers[index][SHADERSTAGE_COMPUTE];
+	if (binding != s)
+	{
+		binding = s;
+		[encoder setSamplerState:sampler atIndex:index];
+	}
+}
+
 love::graphics::Graphics *createInstance()
 {
 	love::graphics::Graphics *instance = nullptr;
@@ -662,6 +701,7 @@ id<MTLComputeCommandEncoder> Graphics::useComputeEncoder()
 		submitRenderEncoder(SUBMIT_STORE);
 		submitBlitEncoder();
 		computeEncoder = [useCommandBuffer() computeCommandEncoder];
+		renderBindings = {};
 	}
 
 	return computeEncoder;
@@ -895,6 +935,63 @@ void Graphics::applyRenderState(id<MTLRenderCommandEncoder> encoder, const Verte
 	dirtyRenderState = 0;
 }
 
+void Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::graphics::Shader *shader)
+{
+	Shader *s = (Shader *)shader;
+
+#ifdef LOVE_MACOS
+	size_t alignment = 256;
+#else
+	size_t alignment = 16;
+#endif
+
+	size_t size = s->getLocalUniformBufferSize();
+	uint8 *bufferdata = s->getLocalUniformBufferData();
+
+	if (uniformBuffer->getSize() < uniformBufferOffset + size)
+	{
+		size_t newsize = uniformBuffer->getSize() * 2;
+		uniformBuffer->release();
+		uniformBuffer = CreateStreamBuffer(device, BUFFERUSAGE_VERTEX, newsize);
+		uniformBufferData = {};
+		uniformBufferOffset = 0;
+	}
+
+	if (uniformBufferData.data == nullptr)
+		uniformBufferData = uniformBuffer->map(uniformBuffer->getSize());
+
+	memcpy(uniformBufferData.data + uniformBufferOffset, bufferdata, size);
+
+	id<MTLBuffer> buffer = getMTLBuffer(uniformBuffer);
+	int uniformindex = Shader::getUniformBufferBinding();
+
+	auto &bindings = renderBindings;
+	setBuffer(encoder, bindings, uniformindex, buffer, uniformBufferOffset);
+
+	uniformBufferOffset += alignUp(size, alignment);
+
+	for (const Shader::TextureBinding &b : s->getTextureBindings())
+	{
+		id<MTLTexture> texture = b.texture;
+		id<MTLSamplerState> sampler = b.sampler;
+
+		uint8 texindex = b.textureStages[SHADERSTAGE_COMPUTE];
+		uint8 sampindex = b.samplerStages[SHADERSTAGE_COMPUTE];
+
+		if (texindex != LOVE_UINT8_MAX)
+			setTexture(encoder, bindings, texindex, texture);
+		if (sampindex != LOVE_UINT8_MAX)
+			setSampler(encoder, bindings, sampindex, sampler);
+	}
+
+	for (const Shader::BufferBinding &b : s->getBufferBindings())
+	{
+		uint8 index = b.stages[SHADERSTAGE_COMPUTE];
+		if (index != LOVE_UINT8_MAX)
+			setBuffer(encoder, bindings, index, b.buffer, 0);
+	}
+}
+
 void Graphics::applyShaderUniforms(id<MTLRenderCommandEncoder> renderEncoder, love::graphics::Shader *shader, love::graphics::Texture *maintex)
 {
 	Shader *s = (Shader *)shader;
@@ -1178,8 +1275,27 @@ void Graphics::drawQuads(int start, int count, const VertexAttributes &attribute
 
 bool Graphics::dispatch(int x, int y, int z)
 { @autoreleasepool {
-	// TODO
-	return false;
+	// Set by higher level code before calling dispatch(x, y, z).
+	auto shader = (Shader *) Shader::current;
+
+	int tX, tY, tZ;
+	shader->getLocalThreadgroupSize(&tX, &tY, &tZ);
+
+	id<MTLComputePipelineState> pipeline = shader->getComputePipeline();
+	if (pipeline == nil)
+		return false;
+
+	id<MTLComputeCommandEncoder> computeEncoder = useComputeEncoder();
+
+	applyShaderUniforms(computeEncoder, shader);
+
+	// TODO: track this state?
+	[computeEncoder setComputePipelineState:pipeline];
+
+	[computeEncoder dispatchThreadgroups:MTLSizeMake(x, y, z)
+				   threadsPerThreadgroup:MTLSizeMake(tX, tY, tZ)];
+
+	return true;
 }}
 
 void Graphics::setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int /*pixelw*/, int /*pixelh*/, bool /*hasSRGBtexture*/)

+ 2 - 0
src/modules/graphics/metal/Shader.h

@@ -107,6 +107,7 @@ public:
 	void setVideoTextures(love::graphics::Texture *ytexture, love::graphics::Texture *cbtexture, love::graphics::Texture *crtexture) override;
 
 	id<MTLRenderPipelineState> getCachedRenderPipeline(const RenderPipelineKey &key);
+	id<MTLComputePipelineState> getComputePipeline() const { return computePipeline; }
 
 	static int getUniformBufferBinding();
 	const std::vector<TextureBinding> &getTextureBindings() const { return textureBindings; }
@@ -144,6 +145,7 @@ private:
 	std::vector<BufferBinding> bufferBindings;
 
 	std::unordered_map<RenderPipelineKey, const void *, RenderPipelineHasher> cachedRenderPipelines;
+	id<MTLComputePipelineState> computePipeline;
 
 }; // Metal
 

+ 21 - 0
src/modules/graphics/metal/Shader.mm

@@ -347,6 +347,27 @@ Shader::Shader(id<MTLDevice> device, StrongRef<love::graphics::ShaderStage> stag
 	}
 
 	cleanup();
+
+	if (functions[SHADERSTAGE_COMPUTE] != nil)
+	{
+		MTLComputePipelineDescriptor *desc = [MTLComputePipelineDescriptor new];
+		desc.computeFunction = functions[SHADERSTAGE_COMPUTE];
+
+		// TODO: threadGroupSizeIsMultipleOfThreadExecutionWidth
+
+		NSError *err = nil;
+		computePipeline = [device newComputePipelineStateWithDescriptor:desc
+																options:MTLPipelineOptionNone
+															 reflection:nil
+																  error:&err];
+		if (computePipeline == nil)
+		{
+			if (err != nil)
+				throw love::Exception("Error creating compute shader pipeline: %s", err.localizedDescription.UTF8String);
+			else
+				throw love::Exception("Error creating compute shader pipeline.");
+		}
+	}
 }}
 
 void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &program)