Browse Source

Add Shader:getBufferFormat.

Returns a table with the same setup as the data format table used with love.graphics.newBuffer.
Sasha Szpakowski 1 year ago
parent
commit
78f8cfc4dc
3 changed files with 162 additions and 0 deletions
  1. 121 0
      src/modules/graphics/Shader.cpp
  2. 5 0
      src/modules/graphics/Shader.h
  3. 36 0
      src/modules/graphics/wrap_Shader.cpp

+ 121 - 0
src/modules/graphics/Shader.cpp

@@ -928,6 +928,14 @@ void Shader::getLocalThreadgroupSize(int *x, int *y, int *z)
 	*z = reflection.localThreadgroupSize[2];
 }
 
+const std::vector<Buffer::DataDeclaration> *Shader::getBufferFormat(const std::string &name) const
+{
+	auto it = reflection.bufferFormats.find(name);
+	if (it != reflection.bufferFormats.end())
+		return &it->second;
+	return nullptr;
+}
+
 bool Shader::validate(StrongRef<ShaderStage> stages[], std::string& err)
 {
 	Reflection reflection;
@@ -946,6 +954,71 @@ static DataBaseType getBaseType(glslang::TBasicType basictype)
 	}
 }
 
+static DataFormat getDataFormat(glslang::TBasicType basictype, int components, int rows, int columns, bool matrix)
+{
+	if (matrix)
+	{
+		if (basictype != glslang::EbtFloat)
+			return DATAFORMAT_MAX_ENUM;
+
+		if (rows == 2 && columns == 2)
+			return DATAFORMAT_FLOAT_MAT2X2;
+		else if (rows == 2 && columns == 3)
+			return DATAFORMAT_FLOAT_MAT2X3;
+		else if (rows == 2 && columns == 4)
+			return DATAFORMAT_FLOAT_MAT2X4;
+		else if (rows == 3 && columns == 2)
+			return DATAFORMAT_FLOAT_MAT3X2;
+		else if (rows == 3 && columns == 3)
+			return DATAFORMAT_FLOAT_MAT3X3;
+		else if (rows == 3 && columns == 4)
+			return DATAFORMAT_FLOAT_MAT3X4;
+		else if (rows == 4 && columns == 2)
+			return DATAFORMAT_FLOAT_MAT4X2;
+		else if (rows == 4 && columns == 3)
+			return DATAFORMAT_FLOAT_MAT4X3;
+		else if (rows == 4 && columns == 4)
+			return DATAFORMAT_FLOAT_MAT4X4;
+		else
+			return DATAFORMAT_MAX_ENUM;
+	}
+	else if (basictype == glslang::EbtFloat)
+	{
+		if (components == 1)
+			return DATAFORMAT_FLOAT;
+		else if (components == 2)
+			return DATAFORMAT_FLOAT_VEC2;
+		else if (components == 3)
+			return DATAFORMAT_FLOAT_VEC2;
+		else if (components == 4)
+			return DATAFORMAT_FLOAT_VEC2;
+	}
+	else if (basictype == glslang::EbtInt)
+	{
+		if (components == 1)
+			return DATAFORMAT_INT32;
+		else if (components == 2)
+			return DATAFORMAT_INT32_VEC2;
+		else if (components == 3)
+			return DATAFORMAT_INT32_VEC2;
+		else if (components == 4)
+			return DATAFORMAT_INT32_VEC2;
+	}
+	else if (basictype == glslang::EbtUint)
+	{
+		if (components == 1)
+			return DATAFORMAT_UINT32;
+		else if (components == 2)
+			return DATAFORMAT_UINT32_VEC2;
+		else if (components == 3)
+			return DATAFORMAT_UINT32_VEC2;
+		else if (components == 4)
+			return DATAFORMAT_UINT32_VEC2;
+	}
+
+	return DATAFORMAT_MAX_ENUM;
+}
+
 static PixelFormat getPixelFormat(glslang::TLayoutFormat format)
 {
 	using namespace glslang;
@@ -1037,6 +1110,48 @@ static T convertData(const glslang::TConstUnion &data)
 	}
 }
 
