Prechádzať zdrojové kódy

Add full support for OpenGL ES shading language version 300

Eric Mellino 7 rokov pred
rodič
commit
4136d6d9c3

+ 6 - 0
src/ShaderGen.App/Program.cs

@@ -125,12 +125,14 @@ namespace ShaderGen.App
 
             HlslBackend hlsl = new HlslBackend(compilation);
             Glsl330Backend glsl330 = new Glsl330Backend(compilation);
+            GlslEs300Backend glsles300 = new GlslEs300Backend(compilation);
             Glsl450Backend glsl450 = new Glsl450Backend(compilation);
             MetalBackend metal = new MetalBackend(compilation);
             LanguageBackend[] languages = new LanguageBackend[]
             {
                 hlsl,
                 glsl330,
+                glsles300,
                 glsl450,
                 metal,
             };
@@ -447,6 +449,10 @@ namespace ShaderGen.App
             {
                 return "330.glsl";
             }
+            else if (lang.GetType() == typeof(GlslEs300Backend))
+            {
+                return "300.glsles";
+            }
             else if (lang.GetType() == typeof(Glsl450Backend))
             {
                 return "450.glsl";

+ 33 - 1
src/ShaderGen.Tests/ShaderGeneratorTests.cs

@@ -114,6 +114,38 @@ namespace ShaderGen.Tests
             }
         }
 
+        [Theory]
+        [MemberData(nameof(ShaderSets))]
+        public void GlslEs300EndToEnd(string vsName, string fsName)
+        {
+            Compilation compilation = TestUtil.GetTestProjectCompilation();
+            GlslEs300Backend backend = new GlslEs300Backend(compilation);
+            ShaderGenerator sg = new ShaderGenerator(
+                compilation,
+                vsName,
+                fsName,
+                backend);
+
+            ShaderGenerationResult result = sg.GenerateShaders();
+            IReadOnlyList<GeneratedShaderSet> sets = result.GetOutput(backend);
+            Assert.Equal(1, sets.Count);
+            GeneratedShaderSet set = sets[0];
+            ShaderModel shaderModel = set.Model;
+
+            if (vsName != null)
+            {
+                ShaderFunction vsFunction = shaderModel.GetFunction(vsName);
+                string vsCode = set.VertexShaderCode;
+                GlsLangValidatorTool.AssertCompilesCode(vsCode, "vert", false);
+            }
+            if (fsName != null)
+            {
+                ShaderFunction fsFunction = shaderModel.GetFunction(fsName);
+                string fsCode = set.FragmentShaderCode;
+                GlsLangValidatorTool.AssertCompilesCode(fsCode, "frag", false);
+            }
+        }
+
         [Theory]
         [MemberData(nameof(ShaderSets))]
         public void Glsl450EndToEnd(string vsName, string fsName)
@@ -259,7 +291,7 @@ namespace ShaderGen.Tests
 
         [Theory]
         [MemberData(nameof(ErrorSets))]
