Parcourir la source

Merge pull request #1715 from bjornbytes/compute-shaders

Compute Shaders
Alex Szpakowski il y a 4 ans
Parent
commit
80f051b97d

+ 38 - 2
src/modules/graphics/Graphics.cpp

@@ -291,7 +291,19 @@ Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
 
 	}
 
-	return newShaderInternal(stages[SHADERSTAGE_VERTEX], stages[SHADERSTAGE_PIXEL]);
+	return newShaderInternal(stages);
+}
+
+Shader *Graphics::newComputeShader(const std::string &source)
+{
+	Shader::SourceInfo info = Shader::getSourceInfo(source);
+
+	if (info.stages[SHADERSTAGE_COMPUTE] == Shader::ENTRYPOINT_NONE)
+		throw love::Exception("Could not parse compute shader code (missing 'computemain' function?)");
+
+	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
+	stages[SHADERSTAGE_COMPUTE].set(newShaderStage(SHADERSTAGE_COMPUTE, source, info));
+	return newShaderInternal(stages);
 }
 
 Buffer *Graphics::newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength)
@@ -332,6 +344,7 @@ bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagess
 	bool validstages[SHADERSTAGE_MAX_ENUM] = {};
 	validstages[SHADERSTAGE_VERTEX] = true;
 	validstages[SHADERSTAGE_PIXEL] = true;
+	validstages[SHADERSTAGE_COMPUTE] = true;
 
 	// Don't use cached shader stages, since the gles flag may not match the
 	// current renderer.
@@ -362,7 +375,7 @@ bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagess
 		}
 	}
 
-	return Shader::validate(stages[SHADERSTAGE_VERTEX], stages[SHADERSTAGE_PIXEL], err);
+	return Shader::validate(stages, err);
 }
 
 int Graphics::getWidth() const
@@ -1065,6 +1078,26 @@ void Graphics::copyBuffer(Buffer *source, Buffer *dest, size_t sourceoffset, siz
 	source->copyTo(dest, sourceoffset, destoffset, size);
 }
 
+void Graphics::dispatchThreadgroups(Shader* shader, int x, int y, int z)
+{
+	if (!shader->hasStage(SHADERSTAGE_COMPUTE))
+		throw love::Exception("Only compute shaders can have threads dispatched.");
+
+	if (x <= 0 || y <= 0 || z <= 0)
+		throw love::Exception("Threadgroup dispatch size must be positive.");
+
+	if (x > capabilities.limits[LIMIT_THREADGROUPS_X]
+		|| y > capabilities.limits[LIMIT_THREADGROUPS_Y]
+		|| z > capabilities.limits[LIMIT_THREADGROUPS_Z])
+	{
+		throw love::Exception("Too many threadgroups dispatched.");
+	}
+
+	flushBatchedDraws();
+	shader->attach();
+	dispatch(x, y, z);
+}
+
 Graphics::BatchedVertexData Graphics::requestBatchedDraw(const BatchedDrawCommand &cmd)
 {
 	BatchedDrawState &state = batchedDrawState;
@@ -1878,6 +1911,9 @@ STRINGMAP_CLASS_BEGIN(Graphics, Graphics::SystemLimit, Graphics::LIMIT_MAX_ENUM,
 	{ "cubetexturesize",         Graphics::LIMIT_CUBE_TEXTURE_SIZE          },
 	{ "texelbuffersize",         Graphics::LIMIT_TEXEL_BUFFER_SIZE          },
 	{ "shaderstoragebuffersize", Graphics::LIMIT_SHADER_STORAGE_BUFFER_SIZE },
+	{ "threadgroupsx",           Graphics::LIMIT_THREADGROUPS_X             },
+	{ "threadgroupsy",           Graphics::LIMIT_THREADGROUPS_Y             },
+	{ "threadgroupsz",           Graphics::LIMIT_THREADGROUPS_Z             },
 	{ "rendertargets",           Graphics::LIMIT_RENDER_TARGETS             },
 	{ "texturemsaa",             Graphics::LIMIT_TEXTURE_MSAA               },
 	{ "anisotropy",              Graphics::LIMIT_ANISOTROPY                 },

+ 9 - 1
src/modules/graphics/Graphics.h

@@ -164,6 +164,9 @@ public:
 		LIMIT_TEXTURE_LAYERS,
 		LIMIT_TEXEL_BUFFER_SIZE,
 		LIMIT_SHADER_STORAGE_BUFFER_SIZE,
+		LIMIT_THREADGROUPS_X,
+		LIMIT_THREADGROUPS_Y,
+		LIMIT_THREADGROUPS_Z,
 		LIMIT_RENDER_TARGETS,
 		LIMIT_TEXTURE_MSAA,
 		LIMIT_ANISOTROPY,
@@ -436,6 +439,7 @@ public:
 	ParticleSystem *newParticleSystem(Texture *texture, int size);
 
 	Shader *newShader(const std::vector<std::string> &stagessource);
+	Shader *newComputeShader(const std::string &source);
 
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength) = 0;
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength);
@@ -671,6 +675,8 @@ public:
 
 	void copyBuffer(Buffer *source, Buffer *dest, size_t sourceoffset, size_t destoffset, size_t size);
 
