Browse Source

metal: implement storage texture support for shaders.

Fixes #1783
slime 2 years ago
parent
commit
99ad576668

+ 35 - 0
src/modules/graphics/Shader.cpp

@@ -1098,6 +1098,41 @@ bool Shader::validateBuffer(const UniformInfo *info, Buffer *buffer, bool intern
 	return true;
 }
 
+bool Shader::fillUniformReflectionData(UniformInfo &u)
+{
+	const auto &r = validationReflection;
+
+	if (u.baseType == UNIFORM_STORAGETEXTURE)
+	{
+		const auto reflectionit = r.storageTextures.find(u.name);
+		if (reflectionit != r.storageTextures.end())
+		{
+			u.storageTextureFormat = reflectionit->second.format;
+			u.access = reflectionit->second.access;
+			return true;
+		}
+
+		// No reflection info - maybe glslang was better at detecting dead code
+		// than the driver's compiler?
+		return false;
+	}
+	else if (u.baseType == UNIFORM_STORAGEBUFFER)
+	{
+		const auto reflectionit = r.storageBuffers.find(u.name);
+		if (reflectionit != r.storageBuffers.end())
+		{
+			u.bufferStride = reflectionit->second.stride;
+			u.bufferMemberCount = reflectionit->second.memberCount;
+			u.access = reflectionit->second.access;
+			return true;
+		}
+
+		return false;
+	}
+
+	return true;
+}
+
 bool Shader::initialize()
 {
 	return glslang::InitializeProcess();

+ 2 - 0
src/modules/graphics/Shader.h

@@ -289,6 +289,8 @@ protected:
 		bool usesPointSize;
 	};
 
+	bool fillUniformReflectionData(UniformInfo &u);
+
 	static bool validateInternal(StrongRef<ShaderStage> stages[], std::string& err, ValidationReflection &reflection);
 	static DataBaseType getDataBaseType(PixelFormat format);
 	static bool isResourceBaseTypeCompatible(DataBaseType a, DataBaseType b);

+ 1 - 1
src/modules/graphics/metal/Graphics.h

@@ -211,7 +211,7 @@ private:
 
 	id<MTLDepthStencilState> getCachedDepthStencilState(const DepthState &depth, const StencilState &stencil);
 	void applyRenderState(id<MTLRenderCommandEncoder> renderEncoder, const VertexAttributes &attributes);
-	void applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::graphics::Shader *shader);
+	bool applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::graphics::Shader *shader);
 	void applyShaderUniforms(id<MTLRenderCommandEncoder> renderEncoder, love::graphics::Shader *shader, Texture *maintex);
 
 	id<MTLCommandQueue> commandQueue;

+ 16 - 2
src/modules/graphics/metal/Graphics.mm

@@ -947,7 +947,7 @@ void Graphics::applyRenderState(id<MTLRenderCommandEncoder> encoder, const Verte
 	dirtyRenderState = 0;
 }
 
