Browse Source

love.graphics.newShader: add a compile options table parameter.

The only field currently read from the table is 'defines', which can contain either an array of define names, or name-value pairs.
For example: newShader(file, {defines={"MYFEATURE_ENABLED", MYSETTING=1}})

Fixes #1577
Alex Szpakowski 3 years ago
parent
commit
0fa00e0bdc

+ 23 - 12
src/modules/graphics/Graphics.cpp

@@ -307,12 +307,18 @@ love::graphics::ParticleSystem *Graphics::newParticleSystem(Texture *texture, in
 	return new ParticleSystem(texture, size);
 	return new ParticleSystem(texture, size);
 }
 }
 
 
-ShaderStage *Graphics::newShaderStage(ShaderStageType stage, const std::string &source, const Shader::SourceInfo &info)
+ShaderStage *Graphics::newShaderStage(ShaderStageType stage, const std::string &source, const Shader::CompileOptions &options, const Shader::SourceInfo &info, bool cache)
 {
 {
 	ShaderStage *s = nullptr;
 	ShaderStage *s = nullptr;
 	std::string cachekey;
 	std::string cachekey;
 
 
-	if (!source.empty())
+	// Never cache if there are custom defines set... because hashing would get
+	// more complicated/expensive, and there shouldn't be a lot of duplicate
+	// shader stages with custom defines anyway.
+	if (!options.defines.empty())
+		cache = false;
+
+	if (cache && !source.empty())
 	{
 	{
 		data::HashFunction::Value hashvalue;
 		data::HashFunction::Value hashvalue;
 		data::hash(data::HashFunction::FUNCTION_SHA1, source.c_str(), source.size(), hashvalue);
 		data::hash(data::HashFunction::FUNCTION_SHA1, source.c_str(), source.size(), hashvalue);
@@ -330,16 +336,16 @@ ShaderStage *Graphics::newShaderStage(ShaderStageType stage, const std::string &
 	if (s == nullptr)
 	if (s == nullptr)
 	{
 	{
 		bool glsles = usesGLSLES();
 		bool glsles = usesGLSLES();
-		std::string glsl = Shader::createShaderStageCode(this, stage, source, info, glsles, true);
+		std::string glsl = Shader::createShaderStageCode(this, stage, source, options, info, glsles, true);
 		s = newShaderStageInternal(stage, cachekey, glsl, glsles);
 		s = newShaderStageInternal(stage, cachekey, glsl, glsles);
-		if (!cachekey.empty())
+		if (cache && !cachekey.empty())
 			cachedShaderStages[stage][cachekey] = s;
 			cachedShaderStages[stage][cachekey] = s;
 	}
 	}
 
 
 	return s;
 	return s;
 }
 }
 
 
-Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
+Shader *Graphics::newShader(const std::vector<std::string> &stagessource, const Shader::CompileOptions &options)
 {
 {
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM] = {};
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM] = {};
 
 
@@ -360,12 +366,12 @@ Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
 			if (info.stages[i] != Shader::ENTRYPOINT_NONE)
 			if (info.stages[i] != Shader::ENTRYPOINT_NONE)
 			{
 			{
 				isanystage = true;
 				isanystage = true;
-				stages[i].set(newShaderStage((ShaderStageType) i, source, info), Acquire::NORETAIN);
+				stages[i].set(newShaderStage((ShaderStageType) i, source, options, info, true), Acquire::NORETAIN);
 			}
 			}
 		}
 		}
 
 
 		if (!isanystage)
 		if (!isanystage)
-			throw love::Exception("Could not parse shader code (missing 'position' or 'effect' function?)");
+			throw love::Exception("Could not parse shader code (missing shader entry point function such as 'position' or 'effect')");
 	}
 	}
 
 
 	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
 	for (int i = 0; i < SHADERSTAGE_MAX_ENUM; i++)
@@ -375,7 +381,8 @@ Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
 		{
 		{
 			const std::string &source = Shader::getDefaultCode(Shader::STANDARD_DEFAULT, stype);
 			const std::string &source = Shader::getDefaultCode(Shader::STANDARD_DEFAULT, stype);
 			Shader::SourceInfo info = Shader::getSourceInfo(source);
 			Shader::SourceInfo info = Shader::getSourceInfo(source);
-			stages[i].set(newShaderStage(stype, source, info), Acquire::NORETAIN);
+			Shader::CompileOptions opts;
+			stages[i].set(newShaderStage(stype, source, opts, info, true), Acquire::NORETAIN);
 		}
 		}
 
 
 	}
 	}
@@ -383,7 +390,7 @@ Shader *Graphics::newShader(const std::vector<std::string> &stagessource)
 	return newShaderInternal(stages);
 	return newShaderInternal(stages);
 }
 }
 
 
