Kaynağa Gözat

Finish up support for compute shaders and (RW)StructuredBuffers.

Eric Mellino 8 yıl önce
ebeveyn
işleme
12da2a42b2
31 değiştirilmiş dosya ile 361 ekleme ve 114 silme
  1. 45 9
      src/ShaderGen.App/Program.cs
  2. 1 1
      src/ShaderGen.Build/ShaderGen.Build.csproj
  3. 5 2
      src/ShaderGen.Primitives/ComputeShaderAttribute.cs
  4. 12 0
      src/ShaderGen.Primitives/ComputeShaderSetAttribute.cs
  5. 6 0
      src/ShaderGen.Primitives/RWStructuredBuffer.cs
  6. 1 1
      src/ShaderGen.Primitives/ShaderGen.Primitives.csproj
  7. 1 0
      src/ShaderGen.Primitives/StructuredBuffer.cs
  8. 17 1
      src/ShaderGen.Tests/ShaderGeneratorTests.cs
  9. 1 1
      src/ShaderGen.Tests/ShaderSetDiscovererTests.cs
  10. 3 2
      src/ShaderGen.Tests/TestAssets/BuiltInVariables.cs
  11. 2 0
      src/ShaderGen.Tests/TestAssets/ShaderSets.cs
  12. 18 0
      src/ShaderGen.Tests/TestAssets/SimpleCompute.cs
  13. 8 2
      src/ShaderGen/GeneratedShaderSet.cs
  14. 12 2
      src/ShaderGen/Glsl330Backend.cs
  15. 14 3
      src/ShaderGen/Glsl450Backend.cs
  16. 24 10
      src/ShaderGen/GlslBackendBase.cs
  17. 7 0
      src/ShaderGen/GlslKnownIdentifiers.cs
  18. 22 1
      src/ShaderGen/HlslBackend.cs
  19. 7 0
      src/ShaderGen/HlslKnownIdentifiers.cs
  20. 2 0
      src/ShaderGen/LanguageBackend.cs
  21. 2 2
      src/ShaderGen/ResourceDefinition.cs
  22. 5 39
      src/ShaderGen/ShaderFunction.cs
  23. 0 12
      src/ShaderGen/ShaderFunctionAndBlockSyntax.cs
  24. 1 1
      src/ShaderGen/ShaderGen.csproj
  25. 25 12
      src/ShaderGen/ShaderGenerator.cs
  26. 19 1
      src/ShaderGen/ShaderMethodVisitor.cs
  27. 2 0
      src/ShaderGen/ShaderResourceKind.cs
  28. 13 1
      src/ShaderGen/ShaderSetDiscoverer.cs
  29. 7 0
      src/ShaderGen/ShaderSetInfo.cs
  30. 78 11
      src/ShaderGen/ShaderSyntaxWalker.cs
  31. 1 0
      src/ShaderGen/ShaderType.cs

+ 45 - 9
src/ShaderGen.App/Program.cs

@@ -172,7 +172,12 @@ namespace ShaderGen.App
                         string vsOutName = name + "-vertex." + extension;
                         string vsOutPath = Path.Combine(outputPath, vsOutName);
                         File.WriteAllText(vsOutPath, set.VertexShaderCode, outputEncoding);
-                        bool succeeded = CompileCode(lang, vsOutPath, set.VertexFunction.Name, true, out string genPath);
+                        bool succeeded = CompileCode(
+                            lang,
+                            vsOutPath,
+                            set.VertexFunction.Name,
+                            ShaderFunctionType.VertexEntryPoint,
+                            out string genPath);
                         if (succeeded)
                         {
                             generatedFilePaths.Add(genPath);
@@ -187,7 +192,12 @@ namespace ShaderGen.App
                         string fsOutName = name + "-fragment." + extension;
                         string fsOutPath = Path.Combine(outputPath, fsOutName);
                         File.WriteAllText(fsOutPath, set.FragmentShaderCode, outputEncoding);
-                        bool succeeded = CompileCode(lang, fsOutPath, set.FragmentFunction.Name, false, out string genPath);
+                        bool succeeded = CompileCode(
+                            lang,
+                            fsOutPath,
+                            set.FragmentFunction.Name,
+                            ShaderFunctionType.FragmentEntryPoint,
+                            out string genPath);
                         if (succeeded)
                         {
                             generatedFilePaths.Add(genPath);
@@ -197,6 +207,26 @@ namespace ShaderGen.App
                             generatedFilePaths.Add(fsOutPath);
                         }
                     }
+                    if (set.ComputeShaderCode != null)
+                    {
+                        string csOutName = name + "-compute." + extension;
+                        string csOutPath = Path.Combine(outputPath, csOutName);
+                        File.WriteAllText(csOutPath, set.ComputeShaderCode, outputEncoding);
+                        bool succeeded = CompileCode(
+                            lang,
+                            csOutPath,
+                            set.ComputeFunction.Name,
+                            ShaderFunctionType.ComputeEntryPoint,
+                            out string genPath);
+                        if (succeeded)
+                        {
+                            generatedFilePaths.Add(genPath);
+                        }
+                        if (!succeeded || listAllFiles)
+                        {
+                            generatedFilePaths.Add(csOutPath);
+                        }
+                    }
                 }
             }
 
@@ -205,16 +235,16 @@ namespace ShaderGen.App
             return 0;
         }
 
-        private static bool CompileCode(LanguageBackend lang, string shaderPath, string entryPoint, bool isVertex, out string path)
+        private static bool CompileCode(LanguageBackend lang, string shaderPath, string entryPoint, ShaderFunctionType type, out string path)
         {
             Type langType = lang.GetType();
             if (langType == typeof(HlslBackend) && IsFxcAvailable())
             {
-                return CompileHlsl(shaderPath, entryPoint, isVertex, out path);
+                return CompileHlsl(shaderPath, entryPoint, type, out path);
             }
             else if (langType == typeof(Glsl450Backend) && IsGlslangValidatorAvailable())
             {
-                return CompileSpirv(shaderPath, entryPoint, isVertex, out path);
+                return CompileSpirv(shaderPath, entryPoint, type, out path);
             }
             else
             {
@@ -223,12 +253,15 @@ namespace ShaderGen.App
             }
         }
 
