Browse Source

Add compute ShaderStage;

bjorn 4 years ago
parent
commit
085ba2cb6e

+ 3 - 2
src/modules/graphics/Graphics.cpp

@@ -291,7 +291,7 @@ Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
 
 
 	}
 	}
 
 
-	return newShaderInternal(stages[SHADERSTAGE_VERTEX], stages[SHADERSTAGE_PIXEL]);
+	return newShaderInternal(stages);
 }
 }
 
 
 Buffer *Graphics::newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength)
 Buffer *Graphics::newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength)
@@ -332,6 +332,7 @@ bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagess
 	bool validstages[SHADERSTAGE_MAX_ENUM] = {};
 	bool validstages[SHADERSTAGE_MAX_ENUM] = {};
 	validstages[SHADERSTAGE_VERTEX] = true;
 	validstages[SHADERSTAGE_VERTEX] = true;
 	validstages[SHADERSTAGE_PIXEL] = true;
 	validstages[SHADERSTAGE_PIXEL] = true;
+	validstages[SHADERSTAGE_COMPUTE] = true;
 
 
 	// Don't use cached shader stages, since the gles flag may not match the
 	// Don't use cached shader stages, since the gles flag may not match the
 	// current renderer.
 	// current renderer.
@@ -362,7 +363,7 @@ bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagess
 		}
 		}
 	}
 	}
 
 
-	return Shader::validate(stages[SHADERSTAGE_VERTEX], stages[SHADERSTAGE_PIXEL], err);
+	return Shader::validate(stages, err);
 }
 }
 
 
 int Graphics::getWidth() const
 int Graphics::getWidth() const

+ 1 - 1
src/modules/graphics/Graphics.h

@@ -932,7 +932,7 @@ protected:
 
 
 	ShaderStage *newShaderStage(ShaderStageType stage, const std::string &source, const Shader::SourceInfo &info);
 	ShaderStage *newShaderStage(ShaderStageType stage, const std::string &source, const Shader::SourceInfo &info);
 	virtual ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) = 0;
 	virtual ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) = 0;
-	virtual Shader *newShaderInternal(ShaderStage *vertex, ShaderStage *pixel) = 0;
+	virtual Shader *newShaderInternal(StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) = 0;
 	virtual StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) = 0;
 	virtual StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) = 0;
 
 
 	virtual void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) = 0;
 	virtual void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) = 0;

+ 46 - 12
src/modules/graphics/Shader.cpp

@@ -389,6 +389,24 @@ void main() {
 }
 }
 )";
 )";
 
 
+static const char compute_header[] = R"(
+#define love_NumWorkGroups gl_NumWorkGroups
+#define love_WorkGroupID gl_WorkGroupID
+#define love_LocalInvocationID gl_LocalInvocationID
+#define love_GlobalInvocationID gl_GlobalInvocationID
+#define love_LocalInvocationIndex gl_LocalInvocationIndex
+)";
+
+static const char compute_functions[] = R"()";
+
+static const char compute_main[] = R"(
+void computemain();
+
+void main() {
+	computemain();
+}
+)";
+
 struct StageInfo
 struct StageInfo
 {
 {
 	const char *name;
 	const char *name;
@@ -403,6 +421,7 @@ static const StageInfo stageInfo[] =
 {
 {
 	{ "VERTEX", vertex_header, vertex_functions, vertex_main, vertex_main, vertex_main_raw },
 	{ "VERTEX", vertex_header, vertex_functions, vertex_main, vertex_main, vertex_main_raw },
 	{ "PIXEL", pixel_header, pixel_functions, pixel_main, pixel_main_custom, pixel_main_raw },
 	{ "PIXEL", pixel_header, pixel_functions, pixel_main, pixel_main_custom, pixel_main_raw },
+	{ "COMPUTE", compute_header, compute_functions, compute_main, compute_main, compute_main },
 };
 };
 
 
 static_assert((sizeof(stageInfo) / sizeof(StageInfo)) == SHADERSTAGE_MAX_ENUM, "Stages array size must match ShaderStage enum.");
 static_assert((sizeof(stageInfo) / sizeof(StageInfo)) == SHADERSTAGE_MAX_ENUM, "Stages array size must match ShaderStage enum.");
@@ -465,6 +484,15 @@ static Shader::EntryPoint getPixelEntryPoint(const std::string &src, bool &mrt)
 	return Shader::ENTRYPOINT_NONE;
 	return Shader::ENTRYPOINT_NONE;
 }
 }
 
 