-Shader *Graphics::newComputeShader(const std::string &source)
+Shader *Graphics::newComputeShader(const std::string &source, const Shader::CompileOptions &options)
 {
 {
 	Shader::SourceInfo info = Shader::getSourceInfo(source);
 	Shader::SourceInfo info = Shader::getSourceInfo(source);
 
 
@@ -391,7 +398,11 @@ Shader *Graphics::newComputeShader(const std::string &source)
 		throw love::Exception("Could not parse compute shader code (missing 'computemain' function?)");
 		throw love::Exception("Could not parse compute shader code (missing 'computemain' function?)");
 
 
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM];
-	stages[SHADERSTAGE_COMPUTE].set(newShaderStage(SHADERSTAGE_COMPUTE, source, info));
+
+	// Don't bother caching compute shader intermediate source, since there
+	// shouldn't be much reuse.
+	stages[SHADERSTAGE_COMPUTE].set(newShaderStage(SHADERSTAGE_COMPUTE, source, options, info, false));
+
 	return newShaderInternal(stages);
 	return newShaderInternal(stages);
 }
 }
 
 
@@ -426,7 +437,7 @@ void Graphics::cleanupCachedShaderStage(ShaderStageType type, const std::string
 	cachedShaderStages[type].erase(hashkey);
 	cachedShaderStages[type].erase(hashkey);
 }
 }
 
 
-bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagessource, std::string &err)
+bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagessource, const Shader::CompileOptions &options, std::string &err)
 {
 {
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM] = {};
 	StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM] = {};
 
 
@@ -452,7 +463,7 @@ bool Graphics::validateShader(bool gles, const std::vector<std::string> &stagess
 			if (info.stages[i] != Shader::ENTRYPOINT_NONE)
 			if (info.stages[i] != Shader::ENTRYPOINT_NONE)
 			{
 			{
 				isanystage = true;
 				isanystage = true;
-				std::string glsl = Shader::createShaderStageCode(this, stype, source, info, gles, false);
+				std::string glsl = Shader::createShaderStageCode(this, stype, source, options, info, gles, false);
 				stages[i].set(new ShaderStageForValidation(this, stype, glsl, gles), Acquire::NORETAIN);
 				stages[i].set(new ShaderStageForValidation(this, stype, glsl, gles), Acquire::NORETAIN);
 			}
 			}
 		}
 		}

+ 4 - 4
src/modules/graphics/Graphics.h

@@ -448,8 +448,8 @@ public:
 	SpriteBatch *newSpriteBatch(Texture *texture, int size, BufferDataUsage usage);
 	SpriteBatch *newSpriteBatch(Texture *texture, int size, BufferDataUsage usage);
 	ParticleSystem *newParticleSystem(Texture *texture, int size);
 	ParticleSystem *newParticleSystem(Texture *texture, int size);
 
 
-	Shader *newShader(const std::vector<std::string> &stagessource);
-	Shader *newComputeShader(const std::string &source);
+	Shader *newShader(const std::vector<std::string> &stagessource, const Shader::CompileOptions &options);
+	Shader *newComputeShader(const std::string &source, const Shader::CompileOptions &options);
 
 
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength) = 0;
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, const std::vector<Buffer::DataDeclaration> &format, const void *data, size_t size, size_t arraylength) = 0;
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength);
 	virtual Buffer *newBuffer(const Buffer::Settings &settings, DataFormat format, const void *data, size_t size, size_t arraylength);
@@ -460,7 +460,7 @@ public:
 
 
 	Text *newText(Font *font, const std::vector<Font::ColoredString> &text = {});
 	Text *newText(Font *font, const std::vector<Font::ColoredString> &text = {});
 
 