-        private static bool CompileHlsl(string shaderPath, string entryPoint, bool isVertex, out string path)
+        private static bool CompileHlsl(string shaderPath, string entryPoint, ShaderFunctionType type, out string path)
         {
             try
             {
+                string profile = type == ShaderFunctionType.VertexEntryPoint ? "vs_5_0"
+                    : type == ShaderFunctionType.FragmentEntryPoint ? "ps_5_0"
+                    : "cs_5_0";
                 string outputPath = shaderPath + ".bytes";
-                string args = $"/T {(isVertex ? "vs_5_0" : "ps_5_0")} /E {entryPoint} {shaderPath} /Fo {outputPath}";
+                string args = $"/T {profile} /E {entryPoint} {shaderPath} /Fo {outputPath}";
                 string fxcPath = FindFxcExe();
                 ProcessStartInfo psi = new ProcessStartInfo(fxcPath, args);
                 psi.RedirectStandardOutput = true;
@@ -255,10 +288,13 @@ namespace ShaderGen.App
             return false;
         }
 
-        private static bool CompileSpirv(string shaderPath, string entryPoint, bool isVertex, out string path)
+        private static bool CompileSpirv(string shaderPath, string entryPoint, ShaderFunctionType type, out string path)
         {
+            string stage = type == ShaderFunctionType.VertexEntryPoint ? "vert"
+                : type == ShaderFunctionType.FragmentEntryPoint ? "frag"
+                : "comp";
             string outputPath = shaderPath + ".spv";
-            string args = $"-V -S {(isVertex ? "vert" : "frag")} {shaderPath} -o {outputPath}";
+            string args = $"-V -S {stage} {shaderPath} -o {outputPath}";
             try
             {
 

+ 1 - 1
src/ShaderGen.Build/ShaderGen.Build.csproj

@@ -9,7 +9,7 @@
     <PackageId>ShaderGen.Build</PackageId>
     <Description>Build-time plugin which generates shader code during a post-build event.</Description>
     <PackageTags>Shader GLSL HLSL SPIR-V Graphics OpenGL Vulkan Direct3D Game</PackageTags>
-    <PackageVersion>1.0.26</PackageVersion>
+    <PackageVersion>1.0.27</PackageVersion>
   </PropertyGroup>
 
   <ItemGroup>

+ 5 - 2
src/ShaderGen.Primitives/ComputeShaderAttribute.cs

@@ -1,6 +1,9 @@
-namespace ShaderGen
+using System;
+
+namespace ShaderGen
 {
-    public class ComputeShaderAttribute
+    [AttributeUsage(AttributeTargets.Method)]
+    public class ComputeShaderAttribute : Attribute
     {
         public uint GroupCountX { get; }
         public uint GroupCountY { get; }

+ 12 - 0
src/ShaderGen.Primitives/ComputeShaderSetAttribute.cs

@@ -0,0 +1,12 @@
+using System;
+
+namespace ShaderGen
+{
+    [AttributeUsage(AttributeTargets.Assembly)]
+    public class ComputeShaderSetAttribute : Attribute
+    {
+        public ComputeShaderSetAttribute(string setName, string computeShaderFunctionName)
+        {
+        }
+    }
+}

+ 6 - 0
src/ShaderGen.Primitives/RWStructuredBuffer.cs

@@ -7,5 +7,11 @@
             get => throw new ShaderBuiltinException();
             set => throw new ShaderBuiltinException();
         }
+
+        public T this[uint index]
+        {
+            get => throw new ShaderBuiltinException();
+            set => throw new ShaderBuiltinException();
+        }
     }
 }

+ 1 - 1
src/ShaderGen.Primitives/ShaderGen.Primitives.csproj

@@ -3,7 +3,7 @@
   <PropertyGroup>
     <TargetFramework>netstandard2.0</TargetFramework>
     <RootNamespace>ShaderGen</RootNamespace>
-    <AssemblyVersion>1.0.10</AssemblyVersion>
+    <AssemblyVersion>1.0.11</AssemblyVersion>
     <!-- Package stuff -->
     <PackageId>ShaderGen.Primitives</PackageId>
     <Description>C# attributes and primitives for generating shader code via ShaderGen.</Description>

+ 1 - 0
src/ShaderGen.Primitives/StructuredBuffer.cs

@@ -3,5 +3,6 @@
     public class StructuredBuffer<T> where T : struct
     {
         public T this[int index] => throw new ShaderBuiltinException();
+        public T this[uint index] => throw new ShaderBuiltinException();
     }
 }

+ 17 - 1
src/ShaderGen.Tests/ShaderGeneratorTests.cs

@@ -33,6 +33,11 @@ namespace ShaderGen.Tests
             yield return new object[] { "TestShaders.BuiltInVariables.VS", null };
         }
 
+        private static IEnumerable<object[]> ComputeShaders()
+        {
+            yield return new object[] { "TestShaders.SimpleCompute.CS" };
+        }
+
         [Theory]
         [MemberData(nameof(ShaderSets))]
         public void HlslEndToEnd(string vsName, string fsName)
@@ -172,7 +177,18 @@ namespace ShaderGen.Tests
                             GlsLangValidatorTool.AssertCompilesCode(set.FragmentShaderCode, "frag", is450);
                         }
                     }
-
+                    if (set.ComputeFunction != null)
+                    {
+                        if (backend is HlslBackend)
+                        {
+                            FxcTool.AssertCompilesCode(set.ComputeShaderCode, "cs_5_0", set.ComputeFunction.Name);
+                        }
+                        else
+                        {
+                            bool is450 = backend is Glsl450Backend;
+                            GlsLangValidatorTool.AssertCompilesCode(set.ComputeShaderCode, "comp", is450);
+                        }
+                    }
                 }
             }
         }

+ 1 - 1
src/ShaderGen.Tests/ShaderSetDiscovererTests.cs