-        public void ExceptedException(string vsName, string fsName)
+        public void ExpectedException(string vsName, string fsName)
         {
             Compilation compilation = TestUtil.GetTestProjectCompilation();
             Glsl330Backend backend = new Glsl330Backend(compilation);

+ 12 - 3
src/ShaderGen/Glsl/GlslBackendBase.cs

@@ -22,7 +22,8 @@ namespace ShaderGen.Glsl
             StringBuilder fb = new StringBuilder();
             foreach (FieldDefinition field in sd.Fields)
             {
-                fb.Append(CSharpToShaderType(field.Type.Name.Trim()));
+                string fieldTypeStr = GetStructureFieldType(field);
+                fb.Append(fieldTypeStr);
                 fb.Append(' ');
                 fb.Append(CorrectIdentifier(field.Name.Trim()));
                 int arrayCount = field.ArrayElementCount;
@@ -39,6 +40,10 @@ namespace ShaderGen.Glsl
             sb.AppendLine();
         }
 
+        protected virtual string GetStructureFieldType(FieldDefinition field)
+        {
+            return CSharpToShaderType(field.Type.Name.Trim());
+        }
 
         protected override MethodProcessResult GenerateFullTextCore(string setName, ShaderFunction function)
         {
@@ -55,8 +60,6 @@ namespace ShaderGen.Glsl
 
             ValidateRequiredSemantics(setName, entryPoint.Function, function.Type);
 
-            WriteVersionHeader(function, sb);
-
             StructureDefinition[] orderedStructures
                 = StructureDependencyGraph.GetOrderedStructureList(Compilation, context.Structures);
 
@@ -123,6 +126,12 @@ namespace ShaderGen.Glsl
 
             WriteMainFunction(setName, sb, entryPoint.Function);
 
+            // Append version last because it relies on information from parsing the shader.
+            StringBuilder versionSB = new StringBuilder();
+            WriteVersionHeader(function, versionSB);
+
+            sb.Insert(0, versionSB.ToString());
+
             return new MethodProcessResult(sb.ToString(), resourcesUsed);
         }
 

+ 143 - 0
src/ShaderGen/Glsl/GlslEs300Backend.cs

@@ -0,0 +1,143 @@
+using System.Diagnostics;
+using System.Text;
+using Microsoft.CodeAnalysis;
+
+namespace ShaderGen.Glsl
+{
+    public class GlslEs300Backend : GlslBackendBase
+    {
+        public GlslEs300Backend(Compilation compilation) : base(compilation)
+        {
+        }
+
+        protected override string CSharpToShaderTypeCore(string fullType)
+        {
+            return GlslKnownTypes.GetMappedName(fullType, false)
+                .Replace(".", "_")
+                .Replace("+", "_");
+        }
+
+        protected override void WriteVersionHeader(ShaderFunction function, StringBuilder sb)
+        {
+            bool useVersion320 = function.UsesTexture2DMS;
+            string versionNumber = useVersion320 ? "320" : "300";
+            string version = $"{versionNumber} es";
+            sb.AppendLine($"#version {version}");
+            sb.AppendLine($"precision mediump float;");
+            sb.AppendLine($"precision mediump int;");
+            sb.AppendLine($"precision mediump sampler2D;");
+            if (useVersion320)
+            {
+                sb.AppendLine($"precision mediump sampler2DMS;");
+            }
+            sb.AppendLine();
+        }
+
+        protected override void WriteSampler(StringBuilder sb, ResourceDefinition rd)
+        {
+        }
+
+        protected override void WriteTexture2D(StringBuilder sb, ResourceDefinition rd)
+        {
+            sb.AppendLine($"uniform sampler2D {CorrectIdentifier(rd.Name)};");
+            sb.AppendLine();
+        }
+
+        protected override void WriteTextureCube(StringBuilder sb, ResourceDefinition rd)
+        {
+            sb.AppendLine($"uniform samplerCube {CorrectIdentifier(rd.Name)};");
+            sb.AppendLine();
+        }
+
+        protected override void WriteTexture2DMS(StringBuilder sb, ResourceDefinition rd)
+        {
+            sb.AppendLine($"uniform sampler2DMS {CorrectIdentifier(rd.Name)};");
+            sb.AppendLine();
+        }
+
+        protected override void WriteUniform(StringBuilder sb, ResourceDefinition rd)
+        {
+            sb.AppendLine($"layout(std140) uniform {rd.Name}");
+            sb.AppendLine("{");
+            sb.AppendLine($"    {CSharpToShaderType(rd.ValueType.Name)} field_{CorrectIdentifier(rd.Name.Trim())};");
+            sb.AppendLine("};");
+            sb.AppendLine();
+        }
+
+        protected override void WriteStructuredBuffer(StringBuilder sb, ResourceDefinition rd, bool isReadOnly)
+        {
+            string readOnlyStr = isReadOnly ? " readonly" : " ";
+            sb.AppendLine($"layout(std140){readOnlyStr} buffer {rd.Name}");
+            sb.AppendLine("{");
+            sb.AppendLine($"    {CSharpToShaderType(rd.ValueType.Name)} field_{CorrectIdentifier(rd.Name.Trim())}[];");
+            sb.AppendLine("};");
+        }
+
+        protected override string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
+        {
+            return GlslEs300KnownFunctions.TranslateInvocation(type, method, parameterInfos);
+        }
+
+        internal override string CorrectBinaryExpression(
+            string leftExpr,
+            string leftExprType,
+            string operatorToken,
+            string rightExpr,
+            string rightExprType)
+        {
+            if (IsIntegerType(leftExprType) && !IsIntegerType(rightExprType))
+            {
+                leftExpr = $"float({leftExpr})";
+            }
+            else if (IsIntegerType(rightExprType) && !IsIntegerType(leftExprType))
+            {
+                rightExpr = $"float({rightExpr})";
+            }
+
+            return $"{leftExpr} {operatorToken} {rightExpr}";
+        }
+
+        private bool IsIntegerType(string exprType)
+        {
+            return exprType == "System.Int32" || exprType == "System.UInt32";
+        }
+
+        internal override string CorrectAssignedValue(string leftExprType, string value, string valueType)
+        {
+            if (valueType == "System.Int32" && leftExprType != "System.Int32")
+            {
+                value = $"float({value})";
+            }
+
+            return $"{value}";
+        }
+
+        protected override void WriteInOutVariable(
+            StringBuilder sb,
+            bool isInVar,
+            bool isVertexStage,
+            string normalizedType,
+            string normalizedIdentifier,
+            int index)
+        {
+            string qualifier = isInVar ? "in" : "out";
+            string identifier;
+            if ((isVertexStage && isInVar) || (!isVertexStage && !isInVar))
+            {
+                identifier = normalizedIdentifier;
+            }
+            else
+            {
+                Debug.Assert(isVertexStage || isInVar);
+                identifier = $"fsin_{index}";
+            }
+
+            sb.AppendLine($"{qualifier} {normalizedType} {identifier};");
+        }
+
+        protected override void EmitGlPositionCorrection(StringBuilder sb)
+        {
+            sb.AppendLine($"        gl_Position.z = gl_Position.z * 2.0 - gl_Position.w;");
+        }
+    }
+}

+ 414 - 0
src/ShaderGen/Glsl/GlslEs300KnownFunctions.cs

@@ -0,0 +1,414 @@
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using ShaderGen.Hlsl;
+
+namespace ShaderGen.Glsl
+{
+    public static class GlslEs300KnownFunctions
+    {
+        private static Dictionary<string, TypeInvocationTranslator> s_mappings = GetMappings();
+
+        private static Dictionary<string, TypeInvocationTranslator> GetMappings()
+        {
+            Dictionary<string, TypeInvocationTranslator> ret = new Dictionary<string, TypeInvocationTranslator>();
+
+            Dictionary<string, InvocationTranslator> builtinMappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "Abs", SimpleNameTranslator("abs") },
+                { "Pow", SimpleNameFloatParameterTranslator("pow") },
+                { "Acos", SimpleNameTranslator("acos") },
+                { "Cos", SimpleNameTranslator("cos") },
+                { "Frac", SimpleNameTranslator("fract") },
+                { "Lerp", SimpleNameTranslator("mix") },
+                { "Sin", SimpleNameTranslator("sin") },
+                { "Tan", SimpleNameTranslator("tan") },
+                { "Clamp", SimpleNameFloatParameterTranslator("clamp") },
+                { "Mod", SimpleNameFloatParameterTranslator("mod") },
+                { "Mul", MatrixMul },
+                { "Sample", Sample2D },
+                { "Load", Load },
+                { "Discard", Discard },
+                { "Saturate", Saturate },
+                { nameof(ShaderBuiltins.ClipToTextureCoordinates), ClipToTextureCoordinates },
+                { "VertexID", VertexID },
+                { "InstanceID", InstanceID },
+                { "DispatchThreadID", DispatchThreadID },
+                { "GroupThreadID", GroupThreadID },
+                { "IsFrontFace", IsFrontFace },
+            };
+            ret.Add("ShaderGen.ShaderBuiltins", new DictionaryTypeInvocationTranslator(builtinMappings));
+
+            Dictionary<string, InvocationTranslator> v2Mappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "Abs", SimpleNameTranslator("abs") },
+                { "Add", BinaryOpTranslator("+") },
+                { "Clamp", SimpleNameTranslator("clamp") },
+                { "Cos", SimpleNameTranslator("cos") },
+                { "Distance", SimpleNameTranslator("distance") },
+                { "DistanceSquared", DistanceSquared },
+                { "Divide", BinaryOpTranslator("/") },
+                { "Dot", SimpleNameTranslator("dot") },
+                { "Lerp", SimpleNameTranslator("mix") },
+                { "Max", SimpleNameTranslator("max") },
+                { "Min", SimpleNameTranslator("min") },
+                { "Multiply", BinaryOpTranslator("*") },
+                { "Negate", Negate },
+                { "Normalize", SimpleNameTranslator("normalize") },
+                { "Reflect", SimpleNameTranslator("reflect") },
+                { "Sin", SimpleNameTranslator("sin") },
+                { "SquareRoot", SimpleNameTranslator("sqrt") },
+                { "Subtract", BinaryOpTranslator("-") },
+                { "Length", SimpleNameTranslator("length") },
+                { "LengthSquared", LengthSquared },
+                { "ctor", VectorCtor },
+                { "Zero", VectorStaticAccessor },
+                { "One", VectorStaticAccessor },
+                { "UnitX", VectorStaticAccessor },
+                { "UnitY", VectorStaticAccessor },
+                { "Transform", Vector2Transform }
+            };
+            ret.Add("System.Numerics.Vector2", new DictionaryTypeInvocationTranslator(v2Mappings));
+
+            Dictionary<string, InvocationTranslator> v3Mappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "Abs", SimpleNameTranslator("abs") },
+                { "Add", BinaryOpTranslator("+") },
+                { "Clamp", SimpleNameTranslator("clamp") },
+                { "Cos", SimpleNameTranslator("cos") },
+                { "Cross", SimpleNameTranslator("cross") },
+                { "Distance", SimpleNameTranslator("distance") },
+                { "DistanceSquared", DistanceSquared },
+                { "Divide", BinaryOpTranslator("/") },
+                { "Dot", SimpleNameTranslator("dot") },
+                { "Lerp", SimpleNameTranslator("mix") },
+                { "Max", SimpleNameTranslator("max") },
+                { "Min", SimpleNameTranslator("min") },
+                { "Multiply", BinaryOpTranslator("*") },
+                { "Negate", Negate },
+                { "Normalize", SimpleNameTranslator("normalize") },
+                { "Reflect", SimpleNameTranslator("reflect") },
+                { "Sin", SimpleNameTranslator("sin") },
+                { "SquareRoot", SimpleNameTranslator("sqrt") },
+                { "Subtract", BinaryOpTranslator("-") },
+                { "Length", SimpleNameTranslator("length") },
+                { "LengthSquared", LengthSquared },
+                { "ctor", VectorCtor },
+                { "Zero", VectorStaticAccessor },
+                { "One", VectorStaticAccessor },
+                { "UnitX", VectorStaticAccessor },
+                { "UnitY", VectorStaticAccessor },
+                { "UnitZ", VectorStaticAccessor },
+                { "Transform", Vector3Transform }
+            };
+            ret.Add("System.Numerics.Vector3", new DictionaryTypeInvocationTranslator(v3Mappings));
+
+            Dictionary<string, InvocationTranslator> v4Mappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "Abs", SimpleNameTranslator("abs") },
+                { "Add", BinaryOpTranslator("+") },
+                { "Clamp", SimpleNameTranslator("clamp") },
+                { "Cos", SimpleNameTranslator("cos") },
+                { "Distance", SimpleNameTranslator("distance") },
+                { "DistanceSquared", DistanceSquared },
+                { "Divide", BinaryOpTranslator("/") },
+                { "Dot", SimpleNameTranslator("dot") },
+                { "Lerp", SimpleNameTranslator("mix") },
+                { "Max", SimpleNameTranslator("max") },
+                { "Min", SimpleNameTranslator("min") },
+                { "Multiply", BinaryOpTranslator("*") },
+                { "Negate", Negate },
+                { "Normalize", SimpleNameTranslator("normalize") },
+                { "Reflect", SimpleNameTranslator("reflect") },
+                { "Sin", SimpleNameTranslator("sin") },
+                { "SquareRoot", SimpleNameTranslator("sqrt") },
+                { "Subtract", BinaryOpTranslator("-") },
+                { "Length", SimpleNameTranslator("length") },
+                { "LengthSquared", LengthSquared },
+                { "ctor", VectorCtor },
+                { "Zero", VectorStaticAccessor },
+                { "One", VectorStaticAccessor },
+                { "UnitX", VectorStaticAccessor },
+                { "UnitY", VectorStaticAccessor },
+                { "UnitZ", VectorStaticAccessor },
+                { "UnitW", VectorStaticAccessor },
+                { "Transform", Vector4Transform }
+            };
+            ret.Add("System.Numerics.Vector4", new DictionaryTypeInvocationTranslator(v4Mappings));
+
+            Dictionary<string, InvocationTranslator> m4x4Mappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "ctor", MatrixCtor }
+            };
+            ret.Add("System.Numerics.Matrix4x4", new DictionaryTypeInvocationTranslator(m4x4Mappings));
+
+            Dictionary<string, InvocationTranslator> mathfMappings = new Dictionary<string, InvocationTranslator>()
+            {
+                { "Cos", SimpleNameTranslator("cos") },
+                { "Max", SimpleNameTranslator("max") },
+                { "Min", SimpleNameTranslator("min") },
+                { "Pow", SimpleNameFloatParameterTranslator("pow") },
+                { "Sin", SimpleNameTranslator("sin") },
+            };
+            ret.Add("System.MathF", new DictionaryTypeInvocationTranslator(mathfMappings));
+
+            ret.Add("ShaderGen.ShaderSwizzle", new SwizzleTranslator());
+
+            return ret;
+        }
+
+        private static string MatrixCtor(string typeName, string methodName, InvocationParameterInfo[] p)
+        {
+            string paramList = string.Join(", ",
+                p[0].Identifier, p[4].Identifier, p[8].Identifier, p[12].Identifier,
+                p[1].Identifier, p[5].Identifier, p[9].Identifier, p[13].Identifier,
+                p[2].Identifier, p[6].Identifier, p[10].Identifier, p[14].Identifier,
+                p[3].Identifier, p[7].Identifier, p[11].Identifier, p[15].Identifier);
+
+            return $"mat4({paramList})";
+        }
+
+        public static string TranslateInvocation(string type, string method, InvocationParameterInfo[] parameters)
+        {
+            if (s_mappings.TryGetValue(type, out var dict))
+            {
+                if (dict.GetTranslator(method, parameters, out InvocationTranslator mappedValue))
+                {
+                    return mappedValue(type, method, parameters);
+                }
+            }
+
+            throw new ShaderGenerationException($"Reference to unknown function: {type}.{method}");
+        }
+
+        private static InvocationTranslator SimpleNameTranslator(string nameTarget)
+        {
+            return (type, method, parameters) =>
+            {
+                return $"{nameTarget}({InvocationParameterInfo.GetInvocationParameterList(parameters)})";
+            };
+        }
+
+        private static InvocationTranslator SimpleNameFloatParameterTranslator(string nameTarget)
+        {
+            return (type, method, parameters) =>
+            {
+                IEnumerable<string> castedParams = parameters.Select(ipi =>
+                {
+                    if (ipi.FullTypeName == "float" || ipi.FullTypeName == "int" || ipi.FullTypeName == "uint")
+                    {
+                        return $"float({ipi.Identifier})";
+                    }
+                    else
+                    {
+                        return ipi.Identifier;
+                    }
+
+                });
+                return $"{nameTarget}({string.Join(", ", castedParams)})";
+            };
+        }
+
+        private static string LengthSquared(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"dot({parameters[0].Identifier}, {parameters[0].Identifier})";
+        }
+
+        private static string Negate(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"-{parameters[0].Identifier}";
+        }
+
+        private static string DistanceSquared(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"dot({parameters[0].Identifier} - {parameters[1].Identifier}, {parameters[0].Identifier} - {parameters[1].Identifier})";
+        }
+
+        private static InvocationTranslator BinaryOpTranslator(string op)
+        {
+            return (type, method, parameters) =>
+            {
+                return $"{parameters[0].Identifier} {op} {parameters[1].Identifier}";
+            };
+        }
+
+        private static string MatrixMul(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"{parameters[0].Identifier} * {parameters[1].Identifier}";
+        }
+
+        private static string Sample2D(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"texture({parameters[0].Identifier}, {parameters[2].Identifier})";
+        }
+
+        private static string Load(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"texelFetch({parameters[0].Identifier}, ivec2({parameters[2].Identifier}), {parameters[3].Identifier})";
+        }
+
+        private static string Discard(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"discard;";
+        }
+
+        private static string Saturate(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            if (parameters.Length == 1)
+            {
+                return $"clamp({parameters[0].Identifier}, 0.f, 1.f)";
+            }
+            else
+            {
+                throw new ShaderGenerationException("Unhandled number of arguments to ShaderBuiltins.Discard.");
+            }
+        }
+
+        private static string ClipToTextureCoordinates(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            string target = parameters[0].Identifier;
+            return $"vec2(({target}.x / {target}.w) / 2.f + 0.5f, ({target}.y / {target}.w) / 2.f + 0.5f)";
+        }
+
+        private static string VertexID(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return "uint(gl_VertexID)";
+        }
+
+        private static string InstanceID(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return "uint(gl_InstanceID)";
+        }
+
+        private static string DispatchThreadID(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return "gl_GlobalInvocationID";
+        }
+
+        private static string GroupThreadID(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return "gl_LocalInvocationID";
+        }
+
+        private static string IsFrontFace(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return "gl_FrontFacing";
+        }
+
+        private static string VectorCtor(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            GetVectorTypeInfo(typeName, out string shaderType, out int elementCount);
+            string paramList;
+            if (parameters.Length == 0)
+            {
+                paramList = string.Join(", ", Enumerable.Repeat("0", elementCount));
+            }
+            else if (parameters.Length == 1)
+            {
+                paramList = string.Join(", ", Enumerable.Repeat(parameters[0].Identifier, elementCount));
+            }
+            else
+            {
+                StringBuilder sb = new StringBuilder();
+                for (int i = 0; i < parameters.Length; i++)
+                {
+                    InvocationParameterInfo ipi = parameters[i];
+                    sb.Append(ipi.Identifier);
+
+                    if (i != parameters.Length - 1)
+                    {
+                        sb.Append(", ");
+                    }
+                }
+
+                paramList = sb.ToString();
+            }
+
+            return $"{shaderType}({paramList})";
+        }
+
+        private static string VectorStaticAccessor(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            Debug.Assert(parameters.Length == 0);
+            GetVectorTypeInfo(typeName, out string shaderType, out int elementCount);
+            if (methodName == "Zero")
+            {
+                return $"{shaderType}({string.Join(", ", Enumerable.Repeat("0", elementCount))})";
+            }
+            else if (methodName == "One")
+            {
+                return $"{shaderType}({string.Join(", ", Enumerable.Repeat("1", elementCount))})";
+            }
+            else if (methodName == "UnitX")
+            {
+                string paramList;
+                if (elementCount == 2) { paramList = "1, 0"; }
+                else if (elementCount == 3) { paramList = "1, 0, 0"; }
+                else { paramList = "1, 0, 0, 0"; }
+                return $"{shaderType}({paramList})";
+            }
+            else if (methodName == "UnitY")
+            {
+                string paramList;
+                if (elementCount == 2) { paramList = "0, 1"; }
+                else if (elementCount == 3) { paramList = "0, 1, 0"; }
+                else { paramList = "0, 1, 0, 0"; }
+                return $"{shaderType}({paramList})";
+            }
+            else if (methodName == "UnitZ")
+            {
+                string paramList;
+                if (elementCount == 3) { paramList = "0, 0, 1"; }
+                else { paramList = "0, 0, 1, 0"; }
+                return $"{shaderType}({paramList})";
+            }
+            else if (methodName == "UnitW")
+            {
+                return $"{shaderType}(0, 0, 0, 1)";
+            }
+            else
+            {
+                Debug.Fail("Invalid static vector accessor: " + methodName);
+                return null;
+            }
+        }
+
+        private static string Vector2Transform(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"({parameters[1].Identifier} * vec4({parameters[0].Identifier}, 0.f, 1.f)).xy";
+        }
+
+        private static string Vector3Transform(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            return $"({parameters[1].Identifier} * vec4({parameters[0].Identifier}, 1.f)).xyz";
+        }
+
+        private static string Vector4Transform(string typeName, string methodName, InvocationParameterInfo[] parameters)
+        {
+            string vecParam;
+            if (parameters[0].FullTypeName == "System.Numerics.Vector2")
+            {
+                vecParam = $"vec4({parameters[0].Identifier}, 0.f, 1.f)";
+            }
+            else if (parameters[0].FullTypeName == "System.Numerics.Vector3")
+            {
+                vecParam = $"vec4({parameters[0].Identifier}, 1.f)";
+            }
+            else
+            {
+                vecParam = parameters[0].Identifier;
+            }
+
+            return $"{parameters[1].Identifier} * {vecParam}";
+        }
+
+        private static void GetVectorTypeInfo(string name, out string shaderType, out int elementCount)
+        {
+            if (name == "System.Numerics.Vector2") { shaderType = "vec2"; elementCount = 2; }
+            else if (name == "System.Numerics.Vector3") { shaderType = "vec3"; elementCount = 3; }
+            else if (name == "System.Numerics.Vector4") { shaderType = "vec4"; elementCount = 4; }
+            else { throw new ShaderGenerationException("VectorCtor translator was called on an invalid type: " + name); }
+        }
+    }
+}