+	void dispatchThreadgroups(Shader* shader, int x, int y, int z);
+
 	void draw(Drawable *drawable, const Matrix4 &m);
 	void draw(Texture *texture, Quad *quad, const Matrix4 &m);
 	void drawLayer(Texture *texture, int layer, const Matrix4 &m);
@@ -929,9 +935,11 @@ protected:
 
 	ShaderStage *newShaderStage(ShaderStageType stage, const std::string &source, const Shader::SourceInfo &info);
 	virtual ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) = 0;
-	virtual Shader *newShaderInternal(ShaderStage *vertex, ShaderStage *pixel) = 0;
+	virtual Shader *newShaderInternal(StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) = 0;
 	virtual StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) = 0;
 
+	virtual void dispatch(int x, int y, int z) = 0;
+
 	virtual void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) = 0;
 
 	virtual void initCapabilities() = 0;

+ 72 - 12
src/modules/graphics/Shader.cpp

@@ -389,6 +389,24 @@ void main() {
 }
 )";
 
+static const char compute_header[] = R"(
+#define love_NumWorkGroups gl_NumWorkGroups
+#define love_WorkGroupID gl_WorkGroupID
+#define love_LocalInvocationID gl_LocalInvocationID
+#define love_GlobalInvocationID gl_GlobalInvocationID
+#define love_LocalInvocationIndex gl_LocalInvocationIndex
+)";
+
+static const char compute_functions[] = R"()";
+
+static const char compute_main[] = R"(
+void computemain();
+
+void main() {
+	computemain();
+}
+)";
+
 struct StageInfo
 {
 	const char *name;
@@ -403,6 +421,7 @@ static const StageInfo stageInfo[] =
 {
 	{ "VERTEX", vertex_header, vertex_functions, vertex_main, vertex_main, vertex_main_raw },
 	{ "PIXEL", pixel_header, pixel_functions, pixel_main, pixel_main_custom, pixel_main_raw },
+	{ "COMPUTE", compute_header, compute_functions, compute_main, compute_main, compute_main },
 };
 
 static_assert((sizeof(stageInfo) / sizeof(StageInfo)) == SHADERSTAGE_MAX_ENUM, "Stages array size must match ShaderStage enum.");
@@ -465,6 +484,15 @@ static Shader::EntryPoint getPixelEntryPoint(const std::string &src, bool &mrt)
 	return Shader::ENTRYPOINT_NONE;
 }
 
+static Shader::EntryPoint getComputeEntryPoint(const std::string &src) {
+	std::smatch m;
+
+	if (std::regex_search(src, m, std::regex("void\\s+computemain\\s*\\(")))
+		return Shader::ENTRYPOINT_RAW;
+
+	return Shader::ENTRYPOINT_NONE;
+}
+
 } // glsl
 
 static_assert(sizeof(Shader::BuiltinUniformData) == sizeof(float) * 4 * 13, "Update the array in wrap_GraphicsShader.lua if this changes.");