@@ -16,7 +16,7 @@ namespace ShaderGen.Tests
             ShaderGenerator sg = new ShaderGenerator(compilation, new[] { backend });
             ShaderGenerationResult result = sg.GenerateShaders();
             IReadOnlyList<GeneratedShaderSet> hlslSets = result.GetOutput(backend);
-            Assert.Equal(3, hlslSets.Count);
+            Assert.Equal(4, hlslSets.Count);
             GeneratedShaderSet set = hlslSets[0];
             Assert.Equal("VertexAndFragment", set.Name);
             ShaderModel shaderModel = set.Model;

+ 3 - 2
src/ShaderGen.Tests/TestAssets/BuiltInVariables.cs

@@ -1,5 +1,6 @@
 using ShaderGen;
 using System.Numerics;
+using static ShaderGen.ShaderBuiltins;
 
 namespace TestShaders
 {
@@ -9,10 +10,10 @@ namespace TestShaders
         SystemPosition4 VS()
         {
             uint vertexID = ShaderBuiltins.VertexID;
-            uint instanceID = ShaderBuiltins.InstanceID;
+            uint instanceID = InstanceID;
 
             SystemPosition4 output;
-            output.Position = new Vector4(vertexID, instanceID, 0, 1);
+            output.Position = new Vector4(vertexID, instanceID, ShaderBuiltins.VertexID, 1);
             return output;
         }
     }

+ 2 - 0
src/ShaderGen.Tests/TestAssets/ShaderSets.cs

@@ -8,3 +8,5 @@
 
 [assembly: ShaderSet("VertexOnly", "TestShaders.TestVertexShader.VS", null)]
 [assembly: ShaderSet("FragmentOnly", null, "TestShaders.VertexAndFragment.FS")]
+
+[assembly: ComputeShaderSet("SimpleCompute", "TestShaders.SimpleCompute.CS")]

+ 18 - 0
src/ShaderGen.Tests/TestAssets/SimpleCompute.cs

@@ -0,0 +1,18 @@
+using ShaderGen;
+using System.Numerics;
+using static ShaderGen.ShaderBuiltins;
+
+namespace TestShaders
+{
+    public class SimpleCompute
+    {
+        public StructuredBuffer<Vector4> StructuredInput;
+        public RWStructuredBuffer<Vector4> StructuredInOut;
+
+        [ComputeShader(1, 1, 1)]
+        public void CS()
+        {
+            StructuredInOut[DispatchThreadID.X] = StructuredInput[DispatchThreadID.Y];
+        }
+    }
+}

+ 8 - 2
src/ShaderGen/GeneratedShaderSet.cs

@@ -13,28 +13,34 @@ namespace ShaderGen
         public string Name { get; }
         public string VertexShaderCode { get; }
         public string FragmentShaderCode { get; }
+        public string ComputeShaderCode { get; }
         public ShaderFunction VertexFunction { get; }
         public ShaderFunction FragmentFunction { get; }
+        public ShaderFunction ComputeFunction { get; }
         public ShaderModel Model { get; }
 
         public GeneratedShaderSet(
             string name,
             string vsCode,
             string fsCode,
+            string csCode,
             ShaderFunction vertexfunction,
             ShaderFunction fragmentFunction,
+            ShaderFunction computeFunction,
             ShaderModel model)
         {
-            if (string.IsNullOrEmpty(vsCode) && string.IsNullOrEmpty(fsCode))
+            if (string.IsNullOrEmpty(vsCode) && string.IsNullOrEmpty(fsCode) && string.IsNullOrEmpty(csCode))
             {
-                throw new ShaderGenerationException("At least one of vsCode or fsCode must be non-empty");
+                throw new ShaderGenerationException("At least one of vsCode, fsCode, or csCode must be non-empty");
             }
 
             Name = name;
             VertexShaderCode = vsCode;
             FragmentShaderCode = fsCode;
+            ComputeShaderCode = csCode;
             VertexFunction = vertexfunction;
             FragmentFunction = fragmentFunction;
+            ComputeFunction = computeFunction;
             Model = model;
         }
     }

+ 12 - 2
src/ShaderGen/Glsl330Backend.cs

@@ -17,9 +17,10 @@ namespace ShaderGen
                 .Replace("+", "_");
         }
 
-        protected override void WriteVersionHeader(StringBuilder sb)
+        protected override void WriteVersionHeader(ShaderFunction function, StringBuilder sb)
         {
-            sb.AppendLine("#version 330 core");
+            string version = function.Type == ShaderFunctionType.ComputeEntryPoint ? "430" : "330 core";
+            sb.AppendLine($"#version {version}");
             sb.AppendLine();
         }
 
@@ -54,6 +55,15 @@ namespace ShaderGen
             sb.AppendLine();
         }
 
+        protected override void WriteStructuredBuffer(StringBuilder sb, ResourceDefinition rd, bool isReadOnly)
+        {
+            string readOnlyStr = isReadOnly ? " readonly" : " ";
+            sb.AppendLine($"layout(std140){readOnlyStr} buffer {rd.Name}");
+            sb.AppendLine("{");
+            sb.AppendLine($"    {CSharpToShaderType(rd.ValueType.Name)} field_{CorrectIdentifier(rd.Name.Trim())}[];");
+            sb.AppendLine("};");
+        }
+
         protected override string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos)
         {
             return Glsl330KnownFunctions.TranslateInvocation(type, method, parameterInfos);

+ 14 - 3
src/ShaderGen/Glsl450Backend.cs

@@ -17,7 +17,7 @@ namespace ShaderGen
                 .Replace("+", "_");
         }
 
-        protected override void WriteVersionHeader(StringBuilder sb)
+        protected override void WriteVersionHeader(ShaderFunction function, StringBuilder sb)
         {
             sb.AppendLine("#version 450");
             sb.AppendLine("#extension GL_ARB_separate_shader_objects : enable");
@@ -34,6 +34,16 @@ namespace ShaderGen
             sb.AppendLine();
         }
 
