Browse Source

metal: support constant initializer values for local uniforms.

Alex Szpakowski 3 years ago
parent
commit
e602f08c24

+ 17 - 7
src/libraries/glslang/glslang/MachineIndependent/reflection.cpp

@@ -113,7 +113,7 @@ public:
             // Use a degenerate (empty) set of dereferences to immediately put as at the end of
             // the dereference change expected by blowUpActiveAggregate.
             blowUpActiveAggregate(base.getType(), baseName, derefs, derefs.end(), offset, blockIndex, 0, -1, 0,
-                                    base.getQualifier().storage, updateStageMasks);
+                                    base.getQualifier().storage, updateStageMasks, &base.getConstArray());
         }
     }
 
@@ -250,7 +250,8 @@ public:
     // A value of 0 for arraySize will mean to use the full array's size.
     void blowUpActiveAggregate(const TType& baseType, const TString& baseName, const TList<TIntermBinary*>& derefs,
                                TList<TIntermBinary*>::const_iterator deref, int offset, int blockIndex, int arraySize,
-                               int topLevelArraySize, int topLevelArrayStride, TStorageQualifier baseStorage, bool active)
+                               int topLevelArraySize, int topLevelArrayStride, TStorageQualifier baseStorage, bool active,
+							   const TConstUnionArray* constArray = nullptr)
     {
         // when strictArraySuffix is enabled, we closely follow the rules from ARB_program_interface_query.
         // Broadly:
@@ -265,9 +266,18 @@ public:
         // process the part of the dereference chain that was explicit in the shader
         TString name = baseName;
         const TType* terminalType = &baseType;
+		const TConstUnionArray* terminalConstArray = constArray;
         for (; deref != derefs.end(); ++deref) {
             TIntermBinary* visitNode = *deref;
             terminalType = &visitNode->getType();
+			if (visitNode->getAsSymbolNode())
+				terminalConstArray = &visitNode->getAsSymbolNode()->getConstArray();
+			else if (visitNode->getAsConstantUnion())
+				terminalConstArray = &visitNode->getAsConstantUnion()->getConstArray();
+			else if (visitNode->getLeft() != nullptr && visitNode->getLeft()->getAsSymbolNode())
+				terminalConstArray = &visitNode->getLeft()->getAsSymbolNode()->getConstArray();
+			else
+				terminalConstArray = nullptr;
             int index;
             switch (visitNode->getOp()) {
             case EOpIndexIndirect: {
@@ -450,7 +460,7 @@ public:
             int uniformIndex = (int)variables.size();
             reflection.nameToIndex[name.c_str()] = uniformIndex;
             variables.push_back(TObjectReflection(name.c_str(), *terminalType, offset, mapToGlType(*terminalType),
-                                                  arraySize, blockIndex));
+                                                  arraySize, blockIndex, terminalConstArray));
             if (terminalType->isArray()) {
                 variables.back().arrayStride = getArrayStride(baseType, *terminalType);
                 if (topLevelArrayStride == 0)
@@ -602,7 +612,7 @@ public:
                 // otherwise - if we're not using strict array suffix rules, or this isn't a block so we are
                 // expanding root arrays anyway, just start the iteration from the base block type.
                 blowUpActiveAggregate(base->getType(), baseName, derefs, derefs.end(), 0, blockIndex, 0, -1, 0,
-                                          base->getQualifier().storage, false);
+                                          base->getQualifier().storage, false, &base->getConstArray());
             }
         }
 
@@ -634,7 +644,7 @@ public:
                 baseName = base->getName();
         }
         blowUpActiveAggregate(base->getType(), baseName, derefs, derefs.begin(), offset, blockIndex, arraySize, -1, 0,
-                              base->getQualifier().storage, true);
+                              base->getQualifier().storage, true, &base->getConstArray());
     }
 
     int addBlockName(const TString& name, const TType& type, int size)
@@ -1078,9 +1088,9 @@ void TReflectionTraverser::visitSymbol(TIntermSymbol* base)
 //
 
 TObjectReflection::TObjectReflection(const std::string &pName, const TType &pType, int pOffset, int pGLDefineType,
-                                     int pSize, int pIndex)
+                                     int pSize, int pIndex, const TConstUnionArray* pConstArray)
     : name(pName), offset(pOffset), glDefineType(pGLDefineType), size(pSize), index(pIndex), counterIndex(-1),
-      numMembers(-1), arrayStride(0), topLevelArrayStride(0), stages(EShLanguageMask(0)), type(pType.clone())
+      numMembers(-1), arrayStride(0), topLevelArrayStride(0), stages(EShLanguageMask(0)), type(pType.clone()), constArray(pConstArray)
 {
 }
 

+ 5 - 2
src/libraries/glslang/glslang/Public/ShaderLang.h