+static bool AddFieldsToFormat(std::vector<Buffer::DataDeclaration> &format, int level, const glslang::TType *type, int arraylength, const std::string &basename, std::string &err)
+{
+	if (type->isStruct())
+	{
+		auto fields = type->getStruct();
+
+		for (int i = 0; i < std::max(arraylength, 1); i++)
+		{
+			std::string name = basename;
+			if (level > 0)
+			{
+				name += type->getFieldName().c_str();
+				if (arraylength > 0)
+					name += "[" + std::to_string(i) + "]";
+				name += ".";
+			}
+			for (size_t fieldi = 0; fieldi < fields->size(); fieldi++)
+			{
+				const glslang::TType *fieldtype = (*fields)[fieldi].type;
+				int fieldlength = fieldtype->isSizedArray() ? fieldtype->getCumulativeArraySize() : 0;
+
+				if (!AddFieldsToFormat(format, level + 1, fieldtype, fieldlength, name, err))
+					return false;
+			}
+		}
+	}
+	else
+	{
+		DataFormat dataformat = getDataFormat(type->getBasicType(), type->getVectorSize(), type->getMatrixRows(), type->getMatrixCols(), type->isMatrix());
+		if (dataformat == DATAFORMAT_MAX_ENUM)
+		{
+			err = "Shader validation error:\n";
+			return false;
+		}
+
+		std::string name = basename.empty() ? type->getFieldName().c_str() : basename + type->getFieldName().c_str();
+		format.emplace_back(name.c_str(), dataformat, arraylength);
+	}
+
+	return true;
+}
+
 bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err, Reflection &reflection)
 {
 	glslang::TProgram program;
@@ -1295,6 +1410,12 @@ bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err,
 				u.access = (Access)(ACCESS_READ | ACCESS_WRITE);
 
 			reflection.storageBuffers[u.name] = u;
+
+			std::vector<Buffer::DataDeclaration> format;
+			if (!AddFieldsToFormat(format, 0, elementtype, 0, "", err))
+				return false;
+
+			reflection.bufferFormats[u.name] = format;
 		}
 		else
 		{

+ 5 - 0
src/modules/graphics/Shader.h

@@ -26,6 +26,7 @@
 #include "Texture.h"
 #include "ShaderStage.h"
 #include "Resource.h"
+#include "Buffer.h"
 
 // STL
 #include <string>
@@ -266,6 +267,8 @@ public:
 
 	void getLocalThreadgroupSize(int *x, int *y, int *z);
 
+	const std::vector<Buffer::DataDeclaration> *getBufferFormat(const std::string &name) const;
+
 	static SourceInfo getSourceInfo(const std::string &src);
 	static std::string createShaderStageCode(Graphics *gfx, ShaderStageType stage, const std::string &code, const CompileOptions &options, const SourceInfo &info, bool gles, bool checksystemfeatures);
 
@@ -296,6 +299,8 @@ protected:
 
 		std::map<std::string, std::vector<LocalUniformValue>> localUniformInitializerValues;
 
+		std::map<std::string, std::vector<Buffer::DataDeclaration>> bufferFormats;
+
 		int textureCount;
 		int bufferCount;
 

+ 36 - 0
src/modules/graphics/wrap_Shader.cpp

@@ -514,6 +514,41 @@ int w_Shader_getLocalThreadgroupSize(lua_State* L)
 	return 3;
 }
 
+int w_Shader_getBufferFormat(lua_State *L)
+{
+	Shader *shader = luax_checkshader(L, 1);
+	const char *name = luaL_checkstring(L, 2);
+	const std::vector<Buffer::DataDeclaration> *format = shader->getBufferFormat(name);
+	if (name != nullptr)
+	{
+		lua_createtable(L, (int)format->size(), 0);
+
+		for (size_t i = 0; i < format->size(); i++)
+		{
+			const Buffer::DataDeclaration &member = (*format)[i];
+
+			lua_createtable(L, 0, 3);
+
+			lua_pushstring(L, member.name.c_str());
+			lua_setfield(L, -2, "name");
+
+			const char* formatstr = "unknown";
+			getConstant(member.format, formatstr);
+			lua_pushstring(L, formatstr);
+			lua_setfield(L, -2, "format");
+
+			lua_pushinteger(L, member.arrayLength);
+			lua_setfield(L, -2, "arraylength");
+
+			lua_rawseti(L, -2, i + 1);
+		}
+
+		return 1;
+	}
+
+	return luaL_error(L, "Buffer '%s' does not exist in the Shader.", name);
+}
+
 int w_Shader_getDebugName(lua_State *L)
 {
 	Shader *shader = luax_checkshader(L, 1);
@@ -533,6 +568,7 @@ static const luaL_Reg w_Shader_functions[] =
 	{ "hasUniform",              w_Shader_hasUniform },
 	{ "hasStage",                w_Shader_hasStage },
 	{ "getLocalThreadgroupSize", w_Shader_getLocalThreadgroupSize },
+	{ "getBufferFormat",         w_Shader_getBufferFormat },
 	{ "getDebugName",            w_Shader_getDebugName },
 	{ 0, 0 }
 };