+        protected override void WriteStructuredBuffer(StringBuilder sb, ResourceDefinition rd, bool isReadOnly)
+        {
+            string layout = FormatLayoutStr(rd, "std140");
+            string readOnlyStr = isReadOnly ? " readonly" : " ";
+            sb.AppendLine($"{layout}{readOnlyStr} buffer {rd.Name}");
+            sb.AppendLine("{");
+            sb.AppendLine($"    {CSharpToShaderType(rd.ValueType.Name)} field_{CorrectIdentifier(rd.Name.Trim())}[];");
+            sb.AppendLine("};");
+        }
+
         protected override void WriteSampler(StringBuilder sb, ResourceDefinition rd)
         {
             sb.Append(FormatLayoutStr(rd));
@@ -98,9 +108,10 @@ namespace ShaderGen
             return Glsl450KnownFunctions.TranslateInvocation(type, method, parameterInfos);
         }
 
-        private string FormatLayoutStr(ResourceDefinition rd)
+        private string FormatLayoutStr(ResourceDefinition rd, string storageSpec = null)
         {
-            return $"layout(set = {rd.Set}, binding = {rd.Binding})";
+            string storageSpecPart = storageSpec != null ? $"{storageSpec}, " : string.Empty;
+            return $"layout({storageSpecPart}set = {rd.Set}, binding = {rd.Binding})";
         }
 
         protected override void EmitGlPositionCorrection(StringBuilder sb)

+ 24 - 10
src/ShaderGen/GlslBackendBase.cs

@@ -1,10 +1,7 @@
 using Microsoft.CodeAnalysis;
-using System;
 using System.Text;
 using System.Collections.Generic;
 using System.Linq;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-using System.IO;
 using System.Diagnostics;
 
 namespace ShaderGen
@@ -12,6 +9,7 @@ namespace ShaderGen
     public abstract class GlslBackendBase : LanguageBackend
     {
         protected readonly HashSet<string> _uniformNames = new HashSet<string>();
+        protected readonly HashSet<string> _ssboNames = new HashSet<string>();
 
         public GlslBackendBase(Compilation compilation) : base(compilation)
         {
@@ -56,7 +54,7 @@ namespace ShaderGen
 
             ValidateRequiredSemantics(setName, entryPoint.Function, function.Type);
 
-            WriteVersionHeader(sb);
+            WriteVersionHeader(function, sb);
 
             StructureDefinition[] orderedStructures
                 = StructureDependencyGraph.GetOrderedStructureList(Compilation, context.Structures);
@@ -85,6 +83,10 @@ namespace ShaderGen
                     case ShaderResourceKind.Sampler:
                         WriteSampler(sb, rd);
                         break;
+                    case ShaderResourceKind.StructuredBuffer:
+                    case ShaderResourceKind.RWStructuredBuffer:
+                        WriteStructuredBuffer(sb, rd, rd.ResourceKind == ShaderResourceKind.StructuredBuffer);
+                        break;
                     default: throw new ShaderGenerationException("Illegal resource kind: " + rd.ResourceKind);
                 }
             }
@@ -184,7 +186,8 @@ namespace ShaderGen
             }
             else
             {
-                Debug.Assert(entryFunction.Type == ShaderFunctionType.FragmentEntryPoint);
+                Debug.Assert(entryFunction.Type == ShaderFunctionType.FragmentEntryPoint
+                    || entryFunction.Type == ShaderFunctionType.ComputeEntryPoint);
 
                 if (mappedReturnType == "vec4")
                 {
@@ -248,7 +251,7 @@ namespace ShaderGen
             }
             else
             {
-                sb.Append($"    {invocationStr};");
+                sb.AppendLine($"    {invocationStr};");
             }
 
             // Assign output fields to synthetic "out" variables with normalized "fsin_#" names.
@@ -277,9 +280,8 @@ namespace ShaderGen
                 sb.AppendLine($"    gl_Position = {CorrectIdentifier("output")}.{CorrectIdentifier(systemPositionField.Name)};");
                 EmitGlPositionCorrection(sb);
             }
-            else
+            else if (entryFunction.Type == ShaderFunctionType.FragmentEntryPoint)
             {
-                Debug.Assert(entryFunction.Type == ShaderFunctionType.FragmentEntryPoint);
                 if (mappedReturnType == "vec4")
                 {
                     sb.AppendLine($"    _outputColor_ = {CorrectIdentifier("output")};");
@@ -319,6 +321,11 @@ namespace ShaderGen
             {
                 _uniformNames.Add(rd.Name);
             }
+            if (rd.ResourceKind == ShaderResourceKind.StructuredBuffer
+                || rd.ResourceKind == ShaderResourceKind.RWStructuredBuffer)
+            {
+                _ssboNames.Add(rd.Name);
+            }
 
             base.AddResource(setName, rd);
         }
@@ -328,7 +335,7 @@ namespace ShaderGen
             string originalName = symbolInfo.Symbol.Name;
             string mapped = CSharpToShaderIdentifierName(symbolInfo);
             string identifier = CorrectIdentifier(mapped);
-            if (_uniformNames.Contains(originalName))
+            if (_uniformNames.Contains(originalName) || _ssboNames.Contains(originalName))
             {
                 return "field_" + identifier;
             }
@@ -338,17 +345,24 @@ namespace ShaderGen
             }
         }
 
+        internal override string GetComputeGroupCountsDeclaration(UInt3 groupCounts)
+        {
+            return $"layout(local_size_x = {groupCounts.X}, local_size_y = {groupCounts.Y}, local_size_z = {groupCounts.Z}) in;";
+        }
+
         private static readonly HashSet<string> s_glslKeywords = new HashSet<string>()
         {
             "input", "output",
         };
 
-        protected abstract void WriteVersionHeader(StringBuilder sb);
+        protected abstract void WriteVersionHeader(ShaderFunction function, StringBuilder sb);
         protected abstract void WriteUniform(StringBuilder sb, ResourceDefinition rd);
         protected abstract void WriteSampler(StringBuilder sb, ResourceDefinition rd);
         protected abstract void WriteTexture2D(StringBuilder sb, ResourceDefinition rd);
         protected abstract void WriteTextureCube(StringBuilder sb, ResourceDefinition rd);
         protected abstract void WriteTexture2DMS(StringBuilder sb, ResourceDefinition rd);