@@ -480,6 +508,9 @@ Shader::SourceInfo Shader::getSourceInfo(const std::string &src)
 	info.language = glsl::getTargetLanguage(src);
 	info.stages[SHADERSTAGE_VERTEX] = glsl::getVertexEntryPoint(src);
 	info.stages[SHADERSTAGE_PIXEL] = glsl::getPixelEntryPoint(src, info.usesMRT);
+	info.stages[SHADERSTAGE_COMPUTE] = glsl::getComputeEntryPoint(src);
+	if (info.stages[SHADERSTAGE_COMPUTE])
+		info.language = LANGUAGE_GLSL4;
 	return info;
 }
 
@@ -494,6 +525,9 @@ std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage,
 	if (info.stages[stage] == ENTRYPOINT_RAW && info.language == LANGUAGE_GLSL1)
 		throw love::Exception("Shaders using a raw entry point (vertexmain or pixelmain) must use GLSL 3 or greater.");
 
+	if (stage == SHADERSTAGE_COMPUTE && info.language != LANGUAGE_GLSL4)
+		throw love::Exception("Compute shaders must use GLSL 4.");
+
 	const auto &features = gfx->getCapabilities().features;
 
 	if (info.language == LANGUAGE_GLSL3 && !features[Graphics::FEATURE_GLSL3])
@@ -541,15 +575,15 @@ std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage,
 	return ss.str();
 }
 
-Shader::Shader(ShaderStage *vertex, ShaderStage *pixel)
+Shader::Shader(StrongRef<ShaderStage> _stages[])
 	: stages()
 {
 	std::string err;
-	if (!validateInternal(vertex, pixel, err, validationReflection))
+	if (!validateInternal(_stages, err, validationReflection))
 		throw love::Exception("%s", err.c_str());
 
-	stages[SHADERSTAGE_VERTEX] = vertex;
-	stages[SHADERSTAGE_PIXEL] = pixel;
+	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
+		stages[i] = _stages[i];
 }
 
 Shader::~Shader()
@@ -564,6 +598,11 @@ Shader::~Shader()
 		attachDefault(STANDARD_DEFAULT);
 }
 
+bool Shader::hasStage(ShaderStageType stage)
+{
+	return stages[stage] != nullptr;
+}
+
 void Shader::attachDefault(StandardShader defaultType)
 {
 	Shader *defaultshader = standardShaders[defaultType];
@@ -636,21 +675,28 @@ void Shader::validateDrawState(PrimitiveType primtype, Texture *maintex) const
 	}
 }
 
-bool Shader::validate(ShaderStage* vertex, ShaderStage* pixel, std::string& err)
+void Shader::getLocalThreadgroupSize(int *x, int *y, int *z)
+{
+	*x = validationReflection.localThreadgroupSize[0];
+	*y = validationReflection.localThreadgroupSize[1];
+	*z = validationReflection.localThreadgroupSize[2];
+}
+
+bool Shader::validate(StrongRef<ShaderStage> stages[], std::string& err)
 {
 	ValidationReflection reflection;
-	return validateInternal(vertex, pixel, err, reflection);
+	return validateInternal(stages, err, reflection);
 }
 