@@ -140,6 +140,7 @@ typedef enum : unsigned {
 namespace glslang {
 
 class TType;
+class TConstUnionArray;
 
 typedef enum {
     EShSourceNone,
@@ -723,9 +724,10 @@ private:
 // Data needed for just a single object at the granularity exchanged by the reflection API
 class TObjectReflection {
 public:
-    GLSLANG_EXPORT TObjectReflection(const std::string& pName, const TType& pType, int pOffset, int pGLDefineType, int pSize, int pIndex);
+    GLSLANG_EXPORT TObjectReflection(const std::string& pName, const TType& pType, int pOffset, int pGLDefineType, int pSize, int pIndex, const TConstUnionArray* pConstArray = nullptr);
 
     const TType* getType() const { return type; }
+	const TConstUnionArray* getConstArray() const { return constArray; }
     GLSLANG_EXPORT int getBinding() const;
     GLSLANG_EXPORT void dump() const;
     static TObjectReflection badReflection() { return TObjectReflection(); }
@@ -745,11 +747,12 @@ public:
 protected:
     TObjectReflection()
         : offset(-1), glDefineType(-1), size(-1), index(-1), counterIndex(-1), numMembers(-1), arrayStride(0),
-          topLevelArrayStride(0), stages(EShLanguageMask(0)), type(nullptr)
+          topLevelArrayStride(0), stages(EShLanguageMask(0)), type(nullptr), constArray(nullptr)
     {
     }
 
     const TType* type;
+	const TConstUnionArray* constArray;
 };
 
 class  TReflection;

+ 62 - 0
src/modules/graphics/Shader.cpp

@@ -792,6 +792,24 @@ static PixelFormat getPixelFormat(glslang::TLayoutFormat format)
 	}
 }
 
+template <typename T>
+static T convertData(const glslang::TConstUnion &data)
+{
+	switch (data.getType())
+	{
+		case glslang::EbtInt: return (T) data.getIConst();
+		case glslang::EbtUint: return (T) data.getUConst();
+		case glslang::EbtDouble: return (T) data.getDConst();
+		case glslang::EbtInt8: return (T) data.getI8Const();
+		case glslang::EbtInt16: return (T) data.getI16Const();
+		case glslang::EbtInt64: return (T) data.getI64Const();
+		case glslang::EbtUint8: return (T) data.getU8Const();
+		case glslang::EbtUint16: return (T) data.getU16Const();
+		case glslang::EbtUint64: return (T) data.getU64Const();
+		default: return 0;
+	}
+}
+
 bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err, ValidationReflection &reflection)
 {
 	glslang::TProgram program;
@@ -871,6 +889,50 @@ bool Shader::validateInternal(StrongRef<ShaderStage> stages[], std::string &err,
 
 			reflection.storageTextures[info.name] = texreflection;
 		}
+		else if (!type->isOpaque())
+		{
+			LocalUniform u = {};
+			auto &values = u.initializerValues;
+			const glslang::TConstUnionArray *constarray = info.getConstArray();
+
+			// Store initializer values for local uniforms. Some love graphics
+			// backends strip these out of the shader so we need to be able to
+			// access them (to re-send them) by getting them here.
+			switch (type->getBasicType())
+			{
+			case glslang::EbtFloat:
+				u.dataType = DATA_BASETYPE_FLOAT;
+				if (constarray != nullptr)
+				{
+					values.resize(constarray->size());
+					for (int i = 0; i < constarray->size(); i++)
+						values[i].f = convertData<float>((*constarray)[i]);
+				}
+				break;
+			case glslang::EbtUint:
+				u.dataType = DATA_BASETYPE_UINT;
+				if (constarray != nullptr)
+				{
+					values.resize(constarray->size());
+					for (int i = 0; i < constarray->size(); i++)
+						values[i].u = convertData<uint32>((*constarray)[i]);
+				}
+				break;
+			case glslang::EbtInt:
+			case glslang::EbtBool:
+			default:
+				u.dataType = DATA_BASETYPE_INT;
+				if (constarray != nullptr)
+				{
+					values.resize(constarray->size());
+					for (int i = 0; i < constarray->size(); i++)
+						values[i].i = convertData<int32>((*constarray)[i]);
+				}
+				break;
+			}
+
+			reflection.localUniforms[info.name] = u;
+		}
 	}
 
 	for (int i = 0; i < program.getNumBufferBlocks(); i++)

+ 14 - 0
src/modules/graphics/Shader.h

@@ -158,6 +158,13 @@ public:
 		};
 	};
 
+	union LocalUniformValue
+	{
+		float f;
+		int32 i;
+		uint32 u;
+	};
+
 	// The members in here must respect uniform buffer alignment/padding rules.
  	struct BuiltinUniformData
  	{
@@ -259,10 +266,17 @@ protected:
 		Access access;
 	};
 
+	struct LocalUniform
+	{
+		DataBaseType dataType;
+		std::vector<LocalUniformValue> initializerValues;
+	};
+
 	struct ValidationReflection
 	{
 		std::map<std::string, BufferReflection> storageBuffers;
 		std::map<std::string, StorageTextureReflection> storageTextures;
+		std::map<std::string, LocalUniform> localUniforms;
 		int localThreadgroupSize[3];
 		bool usesPointSize;
 	};

+ 12 - 0
src/modules/graphics/metal/Shader.mm

@@ -565,6 +565,16 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 								u.matrix.rows = membertype.vecsize;
 								u.matrix.columns = membertype.columns;
 							}
+							if (validationReflection.localUniforms.find(u.name) != validationReflection.localUniforms.end())
+							{
+								const auto &ru = validationReflection.localUniforms.find(u.name);
+								const auto &values = ru->second.initializerValues;
+								if (!values.empty())
+								{
+									memcpy(u.data, values.data(), std::min(u.dataSize, values.size() * sizeof(LocalUniformValue)));
+								}
+							}
+							updateUniform(&u, u.count);
 							break;
 						case SPIRType::Struct:
 							// TODO
@@ -582,6 +592,8 @@ void Shader::compileFromGLSLang(id<MTLDevice> device, const glslang::TProgram &p
 								builtinUniformDataOffset = offset;
 							builtinUniformInfo[builtin] = &uniforms[u.name];
 						}
+
+
 					}
 				}
 				else