-void Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::graphics::Shader *shader)
+bool Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::graphics::Shader *shader)
 {
 	Shader *s = (Shader *)shader;
 
@@ -982,6 +982,8 @@ void Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::g
 
 	uniformBufferOffset += alignUp(size, alignment);
 
+	bool allWritableVariablesSet = true;
+
 	for (const Shader::TextureBinding &b : s->getTextureBindings())
 	{
 		id<MTLTexture> texture = b.texture;
@@ -991,7 +993,12 @@ void Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::g
 		uint8 sampindex = b.samplerStages[SHADERSTAGE_COMPUTE];
 
 		if (texindex != LOVE_UINT8_MAX)
+		{
 			setTexture(encoder, bindings, texindex, texture);
+			if ((b.access & Shader::ACCESS_WRITE) != 0 && texture == nil)
+				allWritableVariablesSet = false;
+		}
+		
 		if (sampindex != LOVE_UINT8_MAX)
 			setSampler(encoder, bindings, sampindex, samplertex);
 	}
@@ -1000,8 +1007,14 @@ void Graphics::applyShaderUniforms(id<MTLComputeCommandEncoder> encoder, love::g
 	{
 		uint8 index = b.stages[SHADERSTAGE_COMPUTE];
 		if (index != LOVE_UINT8_MAX)
+		{
 			setBuffer(encoder, bindings, index, b.buffer, 0);
+			if ((b.access & Shader::ACCESS_WRITE) != 0 && b.buffer == nil)
+				allWritableVariablesSet = false;
+		}
 	}
+
+	return allWritableVariablesSet;
 }
 
 void Graphics::applyShaderUniforms(id<MTLRenderCommandEncoder> renderEncoder, love::graphics::Shader *shader, love::graphics::Texture *maintex)
@@ -1299,7 +1312,8 @@ bool Graphics::dispatch(int x, int y, int z)
 
 	id<MTLComputeCommandEncoder> computeEncoder = useComputeEncoder();
 
-	applyShaderUniforms(computeEncoder, shader);
+	if (!applyShaderUniforms(computeEncoder, shader))
+		return false;
 
 	// TODO: track this state?
 	[computeEncoder setComputePipelineState:pipeline];

+ 4 - 0
src/modules/graphics/metal/Shader.h

@@ -43,6 +43,7 @@ namespace spirv_cross
 {
 class CompilerMSL;
 struct SPIRType;
+struct Resource;
 }
 
 namespace love
@@ -86,6 +87,7 @@ public:
 		Texture *samplerTexture;
 
 		bool isMainTexture;
+		Access access;
 
 		uint8 textureStages[SHADERSTAGE_MAX_ENUM];
 		uint8 samplerStages[SHADERSTAGE_MAX_ENUM];
@@ -95,6 +97,7 @@ public:
 	{
 		id<MTLBuffer> buffer;
 		uint8 stages[SHADERSTAGE_MAX_ENUM];
+		Access access;
 	};
 
 	Shader(id<MTLDevice> device, StrongRef<love::graphics::ShaderStage> stages[SHADERSTAGE_MAX_ENUM]);
@@ -137,6 +140,7 @@ private:
 	};
 
 	void buildLocalUniforms(const spirv_cross::CompilerMSL &msl, const spirv_cross::SPIRType &type, size_t baseoffset, const std::string &basename);
+	void addImage(const spirv_cross::CompilerMSL &msl, const spirv_cross::Resource &resource, UniformType baseType);
 	void compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &program);
 
 	id<MTLFunction> functions[SHADERSTAGE_MAX_ENUM];

+ 118 - 95
src/modules/graphics/metal/Shader.mm

@@ -455,13 +455,105 @@ void Shader::buildLocalUniforms(const spirv_cross::CompilerMSL &msl, const spirv
 	}
 }
 
