Ver Fonte

Merge pull request #46 from OpenSAGE/const-values

Allow const variable declaration and usage
Eric Mellino há 7 anos atrás
pai
commit
67f9afa6ad

+ 1 - 1
src/ShaderGen.Tests/ShaderModelTests.cs

@@ -97,7 +97,7 @@ namespace ShaderGen.Tests
             ShaderModel shaderModel = set.Model;
 
             Assert.Single(shaderModel.AllResources);
-            Assert.Equal(144, shaderModel.GetTypeSize(shaderModel.AllResources[0].ValueType));
+            Assert.Equal(208, shaderModel.GetTypeSize(shaderModel.AllResources[0].ValueType));
         }
 
         [Fact]

+ 13 - 3
src/ShaderGen.Tests/TestAssets/PointLightInfoStructs.cs

@@ -6,12 +6,19 @@ namespace TestShaders
     public class PointLightTestShaders
     {
         public PointLightsInfo PointLights;
+        public const int MyOtherConst = 20;
 
         [VertexShader] SystemPosition4 VS(Position4 input)
         {
+            const int MyConst = 10;
+
             SystemPosition4 output;
-            PointLightInfo a = PointLights.PointLights[0];
-            Vector4 position = new Vector4(a.Position.XYZ(), 10);
+            Vector4 color = Vector4.Zero;
+            for (int i = 0; i < PointLightsInfo.MaxLights; i++)
+            {
+                PointLightInfo a = PointLights.PointLights[i];
+                color += new Vector4(a.Color, MyConst);
+            }
             output.Position = input.Position;
             return output;
         }
@@ -27,8 +34,11 @@ namespace TestShaders
 
     public struct PointLightsInfo
     {
+        public const int MaxLights = 4;
+
         public int NumActiveLights;
         public Vector3 _padding;
-        [ArraySize(4)] public PointLightInfo[] PointLights;
+        [ArraySize(MaxLights)] public PointLightInfo[] PointLights;
+        [ArraySize(2)] public PointLightInfo[] PointLights2;
     }
 }

+ 16 - 1
src/ShaderGen/ShaderMethodVisitor.cs

@@ -161,6 +161,11 @@ namespace ShaderGen
 
         public override string VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
         {
+            if (node.Modifiers.Any(x => x.IsKind(SyntaxKind.ConstKeyword)))
+            {
+                return " "; // TODO: Can't return empty string here because of validation check in VisitBlock
+            }
+
             return Visit(node.Declaration);
         }
 
@@ -211,7 +216,7 @@ namespace ShaderGen
                 }
 
                 // Static member access
-                if (symbol.Kind == SymbolKind.Property)
+                if (symbol.Kind == SymbolKind.Property || symbol.Kind == SymbolKind.Field)
                 {
                     return Visit(node.Name);
                 }
@@ -415,6 +420,16 @@ namespace ShaderGen
             {
                 return _backend.FormatInvocation(_setName, containingTypeName, symbol.Name, Array.Empty<InvocationParameterInfo>());
             }
+            else if (symbol is IFieldSymbol fs && fs.HasConstantValue)
+            {
+                // TODO: Share code to format constant values.
+                return fs.ConstantValue.ToString();
+            }
+            else if (symbol is ILocalSymbol ls && ls.HasConstantValue)
+            {
+                // TODO: Share code to format constant values.
+                return ls.ConstantValue.ToString();
+            }
 
             string mapped = _backend.CSharpToShaderIdentifierName(symbolInfo);
             return _backend.CorrectIdentifier(mapped);

+ 22 - 28
src/ShaderGen/ShaderSyntaxWalker.cs