-bool Shader::validateInternal(ShaderStage *vertex, ShaderStage *pixel, std::string &err, ValidationReflection &reflection)
+bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err, ValidationReflection &reflection)
 {
 	glslang::TProgram program;
 
-	if (vertex != nullptr)
-		program.addShader(vertex->getGLSLangShader());
-
-	if (pixel != nullptr)
-		program.addShader(pixel->getGLSLangShader());
+	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
+	{
+		if (stages[i] != nullptr)
+			program.addShader(stages[i]->getGLSLangShader());
+	}
 
 	if (!program.link(EShMsgDefault))
 	{
@@ -671,6 +717,20 @@ bool Shader::validateInternal(ShaderStage *vertex, ShaderStage *pixel, std::stri
 		reflection.usesPointSize = vertintermediate->inIoAccessed("gl_PointSize");
 	}
 
+	if (stages[SHADERSTAGE_COMPUTE] != nullptr)
+	{
+		for (int i = 0; i < 3; i++)
+		{
+			reflection.localThreadgroupSize[i] = program.getLocalSize(i);
+
+			if (reflection.localThreadgroupSize[i] <= 0)
+			{
+				err = "Shader validation error:\nNegative local threadgroup size.";
+				return false;
+			}
+		}
+	}
+
 	for (int i = 0; i < program.getNumBufferBlocks(); i++)
 	{
 		const glslang::TObjectReflection &info = program.getBufferBlock(i);

+ 11 - 3
src/modules/graphics/Shader.h

@@ -164,9 +164,14 @@ public:
 	// Pointer to the default Shader.
 	static Shader *standardShaders[STANDARD_MAX_ENUM];
 
-	Shader(ShaderStage *vertex, ShaderStage *pixel);
+	Shader(StrongRef<ShaderStage> stages[]);
 	virtual ~Shader();
 
+	/**
+	 * Check whether a Shader has a stage.
+	 **/
+	bool hasStage(ShaderStageType stage);
+
 	/**
 	 * Binds this Shader's program to be used when rendering.
 	 **/
@@ -211,10 +216,12 @@ public:
 	TextureType getMainTextureType() const;
 	void validateDrawState(PrimitiveType primtype, Texture *maintexture) const;
 
+	void getLocalThreadgroupSize(int *x, int *y, int *z);
+
 	static SourceInfo getSourceInfo(const std::string &src);
 	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const SourceInfo &info);
 
-	static bool validate(ShaderStage *vertex, ShaderStage *pixel, std::string &err);
+	static bool validate(StrongRef<ShaderStage> stages[], std::string &err);
 
 	static bool initialize();
 	static void deinitialize();
@@ -238,10 +245,11 @@ protected:
 	struct ValidationReflection
 	{
 		std::map<std::string, BufferReflection> storageBuffers;
+		int localThreadgroupSize[3];
 		bool usesPointSize;
 	};
 
-	static bool validateInternal(ShaderStage* vertex, ShaderStage* pixel, std::string& err, ValidationReflection &reflection);
+	static bool validateInternal(StrongRef<ShaderStage> stages[], std::string& err, ValidationReflection &reflection);
 
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
 

+ 5 - 2
src/modules/graphics/ShaderStage.cpp

@@ -148,6 +148,8 @@ ShaderStage::ShaderStage(Graphics *gfx, ShaderStageType stage, const std::string
 		glslangStage = EShLangVertex;
 	else if (stage == SHADERSTAGE_PIXEL)
 		glslangStage = EShLangFragment;
+	else if (stage == SHADERSTAGE_COMPUTE)
+		glslangStage = EShLangCompute;
 	else
 		throw love::Exception("Cannot compile shader stage: unknown stage type.");
 
@@ -212,8 +214,9 @@ const char *ShaderStage::getConstant(ShaderStageType in)
 
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM>::Entry ShaderStage::stageNameEntries[] =
 {
-	{ "vertex", SHADERSTAGE_VERTEX },
-	{ "pixel",  SHADERSTAGE_PIXEL  },
+	{ "vertex",  SHADERSTAGE_VERTEX  },
+	{ "pixel",   SHADERSTAGE_PIXEL   },
+	{ "compute", SHADERSTAGE_COMPUTE },
 };
 
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM> ShaderStage::stageNames(ShaderStage::stageNameEntries, sizeof(ShaderStage::stageNameEntries));

+ 1 - 0
src/modules/graphics/ShaderStage.h

@@ -45,6 +45,7 @@ enum ShaderStageType
 {
 	SHADERSTAGE_VERTEX,
 	SHADERSTAGE_PIXEL,
+	SHADERSTAGE_COMPUTE,
 	SHADERSTAGE_MAX_ENUM
 };
 

+ 12 - 3
src/modules/graphics/opengl/Graphics.cpp

@@ -155,9 +155,9 @@ love::graphics::ShaderStage *Graphics::newShaderStageInternal(ShaderStageType st
 	return new ShaderStage(this, stage, source, gles, cachekey);
 }
 
-love::graphics::Shader *Graphics::newShaderInternal(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel)
+love::graphics::Shader *Graphics::newShaderInternal(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM])
 {
-	return new Shader(vertex, pixel);
+	return new Shader(stages);
 }
 
 love::graphics::Buffer *Graphics::newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength)
@@ -469,6 +469,12 @@ void Graphics::setActive(bool enable)
 	active = enable;
 }
 