+void Shader::addImage(const spirv_cross::CompilerMSL &msl, const spirv_cross::Resource &resource, UniformType baseType)
+{
+	using namespace spirv_cross;
+
+	const SPIRType &basetype = msl.get_type(resource.base_type_id);
+	const SPIRType &type = msl.get_type(resource.type_id);
+	const SPIRType &imagetype = msl.get_type(basetype.image.type);
+
+	UniformInfo u = {};
+	u.baseType = baseType;
+	u.name = resource.name;
+	u.count = type.array.empty() ? 1 : type.array[0];
+	u.isDepthSampler = type.image.depth;
+	u.components = 1;
+
+	auto it = uniforms.find(u.name);
+	if (it != uniforms.end())
+		return;
+
+	if (!fillUniformReflectionData(u))
+		return;
+
+	switch (imagetype.basetype)
+	{
+	case SPIRType::Float:
+		u.dataBaseType = DATA_BASETYPE_FLOAT;
+		break;
+	case SPIRType::Int:
+		u.dataBaseType = DATA_BASETYPE_INT;
+		break;
+	case SPIRType::UInt:
+		u.dataBaseType = DATA_BASETYPE_UINT;
+		break;
+	default:
+		break;
+	}
+
+	switch (basetype.image.dim)
+	{
+	case spv::Dim2D:
+		u.textureType = basetype.image.arrayed ? TEXTURE_2D_ARRAY : TEXTURE_2D;
+		u.textures = new love::graphics::Texture*[u.count];
+		break;
+	case spv::Dim3D:
+		u.textureType = TEXTURE_VOLUME;
+		u.textures = new love::graphics::Texture*[u.count];
+		break;
+	case spv::DimCube:
+		if (basetype.image.arrayed)
+			throw love::Exception("Cubemap Arrays are not currently supported.");
+		u.textureType = TEXTURE_CUBE;
+		u.textures = new love::graphics::Texture*[u.count];
+		break;
+	case spv::DimBuffer:
+		u.baseType = UNIFORM_TEXELBUFFER;
+		u.buffers = new love::graphics::Buffer*[u.count];
+		break;
+	default:
+		// TODO: error? continue?
+		break;
+	}
+
+	u.dataSize = sizeof(int) * u.count;
+	u.data = malloc(u.dataSize);
+	for (int i = 0; i < u.count; i++)
+		u.ints[i] = -1; // Initialized below, after compiling.
+
+	if (u.baseType == UNIFORM_SAMPLER)
+	{
+		auto tex = Graphics::getInstance()->getDefaultTexture(u.textureType);
+		for (int i = 0; i < u.count; i++)
+		{
+			tex->retain();
+			u.textures[i] = tex;
+		}
+	}
+	else if (u.baseType == UNIFORM_TEXELBUFFER)
+	{
+		for (int i = 0; i < u.count; i++)
+			u.buffers[i] = nullptr; // TODO
+	}
+	else if (u.baseType == UNIFORM_STORAGETEXTURE)
+	{
+		for (int i = 0; i < u.count; i++)
+			u.textures[i] = nullptr;
+	}
+
+	uniforms[u.name] = u;
+
+	BuiltinUniform builtin;
+	if (getConstant(resource.name.c_str(), builtin))
+		builtinUniformInfo[builtin] = &uniforms[u.name];
+}
+
 void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &program)
 {
 	using namespace glslang;
 	using namespace spirv_cross;
 
-	auto gfx = Graphics::getInstance();
-
 	std::map<std::string, int> varyings;
 	int nextVaryingLocation = 0;
 
@@ -499,84 +591,14 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 
 			ShaderResources resources = msl.get_shader_resources();
 
-			for (const auto &resource : resources.sampled_images)
+			for (const auto &resource : resources.storage_images)
 			{
-				const SPIRType &basetype = msl.get_type(resource.base_type_id);
-				const SPIRType &type = msl.get_type(resource.type_id);
-				const SPIRType &imagetype = msl.get_type(basetype.image.type);
-
-				UniformInfo u = {};
-				u.baseType = UNIFORM_SAMPLER;
-				u.name = resource.name;
-				u.count = type.array.empty() ? 1 : type.array[0];
-				u.isDepthSampler = type.image.depth;
-				u.components = 1;
-
-				switch (imagetype.basetype)
-				{
-				case SPIRType::Float:
-					u.dataBaseType = DATA_BASETYPE_FLOAT;
-					break;
-				case SPIRType::Int:
-					u.dataBaseType = DATA_BASETYPE_INT;
-					break;
-				case SPIRType::UInt:
-					u.dataBaseType = DATA_BASETYPE_UINT;
-					break;
-				default:
-					break;
-				}
-
-				switch (basetype.image.dim)
-				{
-				case spv::Dim2D:
-					u.textureType = basetype.image.arrayed ? TEXTURE_2D_ARRAY : TEXTURE_2D;
-					u.textures = new love::graphics::Texture*[u.count];
-					break;
-				case spv::Dim3D:
-					u.textureType = TEXTURE_VOLUME;
-					u.textures = new love::graphics::Texture*[u.count];
-					break;
-				case spv::DimCube:
-					if (basetype.image.arrayed)
-						throw love::Exception("Cubemap Arrays are not currently supported.");
-					u.textureType = TEXTURE_CUBE;
-					u.textures = new love::graphics::Texture*[u.count];
-					break;
-				case spv::DimBuffer:
-					u.baseType = UNIFORM_TEXELBUFFER;
-					u.buffers = new love::graphics::Buffer*[u.count];
-					break;
-				default:
-					// TODO: error? continue?
-					break;
-				}
-
-				u.dataSize = sizeof(int) * u.count;
-				u.data = malloc(u.dataSize);
-				for (int i = 0; i < u.count; i++)
-					u.ints[i] = -1; // Initialized below, after compiling.
-
-				if (u.baseType == UNIFORM_SAMPLER)
-				{
-					auto tex = gfx->getDefaultTexture(u.textureType);
-					for (int i = 0; i < u.count; i++)
-					{
-						tex->retain();
-						u.textures[i] = tex;
-					}
-				}
-				else if (u.baseType == UNIFORM_TEXELBUFFER)
-				{
-					for (int i = 0; i < u.count; i++)
-						u.buffers[i] = nullptr; // TODO
-				}
-
-				uniforms[u.name] = u;
+				addImage(msl, resource, UNIFORM_STORAGETEXTURE);
+			}
 
-				BuiltinUniform builtin;
-				if (getConstant(resource.name.c_str(), builtin))
-					builtinUniformInfo[builtin] = &uniforms[u.name];
+			for (const auto &resource : resources.sampled_images)
+			{
+				addImage(msl, resource, UNIFORM_SAMPLER);
 			}
 
 			for (const auto &resource : resources.uniform_buffers)
@@ -639,19 +661,8 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 				u.name = resource.name;
 				u.count = type.array.empty() ? 1 : type.array[0];
 
-				const auto reflectionit = validationReflection.storageBuffers.find(u.name);
-				if (reflectionit != validationReflection.storageBuffers.end())
-				{
-					u.bufferStride = reflectionit->second.stride;
-					u.bufferMemberCount = reflectionit->second.memberCount;
-					u.access = reflectionit->second.access;
-				}
-				else
-				{
-					// No reflection info - maybe glslang was better at detecting
-					// dead code than the driver's compiler?
+				if (!fillUniformReflectionData(u))
 					continue;
-				}
 
 				u.buffers = new love::graphics::Buffer*[u.count];
 				u.dataSize = sizeof(int) * u.count;
@@ -741,11 +752,11 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 
 			functions[stageindex] = [library newFunctionWithName:library.functionNames[0]];
 
-			for (const auto &resource : resources.sampled_images)
+			auto setTextureBinding = [this](CompilerMSL &msl, int stageindex, const spirv_cross::Resource &resource) -> void
 			{
 				auto it = uniforms.find(resource.name);
 				if (it == uniforms.end())
-					continue;
+					return;
 
 				UniformInfo &u = it->second;
 
@@ -753,7 +764,7 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 				uint32 samplerbinding = msl.get_automatic_msl_resource_binding_secondary(resource.id);
 
 				if (texturebinding == (uint32)-1)
-					continue;
+					return;
 
 				for (int i = 0; i < u.count; i++)
 				{
@@ -761,6 +772,7 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 					{
 						u.ints[i] = (int)textureBindings.size();
 						TextureBinding b = {};
+						b.access = u.access;
 
 						if (u.baseType == UNIFORM_TEXELBUFFER)
 						{
@@ -788,6 +800,16 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 					b.textureStages[stageindex] = (uint8) texturebinding;
 					b.samplerStages[stageindex] = (uint8) samplerbinding;
 				}
+			};
+
+			for (const auto &resource : resources.sampled_images)
+			{
+				setTextureBinding(msl, stageindex, resource);
+			}
+
+			for (const auto &resource : resources.storage_images)
+			{
+				setTextureBinding(msl, stageindex, resource);
 			}
 
 			for (const auto &resource : resources.storage_buffers)
@@ -808,6 +830,7 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 					{
 						u.ints[i] = (int)bufferBindings.size();
 						BufferBinding b = {};
+						b.access = u.access;
 
 						for (uint8 &stagebinding : b.stages)
 							stagebinding = LOVE_UINT8_MAX;
@@ -841,7 +864,7 @@ Shader::~Shader()
 	for (const auto &it : uniforms)
 	{
 		const auto &u = it.second;
-		if (u.baseType == UNIFORM_SAMPLER)
+		if (u.baseType == UNIFORM_SAMPLER || u.baseType == UNIFORM_STORAGETEXTURE)
 		{
 			free(u.data);
 			for (int i = 0; i < u.count; i++)
@@ -940,7 +963,7 @@ void Shader::updateUniform(const UniformInfo *info, int count)
 
 void Shader::sendTextures(const UniformInfo *info, love::graphics::Texture **textures, int count)
 { @autoreleasepool {
-	if (info->baseType != UNIFORM_SAMPLER)
+	if (info->baseType != UNIFORM_SAMPLER && info->baseType != UNIFORM_STORAGETEXTURE)
 		return;
 
 	if (current == this)

+ 4 - 25
src/modules/graphics/opengl/Shader.cpp

@@ -139,6 +139,9 @@ void Shader::mapActiveUniforms()
 		if (u.location == -1)
 			continue;
 
+		if (!fillUniformReflectionData(u))
+			continue;;
+
 		if ((u.baseType == UNIFORM_SAMPLER && builtin != BUILTIN_TEXTURE_MAIN) || u.baseType == UNIFORM_TEXELBUFFER)
 		{
 			TextureUnit unit;
@@ -161,19 +164,6 @@ void Shader::mapActiveUniforms()
 		}
 		else if (u.baseType == UNIFORM_STORAGETEXTURE)
 		{
-			const auto reflectionit = validationReflection.storageTextures.find(u.name);
-			if (reflectionit != validationReflection.storageTextures.end())
-			{
-				u.storageTextureFormat = reflectionit->second.format;
-				u.access = reflectionit->second.access;
-			}
-			else
-			{
-				// No reflection info - maybe glslang was better at detecting
-				// dead code than the driver's compiler?
-				continue;
-			}
-
 			StorageTextureBinding binding = {};
 			binding.gltexture = gl.getDefaultTexture(u.textureType, u.dataBaseType);
 			binding.type = u.textureType;
@@ -373,19 +363,8 @@ void Shader::mapActiveUniforms()
 			u.name = std::string(namebuffer, namelength);
 			u.count = 1;
 
-			const auto reflectionit = validationReflection.storageBuffers.find(u.name);
-			if (reflectionit != validationReflection.storageBuffers.end())
-			{
-				u.bufferStride = reflectionit->second.stride;
-				u.bufferMemberCount = reflectionit->second.memberCount;
-				u.access = reflectionit->second.access;
-			}
-			else
-			{
-				// No reflection info - maybe glslang was better at detecting
-				// dead code than the driver's compiler?
+			if (!fillUniformReflectionData(u))
 				continue;
-			}
 
 			// Make sure previously set uniform data is preserved, and shader-
 			// initialized values are retrieved.

+ 9 - 10
src/modules/graphics/vulkan/Shader.cpp

@@ -872,18 +872,11 @@ void Shader::compileShaders()
 			u.components = 1;
 			u.name = r.name;
 			u.count = type.array.empty() ? 1 : type.array[0];
-			u.location = comp.get_decoration(r.id, spv::DecorationBinding);
-			
-			const auto reflectionit = validationReflection.storageBuffers.find(u.name);
-			if (reflectionit != validationReflection.storageBuffers.end())
-			{
-				u.bufferStride = reflectionit->second.stride;
-				u.bufferMemberCount = reflectionit->second.memberCount;
-				u.access = reflectionit->second.access;
-			}
-			else
+
+			if (!fillUniformReflectionData(u))
 				continue;
 
+			u.location = comp.get_decoration(r.id, spv::DecorationBinding);
 			u.buffers = new love::graphics::Buffer *[u.count];
 
 			for (int i = 0; i < u.count; i++)
@@ -901,6 +894,10 @@ void Shader::compileShaders()
 			u.components = 1;
 			u.name = r.name;
 			u.count = type.array.empty() ? 1 : type.array[0];
+
+			if (!fillUniformReflectionData(u))
+				continue;
+
 			u.textures = new love::graphics::Texture *[u.count];
 			u.location = comp.get_decoration(r.id, spv::DecorationBinding);
 
@@ -913,12 +910,14 @@ void Shader::compileShaders()
 		}
 
 		if (shaderStage == SHADERSTAGE_VERTEX)
+		{
 			for (const auto &r : shaderResources.stage_inputs)
 			{
 				const auto &name = r.name;
 				const int attributeLocation = static_cast<int>(comp.get_decoration(r.id, spv::DecorationLocation));
 				attributes[name] = attributeLocation;
 			}
+		}
 	}
 
 	delete program;