-	bool validateShader(bool gles, const std::vector<std::string> &stages, std::string &err);
+	bool validateShader(bool gles, const std::vector<std::string> &stages, const Shader::CompileOptions &options, std::string &err);
 
 
 	/**
 	/**
 	 * Resets the current color, background color, line style, and so forth.
 	 * Resets the current color, background color, line style, and so forth.
@@ -965,7 +965,7 @@ protected:
 		{}
 		{}
 	};
 	};
 
 
-	ShaderStage *newShaderStage(ShaderStageType stage, const std::string &source, const Shader::SourceInfo &info);
+	ShaderStage *newShaderStage(ShaderStageType stage, const std::string &source, const Shader::CompileOptions &options, const Shader::SourceInfo &info, bool cache);
 	virtual ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) = 0;
 	virtual ShaderStage *newShaderStageInternal(ShaderStageType stage, const std::string &cachekey, const std::string &source, bool gles) = 0;
 	virtual Shader *newShaderInternal(StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) = 0;
 	virtual Shader *newShaderInternal(StrongRef<ShaderStage> stages[SHADERSTAGE_MAX_ENUM]) = 0;
 	virtual StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) = 0;
 	virtual StreamBuffer *newStreamBuffer(BufferUsage type, size_t size) = 0;

+ 5 - 1
src/modules/graphics/Shader.cpp

@@ -526,7 +526,7 @@ Shader::SourceInfo Shader::getSourceInfo(const std::string &src)
 	return info;
 	return info;
 }
 }
 
 
-std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const Shader::SourceInfo &info, bool gles, bool checksystemfeatures)
+std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const CompileOptions &options, const Shader::SourceInfo &info, bool gles, bool checksystemfeatures)
 {
 {
 	if (info.language == Shader::LANGUAGE_MAX_ENUM)
 	if (info.language == Shader::LANGUAGE_MAX_ENUM)
 		throw love::Exception("Invalid shader language");
 		throw love::Exception("Invalid shader language");
@@ -574,6 +574,10 @@ std::string Shader::createShaderStageCode(Graphics *gfx, ShaderStageType stage,
 		ss << "#define LOVE_GAMMA_CORRECT 1\n";
 		ss << "#define LOVE_GAMMA_CORRECT 1\n";
 	if (info.usesMRT)
 	if (info.usesMRT)
 		ss << "#define LOVE_MULTI_RENDER_TARGETS 1\n";
 		ss << "#define LOVE_MULTI_RENDER_TARGETS 1\n";
+
+	for (const auto &def : options.defines)
+		ss << "#define " + def.first + " " + def.second + "\n";
+
 	ss << glsl::global_syntax;
 	ss << glsl::global_syntax;
 	ss << stageinfo.header;
 	ss << stageinfo.header;
 	ss << stageinfo.uniforms;
 	ss << stageinfo.uniforms;

+ 6 - 1
src/modules/graphics/Shader.h

@@ -107,6 +107,11 @@ public:
 		ACCESS_WRITE = (1 << 1),
 		ACCESS_WRITE = (1 << 1),
 	};
 	};
 
 
+	struct CompileOptions
+	{
+		std::map<std::string, std::string> defines;
+	};
+
 	struct SourceInfo
 	struct SourceInfo
 	{
 	{
 		Language language;
 		Language language;
@@ -236,7 +241,7 @@ public:
 	void getLocalThreadgroupSize(int *x, int *y, int *z);
 	void getLocalThreadgroupSize(int *x, int *y, int *z);
 
 
 	static SourceInfo getSourceInfo(const std::string &src);
 	static SourceInfo getSourceInfo(const std::string &src);
-	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const SourceInfo &info, bool gles, bool checksystemfeatures);
+	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const CompileOptions &options, const SourceInfo &info, bool gles, bool checksystemfeatures);
 
 
 	static bool validate(StrongRef<ShaderStage> stages[], std::string &err);
 	static bool validate(StrongRef<ShaderStage> stages[], std::string &err);
 
 

+ 2 - 1
src/modules/graphics/metal/Graphics.mm

@@ -343,9 +343,10 @@ Graphics::Graphics()
 		if (!Shader::standardShaders[i])
 		if (!Shader::standardShaders[i])
 		{
 		{
 			std::vector<std::string> stages;
 			std::vector<std::string> stages;
+			Shader::CompileOptions opts;
 			stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_VERTEX));
 			stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_VERTEX));
 			stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_PIXEL));
 			stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_PIXEL));
-			Shader::standardShaders[i] = newShader(stages);
+			Shader::standardShaders[i] = newShader(stages, opts);
 		}
 		}
 	}
 	}
 
 

+ 2 - 1
src/modules/graphics/opengl/Graphics.cpp

@@ -425,9 +425,10 @@ bool Graphics::setMode(void */*context*/, int width, int height, int pixelwidth,
 			if (!Shader::standardShaders[i])
 			if (!Shader::standardShaders[i])
 			{
 			{
 				std::vector<std::string> stages;
 				std::vector<std::string> stages;
+				Shader::CompileOptions opts;
 				stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_VERTEX));
 				stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_VERTEX));
 				stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_PIXEL));
 				stages.push_back(Shader::getDefaultCode(stype, SHADERSTAGE_PIXEL));
-				Shader::standardShaders[i] = newShader(stages);
+				Shader::standardShaders[i] = newShader(stages, opts);
 			}
 			}
 		}
 		}
 		catch (love::Exception &)
 		catch (love::Exception &)

+ 52 - 8
src/modules/graphics/wrap_Graphics.cpp

@@ -1347,7 +1347,7 @@ int w_newParticleSystem(lua_State *L)
 	return 1;
 	return 1;
 }
 }
 
 