+static Shader::EntryPoint getComputeEntryPoint(const std::string &src) {
+	std::smatch m;
+
+	if (std::regex_search(src, m, std::regex("void\\s+computemain\\s*\\(")))
+		return Shader::ENTRYPOINT_RAW;
+
+	return Shader::ENTRYPOINT_NONE;
+}
+
 } // glsl
 } // glsl
 
 
 static_assert(sizeof(Shader::BuiltinUniformData) == sizeof(float) * 4 * 13, "Update the array in wrap_GraphicsShader.lua if this changes.");
 static_assert(sizeof(Shader::BuiltinUniformData) == sizeof(float) * 4 * 13, "Update the array in wrap_GraphicsShader.lua if this changes.");
@@ -480,6 +508,9 @@ Shader::SourceInfo Shader::getSourceInfo(const std::string &src)
 	info.language = glsl::getTargetLanguage(src);
 	info.language = glsl::getTargetLanguage(src);
 	info.stages[SHADERSTAGE_VERTEX] = glsl::getVertexEntryPoint(src);
 	info.stages[SHADERSTAGE_VERTEX] = glsl::getVertexEntryPoint(src);
 	info.stages[SHADERSTAGE_PIXEL] = glsl::getPixelEntryPoint(src, info.usesMRT);
 	info.stages[SHADERSTAGE_PIXEL] = glsl::getPixelEntryPoint(src, info.usesMRT);
+	info.stages[SHADERSTAGE_COMPUTE] = glsl::getComputeEntryPoint(src);
+	if (info.stages[SHADERSTAGE_COMPUTE])
+		info.language = LANGUAGE_GLSL4;
 	return info;
 	return info;
 }
 }
 
 
@@ -494,6 +525,9 @@ std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage,
 	if (info.stages[stage] == ENTRYPOINT_RAW && info.language == LANGUAGE_GLSL1)
 	if (info.stages[stage] == ENTRYPOINT_RAW && info.language == LANGUAGE_GLSL1)
 		throw love::Exception("Shaders using a raw entry point (vertexmain or pixelmain) must use GLSL 3 or greater.");
 		throw love::Exception("Shaders using a raw entry point (vertexmain or pixelmain) must use GLSL 3 or greater.");
 
 
+	if (stage == SHADERSTAGE_COMPUTE && info.language != LANGUAGE_GLSL4)
+		throw love::Exception("Compute shaders must use GLSL 4.");
+
 	const auto &features = gfx->getCapabilities().features;
 	const auto &features = gfx->getCapabilities().features;
 
 
 	if (info.language == LANGUAGE_GLSL3 && !features[Graphics::FEATURE_GLSL3])
 	if (info.language == LANGUAGE_GLSL3 && !features[Graphics::FEATURE_GLSL3])
@@ -541,15 +575,15 @@ std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage,
 	return ss.str();
 	return ss.str();
 }
 }
 
 
-Shader::Shader(ShaderStage *vertex, ShaderStage *pixel)
+Shader::Shader(StrongRef<ShaderStage> _stages[])
 	: stages()
 	: stages()
 {
 {
 	std::string err;
 	std::string err;
-	if (!validateInternal(vertex, pixel, err, validationReflection))
+	if (!validateInternal(_stages, err, validationReflection))
 		throw love::Exception("%s", err.c_str());
 		throw love::Exception("%s", err.c_str());
 
 
-	stages[SHADERSTAGE_VERTEX] = vertex;
-	stages[SHADERSTAGE_PIXEL] = pixel;
+	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
+		stages[i] = _stages[i];
 }
 }
 
 
 Shader::~Shader()
 Shader::~Shader()
@@ -641,21 +675,21 @@ void Shader::validateDrawState(PrimitiveType primtype, Texture *maintex) const
 	}
 	}
 }
 }
 
 
