Browse Source

Shader:getLocalThreadgroupSize;

bjorn 4 years ago
parent
commit
b28e61d00d

+ 21 - 0
src/modules/graphics/Shader.cpp

@@ -675,6 +675,13 @@ void Shader::validateDrawState(PrimitiveType primtype, Texture *maintex) const
 	}
 }
 
+void Shader::getLocalThreadgroupSize(int *x, int *y, int *z)
+{
+	*x = validationReflection.localThreadgroupSize[0];
+	*y = validationReflection.localThreadgroupSize[1];
+	*z = validationReflection.localThreadgroupSize[2];
+}
+
 bool Shader::validate(StrongRef<ShaderStage> stages[], std::string& err)
 {
 	ValidationReflection reflection;
@@ -710,6 +717,20 @@ bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err,
 		reflection.usesPointSize = vertintermediate->inIoAccessed("gl_PointSize");
 	}
 
+	if (stages[SHADERSTAGE_COMPUTE] != nullptr)
+	{
+		for (int i = 0; i < 3; i++)
+		{
+			reflection.localThreadgroupSize[i] = program.getLocalSize(i);
+
+			if (reflection.localThreadgroupSize[i] <= 0)
+			{
+				err = "Shader validation error:\nNegative local threadgroup size.";
+				return false;
+			}
+		}
+	}
+
 	for (int i = 0; i < program.getNumBufferBlocks(); i++)
 	{
 		const glslang::TObjectReflection &info = program.getBufferBlock(i);

+ 3 - 0
src/modules/graphics/Shader.h

@@ -216,6 +216,8 @@ public:
 	TextureType getMainTextureType() const;
 	void validateDrawState(PrimitiveType primtype, Texture *maintexture) const;
 
+	void getLocalThreadgroupSize(int *x, int *y, int *z);
+
 	static SourceInfo getSourceInfo(const std::string &src);
 	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const SourceInfo &info);
 
@@ -243,6 +245,7 @@ protected:
 	struct ValidationReflection
 	{
 		std::map<std::string, BufferReflection> storageBuffers;
+		int localThreadgroupSize[3];
 		bool usesPointSize;
 	};
 

+ 24 - 5
src/modules/graphics/wrap_Shader.cpp

@@ -503,13 +503,32 @@ int w_Shader_hasStage(lua_State* L)
 	return 1;
 }
 
+int w_Shader_getLocalThreadgroupSize(lua_State* L)
+{
+	Shader *shader = luax_checkshader(L, 1);
+
+	if (!shader->hasStage(SHADERSTAGE_COMPUTE))
+	{
+		lua_pushnil(L);
+		return 1;
+	}
+
+	int x, y, z;
+	shader->getLocalThreadgroupSize(&x, &y, &z);
+	lua_pushinteger(L, x);
+	lua_pushinteger(L, y);
+	lua_pushinteger(L, z);
+	return 3;
+}
+
 static const luaL_Reg w_Shader_functions[] =
 {
-	{ "getWarnings", w_Shader_getWarnings },
-	{ "send",        w_Shader_send },
-	{ "sendColor",   w_Shader_sendColors },
-	{ "hasUniform",  w_Shader_hasUniform },
-	{ "hasStage",    w_Shader_hasStage },
+	{ "getWarnings",             w_Shader_getWarnings },
+	{ "send",                    w_Shader_send },
+	{ "sendColor",               w_Shader_sendColors },
+	{ "hasUniform",              w_Shader_hasUniform },
+	{ "hasStage",                w_Shader_hasStage },
+	{ "getLocalThreadgroupSize", w_Shader_getLocalThreadgroupSize },
 	{ 0, 0 }
 };