+        protected abstract void WriteStructuredBuffer(StringBuilder sb, ResourceDefinition rd, bool isReadOnly);
+
         protected abstract void WriteInOutVariable(
             StringBuilder sb,
             bool isInVar,

+ 7 - 0
src/ShaderGen/GlslKnownIdentifiers.cs

@@ -34,6 +34,13 @@ namespace ShaderGen
             };
             ret.Add("System.Numerics.Vector4", v4Mappings);
 
+            Dictionary<string, string> uint3Mappings = new Dictionary<string, string>()
+            {
+                { "X", "x" },
+                { "Y", "y" },
+                { "Z", "z" },
+            };
+            ret.Add("ShaderGen.UInt3", uint3Mappings);
 
             return ret;
         }

+ 22 - 1
src/ShaderGen/HlslBackend.cs

@@ -125,6 +125,16 @@ namespace ShaderGen
             sb.AppendLine();
         }
 
+        private void WriteStructuredBuffer(StringBuilder sb, ResourceDefinition rd, int binding)
+        {
+            sb.AppendLine($"StructuredBuffer<{CSharpToShaderType(rd.ValueType.Name)}> {CorrectIdentifier(rd.Name)}: register(t{binding});");
+        }
+
+        private void WriteRWStructuredBuffer(StringBuilder sb, ResourceDefinition rd, int binding)
+        {
+            sb.AppendLine($"RWStructuredBuffer<{CSharpToShaderType(rd.ValueType.Name)}> {CorrectIdentifier(rd.Name)}: register(u{binding});");
+        }
+
         protected override string GenerateFullTextCore(string setName, ShaderFunction function)
         {
             Debug.Assert(function.IsEntryPoint);
@@ -156,7 +166,7 @@ namespace ShaderGen
             List<ResourceDefinition[]> resourcesBySet = setContext.Resources.GroupBy(rd => rd.Set)
                 .Select(g => g.ToArray()).ToList();
 
-            int uniformBinding = 0, textureBinding = 0, samplerBinding = 0;
+            int uniformBinding = 0, textureBinding = 0, samplerBinding = 0, uavBinding = function.ColorOutputCount;
             int setIndex = 0;
             foreach (ResourceDefinition[] set in resourcesBySet)
             {
@@ -182,6 +192,12 @@ namespace ShaderGen
                         case ShaderResourceKind.Sampler:
                             WriteSampler(sb, rd, samplerBinding++);
                             break;
+                        case ShaderResourceKind.StructuredBuffer:
+                            WriteStructuredBuffer(sb, rd, textureBinding++);
+                            break;
+                        case ShaderResourceKind.RWStructuredBuffer:
+                            WriteRWStructuredBuffer(sb, rd, uavBinding++);
+                            break;
                         default: throw new ShaderGenerationException("Illegal resource kind: " + rd.ResourceKind);
                     }
                 }
@@ -228,6 +244,11 @@ namespace ShaderGen
             return HlslKnownIdentifiers.GetMappedIdentifier(typeName, identifier);
         }
 
+        internal override string GetComputeGroupCountsDeclaration(UInt3 groupCounts)
+        {
+            return $"[numthreads({groupCounts.X}, {groupCounts.Y}, {groupCounts.Z})]";
+        }
+
         private struct HlslSemanticTracker
         {
             public int Position;

+ 7 - 0
src/ShaderGen/HlslKnownIdentifiers.cs

@@ -34,6 +34,13 @@ namespace ShaderGen
             };
             ret.Add("System.Numerics.Vector4", v4Mappings);
 
+            Dictionary<string, string> uint3Mappings = new Dictionary<string, string>()
+            {
+                { "X", "x" },
+                { "Y", "y" },
+                { "Z", "z" },
+            };
+            ret.Add("ShaderGen.UInt3", uint3Mappings);
 
             return ret;
         }

+ 2 - 0
src/ShaderGen/LanguageBackend.cs

@@ -47,6 +47,7 @@ namespace ShaderGen
             return ret;
         }
 
