Selaa lähdekoodia

Preliminary work for supporting custom functions.

Eric Mellino 8 vuotta sitten
vanhempi
sitoutus
3ea080ac20

+ 20 - 36
src/ShaderGen.Tests/ShaderGeneratorTests.cs

@@ -7,19 +7,25 @@ namespace ShaderGen.Tests
 {
     public class ShaderGeneratorTests
     {
+        private static IEnumerable<object[]> ShaderSets()
+        {
+            yield return new object[] { "TestShaders.TestVertexShader.VS", null };
+            yield return new object[] { null, "TestShaders.TestFragmentShader.FS" };
+            yield return new object[] { "TestShaders.TestVertexShader.VS", "TestShaders.TestFragmentShader.FS" };
+            yield return new object[] { null, "TestShaders.TextureSamplerFragment.FS" };
+            yield return new object[] { "TestShaders.VertexAndFragment.VS", "TestShaders.VertexAndFragment.FS" };
+            yield return new object[] { null, "TestShaders.ComplexExpression.FS" };
+            yield return new object[] { "TestShaders.PartialVertex.VertexShaderFunc", null };
+            yield return new object[] { "TestShaders.VeldridShaders.ForwardMtlCombined.VS", "TestShaders.VeldridShaders.ForwardMtlCombined.FS" };
+            yield return new object[] { "TestShaders.VeldridShaders.ForwardMtlCombined.VS", null };
+            yield return new object[] { null, "TestShaders.VeldridShaders.ForwardMtlCombined.FS" };
+            yield return new object[] { "TestShaders.CustomStructResource.VS", null };
+            yield return new object[] { "TestShaders.Swizzles.VS", null };
+            yield return new object[] { "TestShaders.CustomMethodCalls.VS", null };
+        }
+
         [Theory]
-        [InlineData("TestShaders.TestVertexShader.VS", null)]
-        [InlineData(null, "TestShaders.TestFragmentShader.FS")]
-        [InlineData("TestShaders.TestVertexShader.VS", "TestShaders.TestFragmentShader.FS")]
-        [InlineData(null, "TestShaders.TextureSamplerFragment.FS")]
-        [InlineData("TestShaders.VertexAndFragment.VS", "TestShaders.VertexAndFragment.FS")]
-        [InlineData(null, "TestShaders.ComplexExpression.FS")]
-        [InlineData("TestShaders.PartialVertex.VertexShaderFunc", null)]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", null)]
-        [InlineData(null, "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.CustomStructResource.VS", null)]
-        [InlineData("TestShaders.Swizzles.VS", null)]
+        [MemberData(nameof(ShaderSets))]
         public void HlslEndToEnd(string vsName, string fsName)
         {
             Compilation compilation = TestUtil.GetTestProjectCompilation();
@@ -51,18 +57,7 @@ namespace ShaderGen.Tests
         }
 
         [Theory]
-        [InlineData("TestShaders.TestVertexShader.VS", null)]
-        [InlineData(null, "TestShaders.TestFragmentShader.FS")]
-        [InlineData("TestShaders.TestVertexShader.VS", "TestShaders.TestFragmentShader.FS")]
-        [InlineData(null, "TestShaders.TextureSamplerFragment.FS")]
-        [InlineData("TestShaders.VertexAndFragment.VS", "TestShaders.VertexAndFragment.FS")]
-        [InlineData(null, "TestShaders.ComplexExpression.FS")]
-        [InlineData("TestShaders.PartialVertex.VertexShaderFunc", null)]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", null)]
-        [InlineData(null, "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.CustomStructResource.VS", null)]
-        [InlineData("TestShaders.Swizzles.VS", null)]
+        [MemberData(nameof(ShaderSets))]
         public void Glsl330EndToEnd(string vsName, string fsName)
         {
             Compilation compilation = TestUtil.GetTestProjectCompilation();
@@ -94,18 +89,7 @@ namespace ShaderGen.Tests
         }
 
         [Theory]
-        [InlineData("TestShaders.TestVertexShader.VS", null)]
-        [InlineData(null, "TestShaders.TestFragmentShader.FS")]
-        [InlineData("TestShaders.TestVertexShader.VS", "TestShaders.TestFragmentShader.FS")]
-        [InlineData(null, "TestShaders.TextureSamplerFragment.FS")]
-        [InlineData("TestShaders.VertexAndFragment.VS", "TestShaders.VertexAndFragment.FS")]
-        [InlineData(null, "TestShaders.ComplexExpression.FS")]
-        [InlineData("TestShaders.PartialVertex.VertexShaderFunc", null)]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.VeldridShaders.ForwardMtlCombined.VS", null)]
-        [InlineData(null, "TestShaders.VeldridShaders.ForwardMtlCombined.FS")]
-        [InlineData("TestShaders.CustomStructResource.VS", null)]
-        [InlineData("TestShaders.Swizzles.VS", null)]
+        [MemberData(nameof(ShaderSets))]
         public void Glsl450EndToEnd(string vsName, string fsName)
         {
             Compilation compilation = TestUtil.GetTestProjectCompilation();

+ 33 - 0
src/ShaderGen.Tests/TestAssets/CustomMethodCalls.cs

@@ -0,0 +1,33 @@
+using ShaderGen;
+using System.Numerics;
+
+namespace TestShaders
+{
+    public class CustomMethodCalls
+    {
+        [VertexShader]
+        Position4 VS(Position4 input)
+        {
+            Position4 reversed = Reverse(input);
+            return ShufflePosition4(reversed);
+        }
+
+        private Position4 Reverse(Position4 vert)
+        {
+            vert.Position = vert.Position.WZYX();
+            return vert;
+        }
+
+        private Position4 ShufflePosition4(Position4 vert)
+        {
+            vert.Position = ShuffleVector4(vert.Position);
+            return vert;
+        }
+
+        private Vector4 ShuffleVector4(Vector4 v)
+        {
+            Vector4 result = v.XZYW();
+            return result;
+        }
+    }
+}

+ 164 - 0
src/ShaderGen/FunctionCallGraphDiscoverer.cs

@@ -0,0 +1,164 @@
+using Microsoft.CodeAnalysis.CSharp;
+using System.Collections.Generic;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System;
+using Microsoft.CodeAnalysis;
+using System.Linq;
+using System.Diagnostics;
+
+namespace ShaderGen
+{
+    internal class FunctionCallGraphDiscoverer
+    {
+        public Compilation Compilation { get; }
+        private CallGraphNode _rootNode;
+        private Dictionary<TypeAndMethodName, CallGraphNode> _nodesByName = new Dictionary<TypeAndMethodName, CallGraphNode>();
+
+        public FunctionCallGraphDiscoverer(Compilation compilation, TypeAndMethodName rootMethod)
+        {
+            Compilation = compilation;
+            _rootNode = new CallGraphNode() { Name = rootMethod };
+            bool foundDecl = GetDeclaration(rootMethod, out _rootNode.Declaration);
+            Debug.Assert(foundDecl);
+            _nodesByName.Add(rootMethod, _rootNode);
+        }
+
+        public TypeAndMethodName[] GetOrderedCallList()
+        {
+            List<TypeAndMethodName> result = new List<TypeAndMethodName>();
+
+            TraverseNode(result, _rootNode);
+
+
+            return result.ToArray();
+        }
+
+        private void TraverseNode(List<TypeAndMethodName> result, CallGraphNode node)
+        {
+            foreach (TypeAndMethodName existing in result)
+            {
+                if (node.Parents.Any(cgn => cgn.Name.Equals(existing)))
+                {
+                    throw new ShaderGenerationException("There was a cyclical call graph involving " + existing + " and " + node.Name);
+                }
+            }
+
+            foreach (CallGraphNode child in node.Children)
+            {
+                TraverseNode(result, child);
+            }
+
+            result.Add(node.Name);
+        }
+
+        public void GenerateFullGraph()
+        {
+            ExploreCallNode(_rootNode);
+        }
+
+        private void ExploreCallNode(CallGraphNode node)
+        {
+            Debug.Assert(node.Declaration != null);
+            MethodWalker walker = new MethodWalker(this);
+            walker.Visit(node.Declaration);
+            TypeAndMethodName[] childrenNames = walker.GetChildren();
+            foreach (TypeAndMethodName childName in childrenNames)
+            {
+                CallGraphNode childNode = GetNode(childName);
+                if (childNode.Declaration != null)
+                {
+                    childNode.Parents.Add(node);
+                    node.Children.Add(childNode);
+                    ExploreCallNode(childNode);
+                }
+            }
+        }
+
+        private CallGraphNode GetNode(TypeAndMethodName name)
+        {
+            if (!_nodesByName.TryGetValue(name, out CallGraphNode node))
+            {
+                node = new CallGraphNode() { Name = name };
+                GetDeclaration(name, out node.Declaration);
+                _nodesByName.Add(name, node);
+            }
+
+            return node;
+        }
+
+        private bool GetDeclaration(TypeAndMethodName name, out MethodDeclarationSyntax decl)
+        {
+            INamedTypeSymbol symb = Compilation.GetTypeByMetadataName(name.TypeName);
+            foreach (SyntaxReference synRef in symb.DeclaringSyntaxReferences)
+            {
+                SyntaxNode node = synRef.GetSyntax();
+                foreach (SyntaxNode child in node.ChildNodes())
+                {
+                    if (child is MethodDeclarationSyntax mds)
+                    {
+                        if (mds.Identifier.ToFullString() == name.MethodName)
+                        {
+                            decl = mds;
+                            return true;
+                        }
+                    }
+                }
+            }
+
+            decl = null;
+            return false;
+        }
+
+        private class MethodWalker : CSharpSyntaxWalker
+        {
+            private readonly FunctionCallGraphDiscoverer _discoverer;
+            private readonly HashSet<TypeAndMethodName> _children = new HashSet<TypeAndMethodName>();
+
+            public MethodWalker(FunctionCallGraphDiscoverer discoverer)
+            {
+                _discoverer = discoverer;
+            }
+
+            public override void VisitInvocationExpression(InvocationExpressionSyntax node)
+            {
+                if (node.Expression is IdentifierNameSyntax ins)
+                {
+                    SymbolInfo symbolInfo = _discoverer.Compilation.GetSemanticModel(node.SyntaxTree).GetSymbolInfo(ins);
+                    string containingType = symbolInfo.Symbol.ContainingType.ToDisplayString();
+                    string methodName = symbolInfo.Symbol.Name;
+                    _children.Add(new TypeAndMethodName() { TypeName = containingType, MethodName = methodName });
+                    return;
+                }
+                else if (node.Expression is MemberAccessExpressionSyntax maes)
+                {
+                    SymbolInfo methodSymbol = _discoverer.Compilation.GetSemanticModel(maes.SyntaxTree).GetSymbolInfo(maes);
+                    if (methodSymbol.Symbol is IMethodSymbol ims)
+                    {
+                        string containingType = Utilities.GetFullMetadataName(ims.ContainingType);
+                        string methodName = ims.MetadataName;
+                        _children.Add(new TypeAndMethodName() { TypeName = containingType, MethodName = methodName });
+                        return;
+                    }
+                }
+
+                throw new NotImplementedException();
+            }
+
+            public TypeAndMethodName[] GetChildren() => _children.ToArray();
+        }
+    }
+
+    internal class CallGraphNode
+    {
+        public TypeAndMethodName Name;
+        /// <summary>
+        /// May be null.
+        /// </summary>
+        public MethodDeclarationSyntax Declaration;
+        /// <summary>
+        /// Functions called by this function.
+        /// </summary>
+        public HashSet<CallGraphNode> Children = new HashSet<CallGraphNode>();
+        public HashSet<CallGraphNode> Parents = new HashSet<CallGraphNode>();
+    }
+}

+ 1 - 1
src/ShaderGen/Glsl330Backend.cs

@@ -41,7 +41,7 @@ namespace ShaderGen
             sb.AppendLine();
         }
 
-        protected override string FormatInvocationCore(string type, string method, InvocationParameterInfo[] parameterInfos)
+        protected override string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
         {
             return Glsl330KnownFunctions.TranslateInvocation(type, method, parameterInfos);
         }

+ 1 - 1
src/ShaderGen/Glsl450Backend.cs

@@ -77,7 +77,7 @@ namespace ShaderGen
 
         }
 
-        protected override string FormatInvocationCore(string type, string method, InvocationParameterInfo[] parameterInfos)
+        protected override string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
         {
             return Glsl450KnownFunctions.TranslateInvocation(type, method, parameterInfos);
         }

+ 1 - 1
src/ShaderGen/GlslBackendBase.cs

@@ -90,7 +90,7 @@ namespace ShaderGen
                 }
             }
 