+void Graphics::dispatch(int x, int y, int z)
+{
+	glDispatchCompute(x, y, z);
+	glMemoryBarrier(GL_ALL_BARRIER_BITS); // TODO: Improve synchronization
+}
+
 void Graphics::draw(const DrawCommand &cmd)
 {
 	gl.prepareDraw(this);
@@ -1543,10 +1549,13 @@ void Graphics::initCapabilities()
 	capabilities.limits[LIMIT_CUBE_TEXTURE_SIZE] = gl.getMaxCubeTextureSize();
 	capabilities.limits[LIMIT_TEXEL_BUFFER_SIZE] = gl.getMaxTexelBufferSize();
 	capabilities.limits[LIMIT_SHADER_STORAGE_BUFFER_SIZE] = gl.getMaxShaderStorageBufferSize();
+	capabilities.limits[LIMIT_THREADGROUPS_X] = gl.getMaxComputeWorkGroupsX();
+	capabilities.limits[LIMIT_THREADGROUPS_Y] = gl.getMaxComputeWorkGroupsY();
+	capabilities.limits[LIMIT_THREADGROUPS_Z] = gl.getMaxComputeWorkGroupsZ();
 	capabilities.limits[LIMIT_RENDER_TARGETS] = gl.getMaxRenderTargets();
 	capabilities.limits[LIMIT_TEXTURE_MSAA] = gl.getMaxSamples();
 	capabilities.limits[LIMIT_ANISOTROPY] = gl.getMaxAnisotropy();
-	static_assert(LIMIT_MAX_ENUM == 10, "Graphics::initCapabilities must be updated when adding a new system limit!");
+	static_assert(LIMIT_MAX_ENUM == 13, "Graphics::initCapabilities must be updated when adding a new system limit!");
 
 	for (int i = 0; i < TEXTURE_MAX_ENUM; i++)
 		capabilities.textureTypes[i] = gl.isTextureTypeSupported((TextureType) i);

+ 3 - 1
src/modules/graphics/opengl/Graphics.h

@@ -68,6 +68,8 @@ public:
 
 	void setActive(bool active) override;
 
+	void dispatch(int x, int y, int z) override;
+
 	void draw(const DrawCommand &cmd) override;
 	void draw(const DrawIndexedCommand &cmd) override;
 	void drawQuads(int start, int count, const VertexAttributes &attributes, const BufferBindings &buffers, love::graphics::Texture *texture) override;
@@ -137,7 +139,7 @@ private:
 	};
 
 	love::graphics::ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) override;
-	love::graphics::Shader *newShaderInternal(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel) override;
+	love::graphics::Shader *newShaderInternal(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) override;
 	love::graphics::StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) override;
 	void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) override;
 	void initCapabilities() override;

+ 31 - 0
src/modules/graphics/opengl/OpenGL.cpp

@@ -103,6 +103,9 @@ OpenGL::OpenGL()
 	, maxTextureArrayLayers(0)
 	, maxTexelBufferSize(0)
 	, maxShaderStorageBufferSize(0)
+	, maxComputeWorkGroupsX(0)
+	, maxComputeWorkGroupsY(0)
+	, maxComputeWorkGroupsZ(0)
 	, maxRenderTargets(1)
 	, maxSamples(1)
 	, maxTextureUnits(1)
@@ -511,6 +514,19 @@ void OpenGL::initMaxValues()
 		maxShaderStorageBufferBindings = 0;
 	}
 