+
         internal ShaderModel GetShaderModel(string setName)
         {
             BackendContext context = GetContext(setName);
@@ -266,6 +267,7 @@ namespace ShaderGen
         protected abstract string CSharpToIdentifierNameCore(string typeName, string identifier);
         protected abstract string GenerateFullTextCore(string setName, ShaderFunction function);
         protected abstract string FormatInvocationCore(string setName, string type, string method, InvocationParameterInfo[] parameterInfos);
+        internal abstract string GetComputeGroupCountsDeclaration(UInt3 groupCounts);
 
         internal string CorrectLiteral(string literal)
         {

+ 2 - 2
src/ShaderGen/ResourceDefinition.cs

@@ -8,12 +8,12 @@
         public TypeReference ValueType { get; }
         public ShaderResourceKind ResourceKind { get; }
 
-        public ResourceDefinition(string name, int set, int binding, TypeReference type, ShaderResourceKind kind)
+        public ResourceDefinition(string name, int set, int binding, TypeReference valueType, ShaderResourceKind kind)
         {
             Name = name;
             Set = set;
             Binding = binding;
-            ValueType = type;
+            ValueType = valueType;
             ResourceKind = kind;
         }
     }

+ 5 - 39
src/ShaderGen/ShaderFunction.cs

@@ -11,9 +11,11 @@ namespace ShaderGen
         public string DeclaringType { get; }
         public string Name { get; }
         public TypeReference ReturnType { get; }
+        public int ColorOutputCount { get; } // TODO: This always returns 0.
         public ParameterDefinition[] Parameters { get; }
         public bool IsEntryPoint => Type != ShaderFunctionType.Normal;
         public ShaderFunctionType Type { get; }
+        public UInt3 ComputeGroupCounts { get; }
         public bool UsesVertexID { get; internal set; }
         public bool UsesInstanceID { get; internal set; }
         public bool UsesDispatchThreadID { get; internal set; }
@@ -24,51 +26,15 @@ namespace ShaderGen
             string name,
             TypeReference returnType,
             ParameterDefinition[] parameters,
-            ShaderFunctionType type)
+            ShaderFunctionType type,
+            UInt3 computeGroupCounts)
         {
             DeclaringType = declaringType;
             Name = name;
             ReturnType = returnType;
             Parameters = parameters;
             Type = type;
-        }
-
-        public ShaderFunction WithReturnType(TypeReference returnType)
-        {
-            return new ShaderFunction(DeclaringType, Name, returnType, Parameters, Type);
-        }
-
-        public ShaderFunction WithParameter(int index, TypeReference typeReference)
-        {
-            ParameterDefinition[] parameters = (ParameterDefinition[])Parameters.Clone();
-            parameters[index] = new ParameterDefinition(parameters[index].Name, typeReference);
-            return new ShaderFunction(DeclaringType, Name, ReturnType, parameters, Type);
-        }
-
-        public static ShaderFunction GetShaderFunction(Compilation compilation, MethodDeclarationSyntax node)
-        {
-            string functionName = node.Identifier.ToFullString();
-            List<ParameterDefinition> parameters = new List<ParameterDefinition>();
-            foreach (ParameterSyntax ps in node.ParameterList.Parameters)
-            {
-                parameters.Add(ParameterDefinition.GetParameterDefinition(compilation, ps));
-            }
-
-            TypeReference returnType = new TypeReference(compilation.GetSemanticModel(node.SyntaxTree).GetFullTypeName(node.ReturnType));
-
-            bool isVertexShader, isFragmentShader = false;
-            isVertexShader = Utilities.GetMethodAttributes(node, "VertexShader").Any();
-            if (!isVertexShader)
-            {
-                isFragmentShader = Utilities.GetMethodAttributes(node, "FragmentShader").Any();
-            }
-
-            ShaderFunctionType type = isVertexShader
-                ? ShaderFunctionType.VertexEntryPoint : isFragmentShader
-                ? ShaderFunctionType.FragmentEntryPoint : ShaderFunctionType.Normal;
-
-            string nestedTypePrefix = Utilities.GetFullNestedTypePrefix(node, out bool nested);
-            return new ShaderFunction(nestedTypePrefix, functionName, returnType, parameters.ToArray(), type);
+            ComputeGroupCounts = computeGroupCounts;
         }
     }
 }

+ 0 - 12
src/ShaderGen/ShaderFunctionAndBlockSyntax.cs

@@ -13,17 +13,5 @@ namespace ShaderGen
             Function = function;
             Block = block;
         }
-
-        public ShaderFunctionAndBlockSyntax WithReturnType(TypeReference returnType)
-        {
-            ShaderFunction sf = Function.WithReturnType(returnType);
-            return new ShaderFunctionAndBlockSyntax(sf, Block);
-        }
-
-        public ShaderFunctionAndBlockSyntax WithParameter(int index, TypeReference type)
-        {
-            ShaderFunction sf = Function.WithParameter(index, type);
-            return new ShaderFunctionAndBlockSyntax(sf, Block);
-        }
     }
 }

+ 1 - 1
src/ShaderGen/ShaderGen.csproj

@@ -2,7 +2,7 @@
 
   <PropertyGroup>
     <TargetFramework>netstandard2.0</TargetFramework>
-    <AssemblyVersion>1.0.20</AssemblyVersion>
+    <AssemblyVersion>1.0.21</AssemblyVersion>
 
     <!-- Package stuff -->
     <PackageId>ShaderGen</PackageId>

+ 25 - 12
src/ShaderGen/ShaderGenerator.cs

@@ -173,6 +173,7 @@ namespace ShaderGen
         {
             TypeAndMethodName vertexFunctionName = ss.VertexShader;
             TypeAndMethodName fragmentFunctionName = ss.FragmentShader;
+            TypeAndMethodName computeFunctionName = ss.ComputeShader;
 
             HashSet<SyntaxTree> treesToVisit = new HashSet<SyntaxTree>();
             if (vertexFunctionName != null)
@@ -183,6 +184,10 @@ namespace ShaderGen
             {
                 GetTrees(treesToVisit, fragmentFunctionName.TypeName);
             }
+            if (computeFunctionName != null)
+            {
+                GetTrees(treesToVisit, computeFunctionName.TypeName);
+            }
 
             foreach (LanguageBackend language in _languages)
             {
@@ -204,20 +209,28 @@ namespace ShaderGen
                 ShaderFunction fsFunc = (ss.FragmentShader != null)
                     ? model.GetFunction(ss.FragmentShader.FullName)
                     : null;
+                ShaderFunction csFunc = (ss.ComputeShader != null)
+                    ? model.GetFunction(ss.ComputeShader.FullName)
+                    : null;
+                string vsCode = null;
+                string fsCode = null;
+                string csCode = null;
+                if (vsFunc != null)
+                {
+                    vsCode = language.GetCode(ss.Name, vsFunc);
+                }
+                if (fsFunc != null)
+                {
+                    fsCode = language.GetCode(ss.Name, fsFunc);
+                }
+                if (csFunc != null)
                 {
-                    string vsCode = null;
-                    string fsCode = null;
-                    if (vsFunc != null)
-                    {
-                        vsCode = language.GetCode(ss.Name, vsFunc);
-                    }
-                    if (fsFunc != null)
-                    {
-                        fsCode = language.GetCode(ss.Name, fsFunc);
-                    }
-
-                    result.AddShaderSet(language, new GeneratedShaderSet(ss.Name, vsCode, fsCode, vsFunc, fsFunc, model));
+                    csCode = language.GetCode(ss.Name, csFunc);
                 }
+
+                result.AddShaderSet(
+                    language,
+                    new GeneratedShaderSet(ss.Name, vsCode, fsCode, csCode, vsFunc, fsFunc, csFunc, model));
             }
         }
 