+ 0 - 1
src/ShaderGen/Glsl/GlslKnownTypes.cs

@@ -30,7 +30,6 @@ namespace ShaderGen.Glsl
             { "ShaderGen.TextureCubeResource", "samplerCube" },
         };
 
-
         private static readonly Dictionary<string, string> s_knownTypesVulkan = new Dictionary<string, string>()
         {
             { "ShaderGen.Texture2DResource", "texture2D" },

+ 18 - 0
src/ShaderGen/LanguageBackend.cs

@@ -106,6 +106,14 @@ namespace ShaderGen
                 computeResources);
         }
 
+        internal virtual string CorrectAssignedValue(
+            string leftExprType,
+            string rightExpr,
+            string rightExprType)
+        {
+            return rightExpr;
+        }
+
         private void ForceTypeDiscovery(string setName, TypeReference fd)
         {
             if (ShaderPrimitiveTypes.IsPrimitiveType(fd.Name))
@@ -268,6 +276,16 @@ namespace ShaderGen
             return result;
         }
 
+        internal virtual string CorrectBinaryExpression(
+            string leftExpr,
+            string leftExprType,
+            string operatorToken,
+            string rightExpr,
+            string rightExprType)
+        {
+            return $"{leftExpr} {operatorToken} {rightExpr}";
+        }
+
         internal virtual string CorrectFieldAccess(SymbolInfo symbolInfo)
         {
             string mapped = CSharpToShaderIdentifierName(symbolInfo);

+ 1 - 0
src/ShaderGen/ShaderFunction.cs

@@ -15,6 +15,7 @@
         public bool UsesDispatchThreadID { get; internal set; }
         public bool UsesGroupThreadID { get; internal set; }
         public bool UsesFrontFace { get; internal set; }
+        public bool UsesTexture2DMS { get; internal set; }
 
         public ShaderFunction(
             string declaringType,

+ 41 - 19
src/ShaderGen/ShaderMethodVisitor.cs

@@ -83,7 +83,7 @@ namespace ShaderGen
 
         public override string VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
         {
-            return Visit(node.Declaration) + ";";
+            return Visit(node.Declaration);
         }
 
         public override string VisitEqualsValueClause(EqualsValueClauseSyntax node)
@@ -98,14 +98,15 @@ namespace ShaderGen
             {
                 throw new ShaderGenerationException(
                     "Modulus operator not supported in shader functions. Use ShaderBuiltins.Mod instead.");
-
             }
 
-            return base.Visit(node.Left)
-                + " "
-                + token
-                + base.Visit(node.Right)
-                + ";";
+            string leftExpr = base.Visit(node.Left);
+            string leftExprType = Utilities.GetFullTypeName(GetModel(node), node.Left);
+            string rightExpr = base.Visit(node.Right);
+            string rightExprType = Utilities.GetFullTypeName(GetModel(node), node.Right);
+
+            string assignedValue = _backend.CorrectAssignedValue(leftExprType, rightExpr, rightExprType);
+            return $"{leftExpr} {token} {assignedValue};";
         }
 
         public override string VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
@@ -233,9 +234,13 @@ namespace ShaderGen
                     "Modulus operator not supported in shader functions. Use ShaderBuiltins.Mod instead.");
             }
 