+	if (GLAD_ES_VERSION_3_1 || GLAD_VERSION_4_3)
+	{
+		glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 0, &maxComputeWorkGroupsX);
+		glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 1, &maxComputeWorkGroupsY);
+		glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 2, &maxComputeWorkGroupsZ);
+	}
+	else
+	{
+		maxComputeWorkGroupsX = 0;
+		maxComputeWorkGroupsY = 0;
+		maxComputeWorkGroupsZ = 0;
+	}
+
 	int maxattachments = 1;
 	int maxdrawbuffers = 1;
 
@@ -1521,6 +1537,21 @@ int OpenGL::getMaxShaderStorageBufferSize() const
 	return maxShaderStorageBufferSize;
 }
 
+int OpenGL::getMaxComputeWorkGroupsX() const
+{
+	return maxComputeWorkGroupsX;
+}
+
+int OpenGL::getMaxComputeWorkGroupsY() const
+{
+	return maxComputeWorkGroupsY;
+}
+
+int OpenGL::getMaxComputeWorkGroupsZ() const
+{
+	return maxComputeWorkGroupsZ;
+}
+
 int OpenGL::getMaxRenderTargets() const
 {
 	return std::min(maxRenderTargets, MAX_COLOR_RENDER_TARGETS);

+ 11 - 0
src/modules/graphics/opengl/OpenGL.h

@@ -390,6 +390,14 @@ public:
 	 **/
 	int getMaxShaderStorageBufferSize() const;
 
+	/**
+	 * Returns the maximum number of compute work groups that can be
+	 * dispatched in a given dimension.
+	 */
+	int getMaxComputeWorkGroupsX() const;
+	int getMaxComputeWorkGroupsY() const;
+	int getMaxComputeWorkGroupsZ() const;
+
 	/**
 	 * Returns the maximum supported number of simultaneous render targets.
 	 **/
@@ -474,6 +482,9 @@ private:
 	int maxTextureArrayLayers;
 	int maxTexelBufferSize;
 	int maxShaderStorageBufferSize;
+	int maxComputeWorkGroupsX;
+	int maxComputeWorkGroupsY;
+	int maxComputeWorkGroupsZ;
 	int maxRenderTargets;
 	int maxSamples;
 	int maxTextureUnits;

+ 2 - 2
src/modules/graphics/opengl/Shader.cpp

@@ -42,8 +42,8 @@ static bool isBuffer(Shader::UniformType utype)
 	return utype == Shader::UNIFORM_TEXELBUFFER || utype == Shader::UNIFORM_STORAGEBUFFER;
 }
 
-Shader::Shader(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel)
-	: love::graphics::Shader(vertex, pixel)
+Shader::Shader(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM])
+	: love::graphics::Shader(stages)
 	, program(0)
 	, builtinUniforms()
 	, builtinUniformInfo()

+ 1 - 1
src/modules/graphics/opengl/Shader.h

@@ -43,7 +43,7 @@ class Shader final : public love::graphics::Shader, public Volatile
 {
 public:
 
-	Shader(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel);
+	Shader(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]);
 	virtual ~Shader();
 
 	// Implements Volatile

+ 2 - 0
src/modules/graphics/opengl/ShaderStage.cpp

@@ -53,6 +53,8 @@ bool ShaderStage::loadVolatile()
 		glstage = GL_VERTEX_SHADER;
 	else if (stage == SHADERSTAGE_PIXEL)
 		glstage = GL_FRAGMENT_SHADER;
+	else if (stage == SHADERSTAGE_COMPUTE)
+		glstage = GL_COMPUTE_SHADER;
 	else
 		throw love::Exception("%s shader stage is not handled in OpenGL backend code.", typestr);
 

+ 41 - 0
src/modules/graphics/wrap_Graphics.cpp

@@ -1458,6 +1458,34 @@ int w_newShader(lua_State *L)
 	return 1;
 }
 