+ 19 - 1
src/ShaderGen/ShaderMethodVisitor.cs

@@ -37,6 +37,12 @@ namespace ShaderGen
             StringBuilder sb = new StringBuilder();
             string blockResult = VisitBlock(node); // Visit block first in order to discover builtin variables.
             string functionDeclStr = GetFunctionDeclStr();
+
+            if (_shaderFunction.Type == ShaderFunctionType.ComputeEntryPoint)
+            {
+                sb.AppendLine(_backend.GetComputeGroupCountsDeclaration(_shaderFunction.ComputeGroupCounts));
+            }
+
             sb.AppendLine(functionDeclStr);
             sb.AppendLine(blockResult);
             return sb.ToString();
@@ -107,6 +113,13 @@ namespace ShaderGen
             if (exprSymbol.Symbol.Kind == SymbolKind.NamedType)
             {
                 // Static member access
+                SymbolInfo symbolInfo = GetModel(node).GetSymbolInfo(node);
+                ISymbol symbol = symbolInfo.Symbol;
+                if (symbol.Kind == SymbolKind.Property)
+                {
+                    return Visit(node.Name);
+                }
+
                 string typeName = Utilities.GetFullMetadataName(exprSymbol.Symbol);
                 string targetName = Visit(node.Name);
                 return _backend.FormatInvocation(_setName, typeName, targetName, Array.Empty<InvocationParameterInfo>());
@@ -250,7 +263,7 @@ namespace ShaderGen
         {
             SymbolInfo symbolInfo = GetModel(node).GetSymbolInfo(node.Type);
             string fullName = Utilities.GetFullName(symbolInfo);
-             
+
             InvocationParameterInfo[] parameters = GetParameterInfos(node.ArgumentList);
             return _backend.FormatInvocation(_setName, fullName, "ctor", parameters);
         }
@@ -268,6 +281,11 @@ namespace ShaderGen
             {
                 return _backend.CorrectFieldAccess(symbolInfo);
             }
+            else if (symbol.Kind == SymbolKind.Property)
+            {
+                return _backend.FormatInvocation(_setName, containingTypeName, symbol.Name, Array.Empty<InvocationParameterInfo>());
+            }
+
             string mapped = _backend.CSharpToShaderIdentifierName(symbolInfo);
             return _backend.CorrectIdentifier(mapped);
         }

+ 2 - 0
src/ShaderGen/ShaderResourceKind.cs

@@ -7,5 +7,7 @@
         TextureCube,
         Texture2DMS,
         Sampler,
+        StructuredBuffer,
+        RWStructuredBuffer,
     }
 }

+ 13 - 1
src/ShaderGen/ShaderSetDiscoverer.cs

@@ -10,7 +10,19 @@ namespace ShaderGen
         private readonly List<ShaderSetInfo> _shaderSets = new List<ShaderSetInfo>();
         public override void VisitAttribute(AttributeSyntax node)
         {
-            if (node.Name.ToFullString().Contains("ShaderSet"))
+            // TODO: Only look at assembly-level attributes.
+            if (node.Name.ToFullString().Contains("ComputeShaderSet"))
+            {
+                string name = GetStringParam(node, 0);
+                string cs = GetStringParam(node, 1);
+                if (!TypeAndMethodName.Get(cs, out TypeAndMethodName csName))
+                {
+                    throw new ShaderGenerationException("ComputeShaderSetAttribute has an incomplete or invalid compute shader name.");
+                }
+
+                _shaderSets.Add(new ShaderSetInfo(name, csName));
+            }
+            else if (node.Name.ToFullString().Contains("ShaderSet"))
             {
                 string name = GetStringParam(node, 0);
 

+ 7 - 0
src/ShaderGen/ShaderSetInfo.cs

@@ -7,6 +7,7 @@ namespace ShaderGen
         public string Name { get; }
         public TypeAndMethodName VertexShader { get; }
         public TypeAndMethodName FragmentShader { get; }
+        public TypeAndMethodName ComputeShader { get; }
 
         public ShaderSetInfo(string name, TypeAndMethodName vs, TypeAndMethodName fs)
         {
@@ -19,5 +20,11 @@ namespace ShaderGen
             VertexShader = vs;
             FragmentShader = fs;
         }
+
+        public ShaderSetInfo(string name, TypeAndMethodName cs)
+        {
+            Name = name;
+            ComputeShader = cs;
+        }
     }
 }

+ 78 - 11
src/ShaderGen/ShaderSyntaxWalker.cs

@@ -38,19 +38,41 @@ namespace ShaderGen
 
             TypeReference returnType = new TypeReference(GetModel(node).GetFullTypeName(node.ReturnType));
 
-            bool isVertexShader, isFragmentShader = false;
-            isVertexShader = Utilities.GetMethodAttributes(node, "VertexShader").Any();
+            UInt3 computeGroupCounts = new UInt3();
+            bool isFragmentShader = false, isComputeShader = false;
+            bool isVertexShader = Utilities.GetMethodAttributes(node, "VertexShader").Any();
             if (!isVertexShader)
             {
                 isFragmentShader = Utilities.GetMethodAttributes(node, "FragmentShader").Any();
             }
+            if (!isVertexShader && !isFragmentShader)
+            {
+                AttributeSyntax computeShaderAttr = Utilities.GetMethodAttributes(node, "ComputeShader").FirstOrDefault();
+                if (computeShaderAttr != null)
+                {
+                    isComputeShader = true;
+                    computeGroupCounts.X = GetAttributeArgumentUIntValue(computeShaderAttr, 0);
+                    computeGroupCounts.Y = GetAttributeArgumentUIntValue(computeShaderAttr, 1);
+                    computeGroupCounts.Z = GetAttributeArgumentUIntValue(computeShaderAttr, 2);
+                }
+            }
 
             ShaderFunctionType type = isVertexShader