@@ -60,7 +60,7 @@ namespace ShaderGen
             List<FieldDefinition> fields = new List<FieldDefinition>();
             foreach (MemberDeclarationSyntax member in node.Members)
             {
-                if (member is FieldDeclarationSyntax fds)
+                if (member is FieldDeclarationSyntax fds && !fds.Modifiers.Any(x => x.IsKind(SyntaxKind.ConstKeyword)))
                 {
                     VariableDeclarationSyntax varDecl = fds.Declaration;
                     foreach (VariableDeclaratorSyntax vds in varDecl.Variables)
@@ -70,7 +70,7 @@ namespace ShaderGen
                         int arrayElementCount = 0;
                         if (isArray)
                         {
-                            arrayElementCount = GetArrayCountValue(vds);
+                            arrayElementCount = GetArrayCountValue(vds, model);
                         }
 
                         TypeReference tr = new TypeReference(typeName, model.GetTypeInfo(varDecl.Type));
@@ -85,7 +85,7 @@ namespace ShaderGen
             return true;
         }
 
-        private static int GetArrayCountValue(VariableDeclaratorSyntax vds)
+        private static int GetArrayCountValue(VariableDeclaratorSyntax vds, SemanticModel semanticModel)
         {
             AttributeSyntax[] arraySizeAttrs = Utilities.GetMemberAttributes(vds, "ArraySize");
             if (arraySizeAttrs.Length != 1)
@@ -94,43 +94,27 @@ namespace ShaderGen
                     "Array fields in structs must have a constant size specified by an ArraySizeAttribute.");
             }
             AttributeSyntax arraySizeAttr = arraySizeAttrs[0];
-            return GetAttributeArgumentIntValue(arraySizeAttr, 0);
+            return GetAttributeArgumentIntValue(arraySizeAttr, 0, semanticModel);
         }
 
-        private static int GetAttributeArgumentIntValue(AttributeSyntax attr, int index)
+        private static int GetAttributeArgumentIntValue(AttributeSyntax attr, int index, SemanticModel semanticModel)
         {
             if (attr.ArgumentList.Arguments.Count < index + 1)
             {
                 throw new ShaderGenerationException(
                     "Too few arguments in attribute " + attr.ToFullString() + ". Required + " + (index + 1));
             }
-            string fullArg0 = attr.ArgumentList.Arguments[index].ToFullString();
-            if (int.TryParse(fullArg0, out int ret))
-            {
-                return ret;
-            }
-            else
-            {
-                throw new ShaderGenerationException("Incorrectly formatted attribute: " + attr.ToFullString());
-            }
+            return GetConstantIntFromExpression(attr.ArgumentList.Arguments[index].Expression, semanticModel);
         }
 
-        private static uint GetAttributeArgumentUIntValue(AttributeSyntax attr, int index)
+        private static int GetConstantIntFromExpression(ExpressionSyntax expression, SemanticModel semanticModel)
         {
-            if (attr.ArgumentList.Arguments.Count < index + 1)
-            {
-                throw new ShaderGenerationException(
-                    "Too few arguments in attribute " + attr.ToFullString() + ". Required + " + (index + 1));
-            }
-            string fullArg0 = attr.ArgumentList.Arguments[index].ToFullString();
-            if (uint.TryParse(fullArg0, out uint ret))
+            var constantValue = semanticModel.GetConstantValue(expression);
+            if (!constantValue.HasValue)
             {
-                return ret;
-            }
-            else
-            {
-                throw new ShaderGenerationException("Incorrectly formatted attribute: " + attr.ToFullString());
+                throw new ShaderGenerationException("Expression did not contain a constant value: " + expression.ToFullString());
             }
+            return (int) constantValue.Value;
         }
 
         private static SemanticType GetSemanticType(VariableDeclaratorSyntax vds)
@@ -196,6 +180,16 @@ namespace ShaderGen
             return attrs.Length == 1;
         }
 
+        public override void VisitFieldDeclaration(FieldDeclarationSyntax node)
+        {
+            if (node.Modifiers.Any(x => x.IsKind(SyntaxKind.ConstKeyword)))
+            {
+                return;
+            }
+
+            base.VisitFieldDeclaration(node);
+        }
+
         public override void VisitVariableDeclaration(VariableDeclarationSyntax node)
         {
             if (node.Variables.Count != 1)
@@ -221,7 +215,7 @@ namespace ShaderGen
             int set = 0; // Default value if not otherwise specified.
             if (GetResourceDecl(node, out AttributeSyntax resourceSetDecl))
             {
-                set = GetAttributeArgumentIntValue(resourceSetDecl, 0);
+                set = GetAttributeArgumentIntValue(resourceSetDecl, 0, GetModel(node));
             }
 
             int resourceBinding = GetAndIncrementBinding(set);