-static int w_getShaderSource(lua_State *L, int startidx, std::vector<std::string> &stages)
+static int w_getShaderSource(lua_State *L, int startidx, std::vector<std::string> &stages, Shader::CompileOptions &options)
 {
 {
 	using namespace love::filesystem;
 	using namespace love::filesystem;
 
 
@@ -1369,7 +1369,6 @@ static int w_getShaderSource(lua_State *L, int startidx, std::vector<std::string
 
 
 				lua_replace(L, i);
 				lua_replace(L, i);
 			}
 			}
-
 			continue;
 			continue;
 		}
 		}
 
 
@@ -1412,18 +1411,61 @@ static int w_getShaderSource(lua_State *L, int startidx, std::vector<std::string
 	if (has_arg2)
 	if (has_arg2)
 		stages.push_back(luax_checkstring(L, startidx + 1));
 		stages.push_back(luax_checkstring(L, startidx + 1));
 
 
+	int optionsidx = has_arg2 ? startidx + 2 : startidx + 1;
+	if (!lua_isnoneornil(L, optionsidx))
+	{
+		luaL_checktype(L, optionsidx, LUA_TTABLE);
+		lua_getfield(L, optionsidx, "defines");
+		if (!lua_isnoneornil(L, -1))
+		{
+			if (!lua_istable(L, -1))
+				luaL_argerror(L, optionsidx, "expected 'defines' field to be a table");
+
+			lua_pushnil(L);
+			while (lua_next(L, -2))
+			{
+				std::string defname;
+				std::string defval;
+
+				if (lua_type(L, -2) == LUA_TNUMBER && lua_type(L, -1) == LUA_TSTRING)
+					defname = luaL_checkstring(L, -1);
+				else if (lua_type(L, -2) != LUA_TSTRING)
+					luaL_argerror(L, optionsidx, "all fields in the 'defines' table must use string keys.");
+				else
+				{
+					defname = luaL_checkstring(L, -2);
+					if (lua_type(L, -1) == LUA_TBOOLEAN)
+						defval = luax_toboolean(L, -1) ? "1" : "0";
+					else
+					{
+						const char *val = lua_tostring(L, -1);
+						if (val == nullptr)
+							luaL_argerror(L, optionsidx, "'defines' table values must be strings, numbers, or booleans.");
+						defval = val;
+					}
+				}
+
+				options.defines[defname] = defval;
+
+				lua_pop(L, 1);
+			}
+		}
+		lua_pop(L, 1);
+	}
+
 	return 0;
 	return 0;
 }
 }
 
 
 int w_newShader(lua_State *L)
 int w_newShader(lua_State *L)
 {
 {
 	std::vector<std::string> stages;
 	std::vector<std::string> stages;
-	w_getShaderSource(L, 1, stages);
+	Shader::CompileOptions options;
+	w_getShaderSource(L, 1, stages, options);
 
 
 	bool should_error = false;
 	bool should_error = false;
 	try
 	try
 	{
 	{
-		Shader *shader = instance()->newShader(stages);
+		Shader *shader = instance()->newShader(stages, options);
 		luax_pushtype(L, shader);
 		luax_pushtype(L, shader);
 		shader->release();
 		shader->release();
 	}
 	}
@@ -1446,12 +1488,13 @@ int w_newShader(lua_State *L)
 int w_newComputeShader(lua_State* L)
 int w_newComputeShader(lua_State* L)
 {
 {
 	std::vector<std::string> stages;
 	std::vector<std::string> stages;
-	w_getShaderSource(L, 1, stages);
+	Shader::CompileOptions options;
+	w_getShaderSource(L, 1, stages, options);
 
 
 	bool should_error = false;
 	bool should_error = false;
 	try
 	try
 	{
 	{
-		Shader *shader = instance()->newComputeShader(stages[0]);
+		Shader *shader = instance()->newComputeShader(stages[0], options);
 		luax_pushtype(L, shader);
 		luax_pushtype(L, shader);
 		shader->release();
 		shader->release();
 	}
 	}
@@ -1476,13 +1519,14 @@ int w_validateShader(lua_State *L)
 	bool gles = luax_checkboolean(L, 1);
 	bool gles = luax_checkboolean(L, 1);
 
 
 	std::vector<std::string> stages;
 	std::vector<std::string> stages;
-	w_getShaderSource(L, 2, stages);
+	Shader::CompileOptions options;
+	w_getShaderSource(L, 2, stages, options);
 
 
 	bool success = true;
 	bool success = true;
 	std::string err;
 	std::string err;
 	try
 	try
 	{
 	{
-		success = instance()->validateShader(gles, stages, err);
+		success = instance()->validateShader(gles, stages, options, err);
 	}
 	}
 	catch (love::Exception &e)
 	catch (love::Exception &e)
 	{
 	{