-                ? ShaderFunctionType.VertexEntryPoint : isFragmentShader
-                ? ShaderFunctionType.FragmentEntryPoint : ShaderFunctionType.Normal;
+                ? ShaderFunctionType.VertexEntryPoint
+                : isFragmentShader
+                    ? ShaderFunctionType.FragmentEntryPoint
+                    : isComputeShader
+                        ? ShaderFunctionType.ComputeEntryPoint
+                        : ShaderFunctionType.Normal;
 
             string nestedTypePrefix = Utilities.GetFullNestedTypePrefix(node, out bool nested);
-            ShaderFunction sf = new ShaderFunction(nestedTypePrefix, functionName, returnType, parameters.ToArray(), type);
+            ShaderFunction sf = new ShaderFunction(
+                nestedTypePrefix,
+                functionName,
+                returnType,
+                parameters.ToArray(),
+                type,
+                computeGroupCounts);
             ShaderFunctionAndBlockSyntax sfab = new ShaderFunctionAndBlockSyntax(sf, node.Body);
             foreach (LanguageBackend b in _backends) { b.AddFunction(_shaderSet.Name, sfab); }
         }
@@ -108,12 +130,17 @@ namespace ShaderGen
                     "Array fields in structs must have a constant size specified by an ArraySizeAttribute.");
             }
             AttributeSyntax arraySizeAttr = arraySizeAttrs[0];
-            return GetAttributeFirstArgumentIntValue(arraySizeAttr);
+            return GetAttributeArgumentIntValue(arraySizeAttr, 0);
         }
 
-        private static int GetAttributeFirstArgumentIntValue(AttributeSyntax attr)
+        private static int GetAttributeArgumentIntValue(AttributeSyntax attr, int index)
         {
-            string fullArg0 = attr.ArgumentList.Arguments[0].ToFullString();
+            if (attr.ArgumentList.Arguments.Count < index + 1)
+            {
+                throw new ShaderGenerationException(
+                    "Too few arguments in attribute " + attr.ToFullString() + ". Required + " + (index + 1));
+            }
+            string fullArg0 = attr.ArgumentList.Arguments[index].ToFullString();
             if (int.TryParse(fullArg0, out int ret))
             {
                 return ret;
@@ -124,6 +151,24 @@ namespace ShaderGen
             }
         }
 
+        private static uint GetAttributeArgumentUIntValue(AttributeSyntax attr, int index)
+        {
+            if (attr.ArgumentList.Arguments.Count < index + 1)
+            {
+                throw new ShaderGenerationException(
+                    "Too few arguments in attribute " + attr.ToFullString() + ". Required + " + (index + 1));
+            }
+            string fullArg0 = attr.ArgumentList.Arguments[index].ToFullString();
+            if (uint.TryParse(fullArg0, out uint ret))
+            {
+                return ret;
+            }
+            else
+            {
+                throw new ShaderGenerationException("Incorrectly formatted attribute: " + attr.ToFullString());
+            }
+        }
+
         private static SemanticType GetSemanticType(VariableDeclaratorSyntax vds)
         {
             AttributeSyntax[] attrs = Utilities.GetMemberAttributes(vds, "VertexSemantic");
@@ -199,18 +244,23 @@ namespace ShaderGen
             string resourceName = vds.Identifier.Text;
             TypeInfo typeInfo = GetModel(node).GetTypeInfo(node.Type);
             string fullTypeName = GetModel(node).GetFullTypeName(node.Type);
-            TypeReference tr = new TypeReference(fullTypeName);
+            TypeReference valueType = new TypeReference(fullTypeName);
             ShaderResourceKind kind = ClassifyResourceKind(fullTypeName);
 
+            if (kind == ShaderResourceKind.StructuredBuffer || kind == ShaderResourceKind.RWStructuredBuffer)
+            {
+                valueType = ParseStructuredBufferElementType(vds);
+            }
+
             int set = 0; // Default value if not otherwise specified.
             if (GetResourceDecl(node, out AttributeSyntax resourceSetDecl))
             {
-                set = GetAttributeFirstArgumentIntValue(resourceSetDecl);
+                set = GetAttributeArgumentIntValue(resourceSetDecl, 0);
             }
 
             int resourceBinding = GetAndIncrementBinding(set);
 
-            ResourceDefinition rd = new ResourceDefinition(resourceName, set, resourceBinding, tr, kind);
+            ResourceDefinition rd = new ResourceDefinition(resourceName, set, resourceBinding, valueType, kind);
             if (kind == ShaderResourceKind.Uniform)
             {
                 ValidateResourceType(typeInfo);
@@ -219,6 +269,15 @@ namespace ShaderGen
             foreach (LanguageBackend b in _backends) { b.AddResource(_shaderSet.Name, rd); }
         }
 
+        private TypeReference ParseStructuredBufferElementType(VariableDeclaratorSyntax vds)
+        {
+            FieldDeclarationSyntax fieldDecl = (FieldDeclarationSyntax)vds.Parent.Parent;
+            GenericNameSyntax gns = (GenericNameSyntax)fieldDecl.Declaration.Type;
+            TypeSyntax type = gns.TypeArgumentList.Arguments[0];
+            string fullName = Utilities.GetFullTypeName(GetModel(vds), type);
+            return new TypeReference(fullName);
+        }
+
         private int GetAndIncrementBinding(int set)
         {
             if (!_setCounts.TryGetValue(set, out int ret))
@@ -268,6 +327,14 @@ namespace ShaderGen
             {
                 return ShaderResourceKind.Sampler;
             }
+            else if (fullTypeName.Contains("ShaderGen.RWStructuredBuffer"))
+            {
+                return ShaderResourceKind.RWStructuredBuffer;
+            }
+            else if (fullTypeName.Contains("ShaderGen.StructuredBuffer"))
+            {
+                return ShaderResourceKind.StructuredBuffer;
+            }
             else
             {
                 return ShaderResourceKind.Uniform;

+ 1 - 0
src/ShaderGen/ShaderType.cs

@@ -5,5 +5,6 @@
         Normal = 0,
         VertexEntryPoint,
         FragmentEntryPoint,
+        ComputeEntryPoint,
     }
 }