+int w_newComputeShader(lua_State* L)
+{
+	std::vector<std::string> stages;
+	w_getShaderSource(L, 1, stages);
+
+	bool should_error = false;
+	try
+	{
+		Shader *shader = instance()->newComputeShader(stages[0]);
+		luax_pushtype(L, shader);
+		shader->release();
+	}
+	catch (love::Exception &e)
+	{
+		luax_getfunction(L, "graphics", "_transformGLSLErrorMessages");
+		lua_pushstring(L, e.what());
+
+		// Function pushes the new error string onto the stack.
+		lua_pcall(L, 1, 1, 0);
+		should_error = true;
+	}
+
+	if (should_error)
+		return lua_error(L);
+
+	return 1;
+}
+
 int w_validateShader(lua_State *L)
 {
 	bool gles = luax_checkboolean(L, 1);
@@ -3255,6 +3283,16 @@ int w_polygon(lua_State *L)
 	return 0;
 }
 
+int w_dispatchThreadgroups(lua_State* L)
+{
+	Shader *shader = luax_checkshader(L, 1);
+	int x = (int) luaL_checkinteger(L, 2);
+	int y = (int) luaL_optinteger(L, 3, 1);
+	int z = (int) luaL_optinteger(L, 4, 1);
+	luax_catchexcept(L, [&](){ instance()->dispatchThreadgroups(shader, x, y, z); });
+	return 0;
+}
+
 int w_copyBuffer(lua_State *L)
 {
 	Buffer *source = luax_checkbuffer(L, 1);
@@ -3405,6 +3443,7 @@ static const luaL_Reg functions[] =
 	{ "newSpriteBatch", w_newSpriteBatch },
 	{ "newParticleSystem", w_newParticleSystem },
 	{ "newShader", w_newShader },
+	{ "newComputeShader", w_newComputeShader },
 	{ "newBuffer", w_newBuffer },
 	{ "newVertexBuffer", w_newVertexBuffer },
 	{ "newIndexBuffer", w_newIndexBuffer },
@@ -3472,6 +3511,8 @@ static const luaL_Reg functions[] =
 	{ "print", w_print },
 	{ "printf", w_printf },
 
+	{ "dispatchThreadgroups", w_dispatchThreadgroups },
+
 	{ "copyBuffer", w_copyBuffer },
 
 	{ "isCreated", w_isCreated },

+ 36 - 4
src/modules/graphics/wrap_Shader.cpp

@@ -491,12 +491,44 @@ int w_Shader_hasUniform(lua_State *L)
 	return 1;
 }
 
+int w_Shader_hasStage(lua_State* L)
+{
+	Shader *shader = luax_checkshader(L, 1);
+	ShaderStageType stage;
+	const char *str = luaL_checkstring(L, 2);
+	if (!ShaderStage::getConstant(str, stage))
+		return luax_enumerror(L, "shader stage", str);
+
+	luax_pushboolean(L, shader->hasStage(stage));
+	return 1;
+}
+
+int w_Shader_getLocalThreadgroupSize(lua_State* L)
+{
+	Shader *shader = luax_checkshader(L, 1);
+
+	if (!shader->hasStage(SHADERSTAGE_COMPUTE))
+	{
+		lua_pushnil(L);
+		return 1;
+	}
+
+	int x, y, z;
+	shader->getLocalThreadgroupSize(&x, &y, &z);
+	lua_pushinteger(L, x);
+	lua_pushinteger(L, y);
+	lua_pushinteger(L, z);
+	return 3;
+}
+
 static const luaL_Reg w_Shader_functions[] =
 {
-	{ "getWarnings", w_Shader_getWarnings },
-	{ "send",        w_Shader_send },
-	{ "sendColor",   w_Shader_sendColors },
-	{ "hasUniform",  w_Shader_hasUniform },
+	{ "getWarnings",             w_Shader_getWarnings },
+	{ "send",                    w_Shader_send },
+	{ "sendColor",               w_Shader_sendColors },
+	{ "hasUniform",              w_Shader_hasUniform },
+	{ "hasStage",                w_Shader_hasStage },
+	{ "getLocalThreadgroupSize", w_Shader_getLocalThreadgroupSize },
 	{ 0, 0 }
 };