-            string result = new ShaderMethodVisitor(Compilation, entryPoint.Function, this)
+            string result = new ShaderMethodVisitor(Compilation, setName, entryPoint.Function, this)
                 .VisitFunction(entryPoint.Block);
             sb.AppendLine(result);
 

+ 41 - 11
src/ShaderGen/HlslBackend.cs

@@ -5,6 +5,7 @@ using System.Collections.Generic;
 using System.Linq;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using System.IO;
+using System.Diagnostics;
 
 namespace ShaderGen
 {
@@ -130,9 +131,11 @@ namespace ShaderGen
 
         protected override string GenerateFullTextCore(string setName, ShaderFunction function)
         {
-            StringBuilder sb = new StringBuilder();
+            Debug.Assert(function.IsEntryPoint);
 
-            ShaderFunctionAndBlockSyntax entryPoint = GetContext(setName).Functions.SingleOrDefault(
+            StringBuilder sb = new StringBuilder();
+            BackendContext setContext = GetContext(setName);
+            ShaderFunctionAndBlockSyntax entryPoint = setContext.Functions.SingleOrDefault(
                 sfabs => sfabs.Function.Name == function.Name);
             if (entryPoint == null)
             {
@@ -147,21 +150,21 @@ namespace ShaderGen
             {
                 // HLSL vertex outputs needs to have semantics applied to the structure fields.
                 StructureDefinition output = CreateOutputStructure(GetRequiredStructureType(setName, entryPoint.Function.ReturnType));
-                GetContext(setName).Functions.Remove(entryPoint);
+                setContext.Functions.Remove(entryPoint);
                 entryPoint = entryPoint.WithReturnType(new TypeReference(output.Name));
-                GetContext(setName).Functions.Add(entryPoint);
+                setContext.Functions.Add(entryPoint);
             }
 
             if (function.Type == ShaderFunctionType.FragmentEntryPoint)
             {
                 // HLSL pixel shader inputs also need these semantics.
                 StructureDefinition modifiedInput = CreateOutputStructure(input);
-                GetContext(setName).Functions.Remove(entryPoint);
+                setContext.Functions.Remove(entryPoint);
                 entryPoint = entryPoint.WithParameter(0, new TypeReference(modifiedInput.Name));
-                GetContext(setName).Functions.Add(entryPoint);
+                setContext.Functions.Add(entryPoint);
             }
 
-            foreach (StructureDefinition sd in GetContext(setName).Structures)
+            foreach (StructureDefinition sd in setContext.Structures)
             {
                 WriteStructure(sb, sd);
             }
@@ -170,8 +173,24 @@ namespace ShaderGen
                 WriteStructure(sb, sd);
             }
 
+            FunctionCallGraphDiscoverer fcgd = new FunctionCallGraphDiscoverer(
+                Compilation,
+                new TypeAndMethodName { TypeName = function.DeclaringType, MethodName = function.Name });
+            fcgd.GenerateFullGraph();
+            TypeAndMethodName[] orderedFunctionList = fcgd.GetOrderedCallList();
+
+            foreach (TypeAndMethodName name in orderedFunctionList)
+            {
+                ShaderFunctionAndBlockSyntax f = setContext.Functions.Single(
+                    sfabs => sfabs.Function.DeclaringType == name.TypeName && sfabs.Function.Name == name.MethodName);
+                if (!f.Function.IsEntryPoint)
+                {
+                    sb.AppendLine(new HlslMethodVisitor(Compilation, setName, f.Function, this).VisitFunction(f.Block));
+                }
+            }
+
             int uniformBinding = 0, textureBinding = 0, samplerBinding = 0;
-            foreach (ResourceDefinition rd in GetContext(setName).Resources)
+            foreach (ResourceDefinition rd in setContext.Resources)
             {
                 switch (rd.ResourceKind)
                 {
@@ -191,7 +210,7 @@ namespace ShaderGen
                 }
             }
 
-            string result = new HlslMethodVisitor(Compilation, entryPoint.Function, this)
+            string result = new HlslMethodVisitor(Compilation, setName, entryPoint.Function, this)
                 .VisitFunction(entryPoint.Block);
             sb.AppendLine(result);
 
@@ -235,9 +254,20 @@ namespace ShaderGen
             return clone;
         }
 
-        protected override string FormatInvocationCore(string type, string method, InvocationParameterInfo[] parameterInfos)
+        protected override string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
         {
-            return HlslKnownFunctions.TranslateInvocation(type, method, parameterInfos);
+            ShaderFunctionAndBlockSyntax function = GetContext(setName).Functions
+                .SingleOrDefault(sfabs => sfabs.Function.DeclaringType == type && sfabs.Function.Name == method);
+            if (function != null)
+            {
+                string invocationList = string.Join(", ", parameterInfos.Select(ipi => CSharpToIdentifierNameCore(ipi.FullTypeName, ipi.Identifier)));
+                string fullMethodName = CSharpToShaderType(function.Function.DeclaringType) + "_" + function.Function.Name;
+                return $"{fullMethodName}({invocationList})";
+            }
+            else
+            {
+                return HlslKnownFunctions.TranslateInvocation(type, method, parameterInfos);
+            }
         }
 
         protected override string CSharpToIdentifierNameCore(string typeName, string identifier)

+ 7 - 3
src/ShaderGen/HlslMethodVisitor.cs

@@ -4,8 +4,8 @@ namespace ShaderGen
 {
     internal class HlslMethodVisitor : ShaderMethodVisitor
     {
-        public HlslMethodVisitor(Compilation compilation, ShaderFunction shaderFunction, LanguageBackend backend)
-            : base(compilation, shaderFunction, backend)
+        public HlslMethodVisitor(Compilation compilation, string setName, ShaderFunction shaderFunction, LanguageBackend backend)
+            : base(compilation, setName, shaderFunction, backend)
         {
         }
 
@@ -13,7 +13,11 @@ namespace ShaderGen
         {
             string returnType = _backend.CSharpToShaderType(_shaderFunction.ReturnType.Name);
             string suffix = _shaderFunction.Type == ShaderFunctionType.FragmentEntryPoint ? " : SV_Target" : string.Empty;
-            string functionDeclStr = $"{returnType} {_shaderFunction.Name}({GetParameterDeclList()}){suffix}";
+            string fullDeclType = _backend.CSharpToShaderType(_shaderFunction.DeclaringType);
+            string funcName = _shaderFunction.IsEntryPoint
+                ? _shaderFunction.Name
+                : fullDeclType + "_" + _shaderFunction.Name;
+            string functionDeclStr = $"{returnType} {funcName}({GetParameterDeclList()}){suffix}";
             return functionDeclStr;
         }
     }

+ 0 - 1
src/ShaderGen/InvocationParameterInfo.cs

@@ -11,6 +11,5 @@ namespace ShaderGen
         {
             return string.Join(", ", parameterInfos.Select(pi => pi.Identifier));
         }
-
     }
 }

+ 7 - 8
src/ShaderGen/LanguageBackend.cs

@@ -67,7 +67,10 @@ namespace ShaderGen
             // HACK: Discover all method input structures.
             foreach (ShaderFunctionAndBlockSyntax sf in context.Functions.ToArray())
             {
-                GetCode(setName, sf.Function);
+                if (sf.Function.IsEntryPoint)
+                {
+                    GetCode(setName, sf.Function);
+                }
             }
 
             return new ShaderModel(
@@ -99,10 +102,6 @@ namespace ShaderGen
             {
                 throw new ArgumentNullException(nameof(function));
             }
-            if (!function.IsEntryPoint)
-            {
-                throw new ArgumentException($"IsEntryPoint must be true for parameter {nameof(function)}");
-            }
 
             if (!_fullTextShaders.TryGetValue(function, out string result))
             {
@@ -165,13 +164,13 @@ namespace ShaderGen
             return CorrectIdentifier(CSharpToIdentifierNameCore(typeName, identifier));
         }
 
-        internal string FormatInvocation(string type, string method, InvocationParameterInfo[] parameterInfos)
+        internal string FormatInvocation(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
         {
             Debug.Assert(type != null);
             Debug.Assert(method != null);
             Debug.Assert(parameterInfos != null);
 
-            return FormatInvocationCore(type, method, parameterInfos);
+            return FormatInvocationCore(setName, type, method, parameterInfos);
         }
 
         protected void ValidateRequiredSemantics(string setName, ShaderFunction function, ShaderFunctionType type)
@@ -243,7 +242,7 @@ namespace ShaderGen
         protected abstract string CSharpToShaderTypeCore(string fullType);
         protected abstract string CSharpToIdentifierNameCore(string typeName, string identifier);
         protected abstract string GenerateFullTextCore(string setName, ShaderFunction function);
-        protected abstract string FormatInvocationCore(string type, string method, InvocationParameterInfo[] parameterInfos);
+        protected abstract string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos);
 
         internal string CorrectLiteral(string literal)
         {

+ 11 - 1
src/ShaderGen/ParameterDefinition.cs

@@ -1,4 +1,7 @@
-namespace ShaderGen
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace ShaderGen
 {
     public class ParameterDefinition
     {
@@ -10,5 +13,12 @@
             Name = name;
             Type = type;
         }
+
+        public static ParameterDefinition GetParameterDefinition(Compilation compilation, ParameterSyntax ps)
+        {
+            string fullType = compilation.GetSemanticModel(ps.SyntaxTree).GetFullTypeName(ps.Type);
+            string name = ps.Identifier.ToFullString();
+            return new ParameterDefinition(name, new TypeReference(fullType));
+        }
     }
 }

+ 34 - 2
src/ShaderGen/ShaderFunction.cs

@@ -1,10 +1,14 @@
 using System;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis;
+using System.Collections.Generic;
+using System.Linq;
 
 namespace ShaderGen
 {
     public class ShaderFunction
     {
+        public string DeclaringType { get; }
         public string Name { get; }
         public TypeReference ReturnType { get; }
         public ParameterDefinition[] Parameters { get; }
@@ -12,11 +16,13 @@ namespace ShaderGen
         public ShaderFunctionType Type { get; }
 
         public ShaderFunction(
+            string declaringType,
             string name,
             TypeReference returnType,
             ParameterDefinition[] parameters,
             ShaderFunctionType type)
         {
+            DeclaringType = declaringType;
             Name = name;
             ReturnType = returnType;
             Parameters = parameters;
@@ -25,14 +31,40 @@ namespace ShaderGen
 
         public ShaderFunction WithReturnType(TypeReference returnType)
         {
-            return new ShaderFunction(Name, returnType, Parameters, Type);
+            return new ShaderFunction(DeclaringType, Name, returnType, Parameters, Type);
         }
 
         public ShaderFunction WithParameter(int index, TypeReference typeReference)
         {
             ParameterDefinition[] parameters = (ParameterDefinition[])Parameters.Clone();
             parameters[index] = new ParameterDefinition(parameters[index].Name, typeReference);
-            return new ShaderFunction(Name, ReturnType, parameters, Type);
+            return new ShaderFunction(DeclaringType, Name, ReturnType, parameters, Type);
+        }
+
+        public static ShaderFunction GetShaderFunction(Compilation compilation, MethodDeclarationSyntax node)
+        {
+            string functionName = node.Identifier.ToFullString();
+            List<ParameterDefinition> parameters = new List<ParameterDefinition>();
+            foreach (ParameterSyntax ps in node.ParameterList.Parameters)
+            {
+                parameters.Add(ParameterDefinition.GetParameterDefinition(compilation, ps));
+            }
+
+            TypeReference returnType = new TypeReference(compilation.GetSemanticModel(node.SyntaxTree).GetFullTypeName(node.ReturnType));
+
+            bool isVertexShader, isFragmentShader = false;
+            isVertexShader = Utilities.GetMethodAttributes(node, "VertexShader").Any();
+            if (!isVertexShader)
+            {
+                isFragmentShader = Utilities.GetMethodAttributes(node, "FragmentShader").Any();
+            }
+
+            ShaderFunctionType type = isVertexShader
+                ? ShaderFunctionType.VertexEntryPoint : isFragmentShader
+                ? ShaderFunctionType.FragmentEntryPoint : ShaderFunctionType.Normal;
+
+            string nestedTypePrefix = Utilities.GetFullNestedTypePrefix(node, out bool nested);
+            return new ShaderFunction(nestedTypePrefix, functionName, returnType, parameters.ToArray(), type);
         }
     }
 }

+ 4 - 0
src/ShaderGen/ShaderGenerator.cs

@@ -185,6 +185,10 @@ namespace ShaderGen
                 language.InitContext(ss.Name);
             }
 
+            FunctionCallGraphDiscoverer fcgd = new FunctionCallGraphDiscoverer(_compilation, ss.VertexShader);
+            fcgd.GenerateFullGraph();
+            TypeAndMethodName[] orderedCalls = fcgd.GetOrderedCallList();
+
             ShaderSyntaxWalker walker = new ShaderSyntaxWalker(_compilation, _languages.ToArray(), ss);
             foreach (SyntaxTree tree in treesToVisit)
             {

+ 20 - 8
src/ShaderGen/ShaderMethodVisitor.cs

@@ -5,18 +5,21 @@ using System;
 using System.Linq;
 using Microsoft.CodeAnalysis;
 using System.Collections.Generic;
+using System.Diagnostics;
 
 namespace ShaderGen
 {
     public partial class ShaderMethodVisitor : CSharpSyntaxVisitor<string>
     {
         protected readonly Compilation _compilation;
+        protected readonly string _setName;
         protected readonly LanguageBackend _backend;
         protected readonly ShaderFunction _shaderFunction;
 
-        public ShaderMethodVisitor(Compilation compilation, ShaderFunction shaderFunction, LanguageBackend backend)
+        public ShaderMethodVisitor(Compilation compilation, string setName, ShaderFunction shaderFunction, LanguageBackend backend)
         {
             _compilation = compilation;
+            _setName = setName;
             _shaderFunction = shaderFunction;
             _backend = backend;
         }
@@ -57,7 +60,11 @@ namespace ShaderGen
         protected virtual string GetFunctionDeclStr()
         {
             string returnType = _backend.CSharpToShaderType(_shaderFunction.ReturnType.Name);
-            return $"{returnType} {_shaderFunction.Name}({GetParameterDeclList()})";
+            string fullDeclType = _backend.CSharpToShaderType(_shaderFunction.DeclaringType);
+            string funcName = _shaderFunction.IsEntryPoint
+                ? _shaderFunction.Name
+                : fullDeclType + "_" + _shaderFunction.Name;
+            return $"{returnType} {funcName}({GetParameterDeclList()})";
         }
 
         public override string VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
@@ -106,7 +113,7 @@ namespace ShaderGen
                 SymbolInfo symbolInfo = GetModel(node).GetSymbolInfo(ins);
                 string type = symbolInfo.Symbol.ContainingType.ToDisplayString();
                 string method = symbolInfo.Symbol.Name;
-                return _backend.FormatInvocation(type, method, parameterInfos);
+                return _backend.FormatInvocation(_setName, type, method, parameterInfos);
             }
             else if (node.Expression is MemberAccessExpressionSyntax maes)
             {
@@ -118,22 +125,27 @@ namespace ShaderGen
                     List<InvocationParameterInfo> pis = new List<InvocationParameterInfo>();
                     if (ims.IsExtensionMethod)
                     {
+                        string identifier = null;
                         // Extension method invocation, ie: swizzle:
-                        if (!(maes.Expression is MemberAccessExpressionSyntax subExpression))
+                        if (maes.Expression is MemberAccessExpressionSyntax subExpression)
                         {
-                            throw new NotImplementedException(
-                                "Extension methods should have MemberAccessExpressionSyntax expressions.");
+                            identifier = Visit(subExpression);
+                        }
+                        else if (maes.Expression is IdentifierNameSyntax identNameSyntax)
+                        {
+                            identifier = Visit(identNameSyntax);
                         }
 
+                        Debug.Assert(identifier != null);
                         // Might need FullTypeName here too.
                         pis.Add(new InvocationParameterInfo()
                         {
-                            Identifier = Visit(subExpression)
+                            Identifier = identifier
                         });
                     }
 
                     pis.AddRange(GetParameterInfos(node.ArgumentList));
-                    return _backend.FormatInvocation(containingType, methodName, pis.ToArray());
+                    return _backend.FormatInvocation(_setName, containingType, methodName, pis.ToArray());
                 }
 
                 throw new NotImplementedException();

+ 3 - 9
src/ShaderGen/ShaderSyntaxWalker.cs

@@ -32,7 +32,7 @@ namespace ShaderGen
             List<ParameterDefinition> parameters = new List<ParameterDefinition>();
             foreach (ParameterSyntax ps in node.ParameterList.Parameters)
             {
-                parameters.Add(GetParameterDefinition(ps));
+                parameters.Add(ParameterDefinition.GetParameterDefinition(_compilation, ps));
             }
 
             TypeReference returnType = new TypeReference(GetModel(node).GetFullTypeName(node.ReturnType));
@@ -48,18 +48,12 @@ namespace ShaderGen
                 ? ShaderFunctionType.VertexEntryPoint : isFragmentShader
                 ? ShaderFunctionType.FragmentEntryPoint : ShaderFunctionType.Normal;
 
-            ShaderFunction sf = new ShaderFunction(functionName, returnType, parameters.ToArray(), type);
+            string nestedTypePrefix = Utilities.GetFullNestedTypePrefix(node, out bool nested);
+            ShaderFunction sf = new ShaderFunction(nestedTypePrefix, functionName, returnType, parameters.ToArray(), type);
             ShaderFunctionAndBlockSyntax sfab = new ShaderFunctionAndBlockSyntax(sf, node.Body);
             foreach (LanguageBackend b in _backends) { b.AddFunction(_shaderSet.Name, sfab); }
         }
 
-        private ParameterDefinition GetParameterDefinition(ParameterSyntax ps)
-        {
-            string fullType = GetModel(ps).GetFullTypeName(ps.Type);
-            string name = ps.Identifier.ToFullString();
-            return new ParameterDefinition(name, new TypeReference(fullType));
-        }
-
         public override void VisitStructDeclaration(StructDeclarationSyntax node)
         {
             TryGetStructDefinition(GetModel(node), node, out var sd);

+ 11 - 2
src/ShaderGen/TypeAndMethodName.cs

@@ -1,6 +1,8 @@
-namespace ShaderGen
+using System;
+
+namespace ShaderGen
 {
-    internal class TypeAndMethodName
+    internal class TypeAndMethodName : IEquatable<TypeAndMethodName>
     {
         public string TypeName;
         public string MethodName;
@@ -24,5 +26,12 @@
             typeAndMethodName = new TypeAndMethodName { TypeName = typeName, MethodName = parts[parts.Length - 1] };
             return true;
         }
+
+        public bool Equals(TypeAndMethodName other)
+        {
+            return TypeName == other.TypeName && MethodName == other.MethodName;
+        }
+
+        public override string ToString() => FullName;
     }
 }