-            return Visit(node.Left) + " "
-                + node.OperatorToken + " "
-                + Visit(node.Right);
+            string leftExpr = Visit(node.Left);
+            string leftExprType = Utilities.GetFullTypeName(GetModel(node), node.Left);
+            string operatorToken = node.OperatorToken.ToString();
+            string rightExpr = Visit(node.Right);
+            string rightExprType = Utilities.GetFullTypeName(GetModel(node), node.Right);
+
+            return _backend.CorrectBinaryExpression(leftExpr, leftExprType, operatorToken, rightExpr, rightExprType);
         }
 
         public override string VisitParenthesizedExpression(ParenthesizedExpressionSyntax node)
@@ -286,6 +291,8 @@ namespace ShaderGen
                 string symbolName = symbol.Name;
                 ResourceDefinition referencedResource = _backend.GetContext(_setName).Resources.Single(rd => rd.Name == symbolName);
                 _resourcesUsed.Add(referencedResource);
+                _shaderFunction.UsesTexture2DMS |= referencedResource.ValueType.Name == "ShaderGen.Texture2DMSResource";
+
                 return _backend.CorrectFieldAccess(symbolInfo);
             }
             else if (symbol.Kind == SymbolKind.Property)
@@ -367,7 +374,7 @@ namespace ShaderGen
             string declaration = Visit(node.Declaration);
             string incrementers = string.Join(", ", node.Incrementors.Select(es => Visit(es)));
             string condition = Visit(node.Condition);
