using System.Collections.Immutable;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace TerminalGuiFluentTestingXunit.Generator;
[Generator]
public class TheGenerator : IIncrementalGenerator
{
///
public void Initialize (IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider provider = context.SyntaxProvider.CreateSyntaxProvider (
static (node, _) => IsClass (node, "XunitContextExtensions"),
static (ctx, _) =>
(ClassDeclarationSyntax)ctx.Node)
.Where (m => m is { });
IncrementalValueProvider<(Compilation Left, ImmutableArray Right)> compilation =
context.CompilationProvider.Combine (provider.Collect ());
context.RegisterSourceOutput (compilation, Execute);
}
private static bool IsClass (SyntaxNode node, string named) { return node is ClassDeclarationSyntax c && c.Identifier.Text == named; }
private void Execute (SourceProductionContext context, (Compilation Left, ImmutableArray Right) arg2)
{
INamedTypeSymbol assertType = arg2.Left.GetTypeByMetadataName ("Xunit.Assert")
?? throw new NotSupportedException ("Referencing codebase does not include Xunit, could not find Xunit.Assert");
GenerateMethods (assertType, context, "Equal", false);
GenerateMethods (assertType, context, "All", true);
GenerateMethods (assertType, context, "Collection", true);
GenerateMethods (assertType, context, "Contains", true);
GenerateMethods (assertType, context, "Distinct", true);
GenerateMethods (assertType, context, "DoesNotContain", true);
GenerateMethods (assertType, context, "DoesNotMatch", true);
GenerateMethods (assertType, context, "Empty", true);
GenerateMethods (assertType, context, "EndsWith", false);
GenerateMethods (assertType, context, "Equivalent", true);
GenerateMethods (assertType, context, "Fail", true);
GenerateMethods (assertType, context, "False", true);
GenerateMethods (assertType, context, "InRange", true);
GenerateMethods (assertType, context, "IsAssignableFrom", true);
GenerateMethods (assertType, context, "IsNotAssignableFrom", true);
GenerateMethods (assertType, context, "IsType", true);
GenerateMethods (assertType, context, "IsNotType", true);
GenerateMethods (assertType, context, "Matches", true);
GenerateMethods (assertType, context, "Multiple", true);
GenerateMethods (assertType, context, "NotEmpty", true);
GenerateMethods (assertType, context, "NotEqual", true);
GenerateMethods (assertType, context, "NotInRange", true);
GenerateMethods (assertType, context, "NotNull", false);
GenerateMethods (assertType, context, "NotSame", true);
GenerateMethods (assertType, context, "NotStrictEqual", true);
GenerateMethods (assertType, context, "Null", false);
GenerateMethods (assertType, context, "ProperSubset", true);
GenerateMethods (assertType, context, "ProperSuperset", true);
GenerateMethods (assertType, context, "Raises", true);
GenerateMethods (assertType, context, "RaisesAny", true);
GenerateMethods (assertType, context, "Same", true);
GenerateMethods (assertType, context, "Single", true);
GenerateMethods (assertType, context, "StartsWith", false);
GenerateMethods (assertType, context, "StrictEqual", true);
GenerateMethods (assertType, context, "Subset", true);
GenerateMethods (assertType, context, "Superset", true);
// GenerateMethods (assertType, context, "Throws", true);
// GenerateMethods (assertType, context, "ThrowsAny", true);
GenerateMethods (assertType, context, "True", false);
}
private void GenerateMethods (INamedTypeSymbol assertType, SourceProductionContext context, string methodName, bool invokeTExplicitly)
{
var sb = new StringBuilder ();
// Create a HashSet to track unique method signatures
HashSet signaturesDone = new ();
List methods = assertType
.GetMembers (methodName)
.OfType ()
.ToList ();
var header = """"
#nullable enable
using TerminalGuiFluentTesting;
using Xunit;
namespace TerminalGuiFluentTestingXunit;
public static partial class XunitContextExtensions
{
"""";
var tail = """
}
""";
sb.AppendLine (header);
foreach (IMethodSymbol? m in methods)
{
string signature = GetModifiedMethodSignature (m, methodName, invokeTExplicitly, out string [] paramNames, out string typeParams);
if (!signaturesDone.Add (signature))
{
continue;
}
var method = $$"""
{{signature}}
{
try
{
Assert.{{methodName}}{{typeParams}} ({{string.Join (",", paramNames)}});
}
catch(Exception ex)
{
context.HardStop (ex);
throw;
}
return context;
}
""";
sb.AppendLine (method);
}
sb.AppendLine (tail);
context.AddSource ($"XunitContextExtensions{methodName}.g.cs", sb.ToString ());
}
private string GetModifiedMethodSignature (
IMethodSymbol methodSymbol,
string methodName,
bool invokeTExplicitly,
out string [] paramNames,
out string typeParams
)
{
typeParams = string.Empty;
// Create the "this GuiTestContext context" parameter
ParameterSyntax contextParam = SyntaxFactory.Parameter (SyntaxFactory.Identifier ("context"))
.WithType (SyntaxFactory.ParseTypeName ("GuiTestContext"))
.AddModifiers (SyntaxFactory.Token (SyntaxKind.ThisKeyword)); // Add the "this" keyword
// Extract the parameter names (expected and actual)
paramNames = new string [methodSymbol.Parameters.Length];
for (var i = 0; i < methodSymbol.Parameters.Length; i++)
{
paramNames [i] = methodSymbol.Parameters.ElementAt (i).Name;
// Check if the parameter name is a reserved keyword and prepend "@" if it is
if (IsReservedKeyword (paramNames [i]))
{
paramNames [i] = "@" + paramNames [i];
}
else
{
paramNames [i] = paramNames [i];
}
}
// Get the current method parameters and add the context parameter at the start
List parameters = methodSymbol.Parameters.Select (p => CreateParameter (p)).ToList ();
parameters.Insert (0, contextParam); // Insert 'context' as the first parameter
// Change the return type to GuiTestContext
TypeSyntax returnType = SyntaxFactory.ParseTypeName ("GuiTestContext");
// Change the method name to AssertEqual
SyntaxToken newMethodName = SyntaxFactory.Identifier ($"Assert{methodName}");
// Handle generic type parameters if the method is generic
TypeParameterSyntax [] typeParameters = methodSymbol.TypeParameters.Select (
tp =>
SyntaxFactory.TypeParameter (SyntaxFactory.Identifier (tp.Name))
)
.ToArray ();
MethodDeclarationSyntax dec = SyntaxFactory.MethodDeclaration (returnType, newMethodName)
.WithModifiers (
SyntaxFactory.TokenList (
SyntaxFactory.Token (SyntaxKind.PublicKeyword),
SyntaxFactory.Token (SyntaxKind.StaticKeyword)))
.WithParameterList (SyntaxFactory.ParameterList (SyntaxFactory.SeparatedList (parameters)));
if (typeParameters.Any ())
{
// Add the here
dec = dec.WithTypeParameterList (SyntaxFactory.TypeParameterList (SyntaxFactory.SeparatedList (typeParameters)));
// Handle type parameter constraints
List constraintClauses = new ();
foreach (ITypeParameterSymbol tp in methodSymbol.TypeParameters)
{
List constraints = new ();
// Add class/struct constraints
if (tp.HasReferenceTypeConstraint)
{
constraints.Add (SyntaxFactory.ClassOrStructConstraint (SyntaxKind.ClassConstraint));
}
else if (tp.HasValueTypeConstraint)
{
constraints.Add (SyntaxFactory.ClassOrStructConstraint (SyntaxKind.StructConstraint));
}
else if (tp.HasNotNullConstraint)
{
// Add notnull constraint
constraints.Add (SyntaxFactory.TypeConstraint (SyntaxFactory.IdentifierName ("notnull")));
}
// Add type constraints
foreach (ITypeSymbol constraintType in tp.ConstraintTypes)
{
constraints.Add (
SyntaxFactory.TypeConstraint (
SyntaxFactory.ParseTypeName (constraintType.ToDisplayString ())));
}
// Add new() constraint
if (tp.HasConstructorConstraint)
{
constraints.Add (SyntaxFactory.ConstructorConstraint ());
}
// Only add constraint clause if there are constraints
if (constraints.Any ())
{
constraintClauses.Add (
SyntaxFactory.TypeParameterConstraintClause (tp.Name)
.WithConstraints (SyntaxFactory.SeparatedList (constraints)));
}
}
if (constraintClauses.Any ())
{
dec = dec.WithConstraintClauses (SyntaxFactory.List (constraintClauses));
}
// Add the here
if (invokeTExplicitly)
{
typeParams = "<" + string.Join (", ", typeParameters.Select (tp => tp.Identifier.ValueText)) + ">";
}
}
// Build the method signature syntax tree
MethodDeclarationSyntax methodSyntax = dec.NormalizeWhitespace ();
// Convert the method syntax to a string
var methodString = methodSyntax.ToString ();
return methodString;
}
///
/// Creates a from a discovered parameter on real xunit method parameter
///
///
///
///
private ParameterSyntax CreateParameter (IParameterSymbol p)
{
string paramName = p.Name;
// Check if the parameter name is a reserved keyword and prepend "@" if it is
if (IsReservedKeyword (paramName))
{
paramName = "@" + paramName;
}
// Create the basic parameter syntax with the modified name and type
ParameterSyntax parameterSyntax = SyntaxFactory.Parameter (SyntaxFactory.Identifier (paramName))
.WithType (SyntaxFactory.ParseTypeName (p.Type.ToDisplayString ()));
// Add 'params' keyword if the parameter has the Params modifier
var modifiers = new List ();
if (p.IsParams)
{
modifiers.Add (SyntaxFactory.Token (SyntaxKind.ParamsKeyword));
}
// Handle ref/out/in modifiers
if (p.RefKind != RefKind.None)
{
SyntaxKind modifierKind = p.RefKind switch
{
RefKind.Ref => SyntaxKind.RefKeyword,
RefKind.Out => SyntaxKind.OutKeyword,
RefKind.In => SyntaxKind.InKeyword,
_ => throw new NotSupportedException ($"Unsupported RefKind: {p.RefKind}")
};
modifiers.Add (SyntaxFactory.Token (modifierKind));
}
if (modifiers.Any ())
{
parameterSyntax = parameterSyntax.WithModifiers (SyntaxFactory.TokenList (modifiers));
}
// Add default value if one is present
if (p.HasExplicitDefaultValue)
{
ExpressionSyntax defaultValueExpression = p.ExplicitDefaultValue switch
{
null => SyntaxFactory.LiteralExpression (SyntaxKind.NullLiteralExpression),
bool b => SyntaxFactory.LiteralExpression (
b
? SyntaxKind.TrueLiteralExpression
: SyntaxKind.FalseLiteralExpression),
int i => SyntaxFactory.LiteralExpression (
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal (i)),
double d => SyntaxFactory.LiteralExpression (
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal (d)),
string s => SyntaxFactory.LiteralExpression (
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal (s)),
_ => SyntaxFactory.ParseExpression (p.ExplicitDefaultValue.ToString ()) // Fallback
};
parameterSyntax = parameterSyntax.WithDefault (
SyntaxFactory.EqualsValueClause (defaultValueExpression)
);
}
return parameterSyntax;
}
// Helper method to check if a parameter name is a reserved keyword
private bool IsReservedKeyword (string name) { return string.Equals (name, "object"); }
}