-bool Shader::validate(ShaderStage* vertex, ShaderStage* pixel, std::string& err)
+bool Shader::validate(StrongRef<ShaderStage> stages[], std::string& err)
 {
 {
 	ValidationReflection reflection;
 	ValidationReflection reflection;
-	return validateInternal(vertex, pixel, err, reflection);
+	return validateInternal(stages, err, reflection);
 }
 }
 
 
-bool Shader::validateInternal(ShaderStage *vertex, ShaderStage *pixel, std::string &err, ValidationReflection &reflection)
+bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err, ValidationReflection &reflection)
 {
 {
 	glslang::TProgram program;
 	glslang::TProgram program;
 
 
-	if (vertex != nullptr)
-		program.addShader(vertex->getGLSLangShader());
-
-	if (pixel != nullptr)
-		program.addShader(pixel->getGLSLangShader());
+	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
+	{
+		if (stages[i] != nullptr)
+			program.addShader(stages[i]->getGLSLangShader());
+	}
 
 
 	if (!program.link(EShMsgDefault))
 	if (!program.link(EShMsgDefault))
 	{
 	{

+ 3 - 3
src/modules/graphics/Shader.h

@@ -164,7 +164,7 @@ public:
 	// Pointer to the default Shader.
 	// Pointer to the default Shader.
 	static Shader *standardShaders[STANDARD_MAX_ENUM];
 	static Shader *standardShaders[STANDARD_MAX_ENUM];
 
 
-	Shader(ShaderStage *vertex, ShaderStage *pixel);
+	Shader(StrongRef<ShaderStage> stages[]);
 	virtual ~Shader();
 	virtual ~Shader();
 
 
 	/**
 	/**
@@ -219,7 +219,7 @@ public:
 	static SourceInfo getSourceInfo(const std::string &src);
 	static SourceInfo getSourceInfo(const std::string &src);
 	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const SourceInfo &info);
 	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const SourceInfo &info);
 
 
-	static bool validate(ShaderStage *vertex, ShaderStage *pixel, std::string &err);
+	static bool validate(StrongRef<ShaderStage> stages[], std::string &err);
 
 
 	static bool initialize();
 	static bool initialize();
 	static void deinitialize();
 	static void deinitialize();
@@ -246,7 +246,7 @@ protected:
 		bool usesPointSize;
 		bool usesPointSize;
 	};
 	};
 
 
-	static bool validateInternal(ShaderStage* vertex, ShaderStage* pixel, std::string& err, ValidationReflection &reflection);
+	static bool validateInternal(StrongRef<ShaderStage> stages[], std::string& err, ValidationReflection &reflection);
 
 
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
 
 

+ 5 - 2
src/modules/graphics/ShaderStage.cpp

@@ -148,6 +148,8 @@ ShaderStage::ShaderStage(Graphics *gfx, ShaderStageType stage, const std::string
 		glslangStage = EShLangVertex;
 		glslangStage = EShLangVertex;
 	else if (stage == SHADERSTAGE_PIXEL)
 	else if (stage == SHADERSTAGE_PIXEL)
 		glslangStage = EShLangFragment;
 		glslangStage = EShLangFragment;
+	else if (stage == SHADERSTAGE_COMPUTE)
+		glslangStage = EShLangCompute;
 	else
 	else
 		throw love::Exception("Cannot compile shader stage: unknown stage type.");
 		throw love::Exception("Cannot compile shader stage: unknown stage type.");
 
 
@@ -212,8 +214,9 @@ const char *ShaderStage::getConstant(ShaderStageType in)
 
 
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM>::Entry ShaderStage::stageNameEntries[] =
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM>::Entry ShaderStage::stageNameEntries[] =
 {
 {
-	{ "vertex", SHADERSTAGE_VERTEX },
-	{ "pixel",  SHADERSTAGE_PIXEL  },
+	{ "vertex",  SHADERSTAGE_VERTEX  },
+	{ "pixel",   SHADERSTAGE_PIXEL   },
+	{ "compute", SHADERSTAGE_COMPUTE },
 };
 };
 
 
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM> ShaderStage::stageNames(ShaderStage::stageNameEntries, sizeof(ShaderStage::stageNameEntries));
 StringMap<ShaderStageType, SHADERSTAGE_MAX_ENUM> ShaderStage::stageNames(ShaderStage::stageNameEntries, sizeof(ShaderStage::stageNameEntries));

+ 1 - 0
src/modules/graphics/ShaderStage.h

@@ -45,6 +45,7 @@ enum ShaderStageType
 {
 {
 	SHADERSTAGE_VERTEX,
 	SHADERSTAGE_VERTEX,
 	SHADERSTAGE_PIXEL,
 	SHADERSTAGE_PIXEL,
+	SHADERSTAGE_COMPUTE,
 	SHADERSTAGE_MAX_ENUM
 	SHADERSTAGE_MAX_ENUM
 };
 };
 
 

+ 2 - 2
src/modules/graphics/opengl/Graphics.cpp

@@ -155,9 +155,9 @@ love::graphics::ShaderStage *Graphics::newShaderStageInternal(ShaderStageType st
 	return new ShaderStage(this, stage, source, gles, cachekey);
 	return new ShaderStage(this, stage, source, gles, cachekey);
 }
 }
 
 
-love::graphics::Shader *Graphics::newShaderInternal(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel)
+love::graphics::Shader *Graphics::newShaderInternal(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM])
 {
 {
-	return new Shader(vertex, pixel);
+	return new Shader(stages);
 }
 }
 
 
 love::graphics::Buffer *Graphics::newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength)
 love::graphics::Buffer *Graphics::newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength)

+ 1 - 1
src/modules/graphics/opengl/Graphics.h

@@ -137,7 +137,7 @@ private:
 	};
 	};
 
 
 	love::graphics::ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) override;
 	love::graphics::ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) override;
-	love::graphics::Shader *newShaderInternal(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel) override;
+	love::graphics::Shader *newShaderInternal(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) override;
 	love::graphics::StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) override;
 	love::graphics::StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) override;
 	void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) override;
 	void setRenderTargetsInternal(const RenderTargets &rts, int w, int h, int pixelw, int pixelh, bool hasSRGBtexture) override;
 	void initCapabilities() override;
 	void initCapabilities() override;

+ 2 - 2
src/modules/graphics/opengl/Shader.cpp

@@ -42,8 +42,8 @@ static bool isBuffer(Shader::UniformType utype)
 	return utype == Shader::UNIFORM_TEXELBUFFER || utype == Shader::UNIFORM_STORAGEBUFFER;
 	return utype == Shader::UNIFORM_TEXELBUFFER || utype == Shader::UNIFORM_STORAGEBUFFER;
 }
 }
 
 
-Shader::Shader(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel)
-	: love::graphics::Shader(vertex, pixel)
+Shader::Shader(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM])
+	: love::graphics::Shader(stages)
 	, program(0)
 	, program(0)
 	, builtinUniforms()
 	, builtinUniforms()
 	, builtinUniformInfo()
 	, builtinUniformInfo()

+ 1 - 1
src/modules/graphics/opengl/Shader.h

@@ -43,7 +43,7 @@ class Shader final : public love::graphics::Shader, public Volatile
 {
 {
 public:
 public:
 
 
-	Shader(love::graphics::ShaderStage *vertex, love::graphics::ShaderStage *pixel);
+	Shader(StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]);
 	virtual ~Shader();
 	virtual ~Shader();
 
 
 	// Implements Volatile
 	// Implements Volatile

+ 2 - 0
src/modules/graphics/opengl/ShaderStage.cpp

@@ -53,6 +53,8 @@ bool ShaderStage::loadVolatile()
 		glstage = GL_VERTEX_SHADER;
 		glstage = GL_VERTEX_SHADER;
 	else if (stage == SHADERSTAGE_PIXEL)
 	else if (stage == SHADERSTAGE_PIXEL)
 		glstage = GL_FRAGMENT_SHADER;
 		glstage = GL_FRAGMENT_SHADER;
+	else if (stage == SHADERSTAGE_COMPUTE)
+		glstage = GL_COMPUTE_SHADER;
 	else
 	else
 		throw love::Exception("%s shader stage is not handled in OpenGL backend code.", typestr);
 		throw love::Exception("%s shader stage is not handled in OpenGL backend code.", typestr);