-            sb.AppendLine($"for ({declaration}; {condition}; {incrementers})");
+            sb.AppendLine($"for ({declaration} {condition}; {incrementers})");
             sb.AppendLine(Visit(node.Statement));
             return sb.ToString();
         }
@@ -425,17 +432,32 @@ namespace ShaderGen
                 throw new NotImplementedException();
             }
 
-            string csName = _compilation.GetSemanticModel(node.Type.SyntaxTree).GetFullTypeName(node.Type);
-            string mappedType = _backend.CSharpToShaderType(csName);
-            string initializerStr = Visit(node.Variables[0].Initializer);
-            string result = mappedType + " "
-                + _backend.CorrectIdentifier(node.Variables[0].Identifier.ToString());
-            if (!string.IsNullOrEmpty(initializerStr))
+            StringBuilder sb = new StringBuilder();
+
+            string varType = _compilation.GetSemanticModel(node.Type.SyntaxTree).GetFullTypeName(node.Type);
+            string mappedType = _backend.CSharpToShaderType(varType);
+
+            sb.Append(mappedType);
+            sb.Append(' ');
+            VariableDeclaratorSyntax varDeclarator = node.Variables[0];
+            string identifier = _backend.CorrectIdentifier(varDeclarator.Identifier.ToString());
+            sb.Append(identifier);
+
+            if (varDeclarator.Initializer != null)
             {
-                result += " " + initializerStr;
+                sb.Append(' ');
+                sb.Append(varDeclarator.Initializer.EqualsToken.ToString());
+                sb.Append(' ');
+
+                string rightExpr = base.Visit(varDeclarator.Initializer.Value);
+                string rightExprType = Utilities.GetFullTypeName(GetModel(node), varDeclarator.Initializer.Value);
+
+                sb.Append(_backend.CorrectAssignedValue(varType, rightExpr, rightExprType));
             }
 
-            return result;
+            sb.Append(';');
+
+            return sb.ToString();
         }
 
         public override string VisitElementAccessExpression(ElementAccessExpressionSyntax node)