TheGenerator.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. using System.Collections.Immutable;
  2. using System.Text;
  3. using Microsoft.CodeAnalysis;
  4. using Microsoft.CodeAnalysis.CSharp;
  5. using Microsoft.CodeAnalysis.CSharp.Syntax;
  6. namespace TerminalGuiFluentTestingXunit.Generator;
  7. [Generator]
  8. public class TheGenerator : IIncrementalGenerator
  9. {
  10. /// <inheritdoc/>
  11. public void Initialize (IncrementalGeneratorInitializationContext context)
  12. {
  13. IncrementalValuesProvider<ClassDeclarationSyntax> provider = context.SyntaxProvider.CreateSyntaxProvider (
  14. static (node, _) => IsClass (node, "XunitContextExtensions"),
  15. static (ctx, _) =>
  16. (ClassDeclarationSyntax)ctx.Node)
  17. .Where (m => m is { });
  18. IncrementalValueProvider<(Compilation Left, ImmutableArray<ClassDeclarationSyntax> Right)> compilation =
  19. context.CompilationProvider.Combine (provider.Collect ());
  20. context.RegisterSourceOutput (compilation, Execute);
  21. }
  22. private static bool IsClass (SyntaxNode node, string named) { return node is ClassDeclarationSyntax c && c.Identifier.Text == named; }
  23. private void Execute (SourceProductionContext context, (Compilation Left, ImmutableArray<ClassDeclarationSyntax> Right) arg2)
  24. {
  25. INamedTypeSymbol assertType = arg2.Left.GetTypeByMetadataName ("Xunit.Assert")
  26. ?? throw new NotSupportedException ("Referencing codebase does not include Xunit, could not find Xunit.Assert");
  27. GenerateMethods (assertType, context, "Equal", false);
  28. GenerateMethods (assertType, context, "All", true);
  29. GenerateMethods (assertType, context, "Collection", true);
  30. GenerateMethods (assertType, context, "Contains", true);
  31. GenerateMethods (assertType, context, "Distinct", true);
  32. GenerateMethods (assertType, context, "DoesNotContain", true);
  33. GenerateMethods (assertType, context, "DoesNotMatch", true);
  34. GenerateMethods (assertType, context, "Empty", true);
  35. GenerateMethods (assertType, context, "EndsWith", false);
  36. GenerateMethods (assertType, context, "Equivalent", true);
  37. GenerateMethods (assertType, context, "Fail", true);
  38. GenerateMethods (assertType, context, "False", true);
  39. GenerateMethods (assertType, context, "InRange", true);
  40. GenerateMethods (assertType, context, "IsAssignableFrom", true);
  41. GenerateMethods (assertType, context, "IsNotAssignableFrom", true);
  42. GenerateMethods (assertType, context, "IsType", true);
  43. GenerateMethods (assertType, context, "IsNotType", true);
  44. GenerateMethods (assertType, context, "Matches", true);
  45. GenerateMethods (assertType, context, "Multiple", true);
  46. GenerateMethods (assertType, context, "NotEmpty", true);
  47. GenerateMethods (assertType, context, "NotEqual", true);
  48. GenerateMethods (assertType, context, "NotInRange", true);
  49. GenerateMethods (assertType, context, "NotNull", false);
  50. GenerateMethods (assertType, context, "NotSame", true);
  51. GenerateMethods (assertType, context, "NotStrictEqual", true);
  52. GenerateMethods (assertType, context, "Null", false);
  53. GenerateMethods (assertType, context, "ProperSubset", true);
  54. GenerateMethods (assertType, context, "ProperSuperset", true);
  55. GenerateMethods (assertType, context, "Raises", true);
  56. GenerateMethods (assertType, context, "RaisesAny", true);
  57. GenerateMethods (assertType, context, "Same", true);
  58. GenerateMethods (assertType, context, "Single", true);
  59. GenerateMethods (assertType, context, "StartsWith", false);
  60. GenerateMethods (assertType, context, "StrictEqual", true);
  61. GenerateMethods (assertType, context, "Subset", true);
  62. GenerateMethods (assertType, context, "Superset", true);
  63. // GenerateMethods (assertType, context, "Throws", true);
  64. // GenerateMethods (assertType, context, "ThrowsAny", true);
  65. GenerateMethods (assertType, context, "True", false);
  66. }
  67. private void GenerateMethods (INamedTypeSymbol assertType, SourceProductionContext context, string methodName, bool invokeTExplicitly)
  68. {
  69. var sb = new StringBuilder ();
  70. // Create a HashSet to track unique method signatures
  71. HashSet<string> signaturesDone = new ();
  72. List<IMethodSymbol> methods = assertType
  73. .GetMembers (methodName)
  74. .OfType<IMethodSymbol> ()
  75. .ToList ();
  76. var header = """"
  77. #nullable enable
  78. using TerminalGuiFluentTesting;
  79. using Xunit;
  80. namespace TerminalGuiFluentTestingXunit;
  81. public static partial class XunitContextExtensions
  82. {
  83. """";
  84. var tail = """
  85. }
  86. """;
  87. sb.AppendLine (header);
  88. foreach (IMethodSymbol? m in methods)
  89. {
  90. string signature = GetModifiedMethodSignature (m, methodName, invokeTExplicitly, out string [] paramNames, out string typeParams);
  91. if (!signaturesDone.Add (signature))
  92. {
  93. continue;
  94. }
  95. var method = $$"""
  96. {{signature}}
  97. {
  98. try
  99. {
  100. Assert.{{methodName}}{{typeParams}} ({{string.Join (",", paramNames)}});
  101. }
  102. catch(Exception ex)
  103. {
  104. context.HardStop (ex);
  105. throw;
  106. }
  107. return context;
  108. }
  109. """;
  110. sb.AppendLine (method);
  111. }
  112. sb.AppendLine (tail);
  113. context.AddSource ($"XunitContextExtensions{methodName}.g.cs", sb.ToString ());
  114. }
  115. private string GetModifiedMethodSignature (
  116. IMethodSymbol methodSymbol,
  117. string methodName,
  118. bool invokeTExplicitly,
  119. out string [] paramNames,
  120. out string typeParams
  121. )
  122. {
  123. typeParams = string.Empty;
  124. // Create the "this GuiTestContext context" parameter
  125. ParameterSyntax contextParam = SyntaxFactory.Parameter (SyntaxFactory.Identifier ("context"))
  126. .WithType (SyntaxFactory.ParseTypeName ("GuiTestContext"))
  127. .AddModifiers (SyntaxFactory.Token (SyntaxKind.ThisKeyword)); // Add the "this" keyword
  128. // Extract the parameter names (expected and actual)
  129. paramNames = new string [methodSymbol.Parameters.Length];
  130. for (var i = 0; i < methodSymbol.Parameters.Length; i++)
  131. {
  132. paramNames [i] = methodSymbol.Parameters.ElementAt (i).Name;
  133. // Check if the parameter name is a reserved keyword and prepend "@" if it is
  134. if (IsReservedKeyword (paramNames [i]))
  135. {
  136. paramNames [i] = "@" + paramNames [i];
  137. }
  138. else
  139. {
  140. paramNames [i] = paramNames [i];
  141. }
  142. }
  143. // Get the current method parameters and add the context parameter at the start
  144. List<ParameterSyntax> parameters = methodSymbol.Parameters.Select (p => CreateParameter (p)).ToList ();
  145. parameters.Insert (0, contextParam); // Insert 'context' as the first parameter
  146. // Change the return type to GuiTestContext
  147. TypeSyntax returnType = SyntaxFactory.ParseTypeName ("GuiTestContext");
  148. // Change the method name to AssertEqual
  149. SyntaxToken newMethodName = SyntaxFactory.Identifier ($"Assert{methodName}");
  150. // Handle generic type parameters if the method is generic
  151. TypeParameterSyntax [] typeParameters = methodSymbol.TypeParameters.Select (
  152. tp =>
  153. SyntaxFactory.TypeParameter (SyntaxFactory.Identifier (tp.Name))
  154. )
  155. .ToArray ();
  156. MethodDeclarationSyntax dec = SyntaxFactory.MethodDeclaration (returnType, newMethodName)
  157. .WithModifiers (
  158. SyntaxFactory.TokenList (
  159. SyntaxFactory.Token (SyntaxKind.PublicKeyword),
  160. SyntaxFactory.Token (SyntaxKind.StaticKeyword)))
  161. .WithParameterList (SyntaxFactory.ParameterList (SyntaxFactory.SeparatedList (parameters)));
  162. if (typeParameters.Any ())
  163. {
  164. // Add the <T> here
  165. dec = dec.WithTypeParameterList (SyntaxFactory.TypeParameterList (SyntaxFactory.SeparatedList (typeParameters)));
  166. // Handle type parameter constraints
  167. List<TypeParameterConstraintClauseSyntax> constraintClauses = new ();
  168. foreach (ITypeParameterSymbol tp in methodSymbol.TypeParameters)
  169. {
  170. List<TypeParameterConstraintSyntax> constraints = new ();
  171. // Add class/struct constraints
  172. if (tp.HasReferenceTypeConstraint)
  173. {
  174. constraints.Add (SyntaxFactory.ClassOrStructConstraint (SyntaxKind.ClassConstraint));
  175. }
  176. else if (tp.HasValueTypeConstraint)
  177. {
  178. constraints.Add (SyntaxFactory.ClassOrStructConstraint (SyntaxKind.StructConstraint));
  179. }
  180. else if (tp.HasNotNullConstraint)
  181. {
  182. // Add notnull constraint
  183. constraints.Add (SyntaxFactory.TypeConstraint (SyntaxFactory.IdentifierName ("notnull")));
  184. }
  185. // Add type constraints
  186. foreach (ITypeSymbol constraintType in tp.ConstraintTypes)
  187. {
  188. constraints.Add (
  189. SyntaxFactory.TypeConstraint (
  190. SyntaxFactory.ParseTypeName (constraintType.ToDisplayString ())));
  191. }
  192. // Add new() constraint
  193. if (tp.HasConstructorConstraint)
  194. {
  195. constraints.Add (SyntaxFactory.ConstructorConstraint ());
  196. }
  197. // Only add constraint clause if there are constraints
  198. if (constraints.Any ())
  199. {
  200. constraintClauses.Add (
  201. SyntaxFactory.TypeParameterConstraintClause (tp.Name)
  202. .WithConstraints (SyntaxFactory.SeparatedList (constraints)));
  203. }
  204. }
  205. if (constraintClauses.Any ())
  206. {
  207. dec = dec.WithConstraintClauses (SyntaxFactory.List (constraintClauses));
  208. }
  209. // Add the <T> here
  210. if (invokeTExplicitly)
  211. {
  212. typeParams = "<" + string.Join (", ", typeParameters.Select (tp => tp.Identifier.ValueText)) + ">";
  213. }
  214. }
  215. // Build the method signature syntax tree
  216. MethodDeclarationSyntax methodSyntax = dec.NormalizeWhitespace ();
  217. // Convert the method syntax to a string
  218. var methodString = methodSyntax.ToString ();
  219. return methodString;
  220. }
  221. /// <summary>
  222. /// Creates a <see cref="ParameterSyntax"/> from a discovered parameter on real xunit method parameter
  223. /// <paramref name="p"/>
  224. /// </summary>
  225. /// <param name="p"></param>
  226. /// <returns></returns>
  227. private ParameterSyntax CreateParameter (IParameterSymbol p)
  228. {
  229. string paramName = p.Name;
  230. // Check if the parameter name is a reserved keyword and prepend "@" if it is
  231. if (IsReservedKeyword (paramName))
  232. {
  233. paramName = "@" + paramName;
  234. }
  235. // Create the basic parameter syntax with the modified name and type
  236. ParameterSyntax parameterSyntax = SyntaxFactory.Parameter (SyntaxFactory.Identifier (paramName))
  237. .WithType (SyntaxFactory.ParseTypeName (p.Type.ToDisplayString ()));
  238. // Add 'params' keyword if the parameter has the Params modifier
  239. var modifiers = new List<SyntaxToken> ();
  240. if (p.IsParams)
  241. {
  242. modifiers.Add (SyntaxFactory.Token (SyntaxKind.ParamsKeyword));
  243. }
  244. // Handle ref/out/in modifiers
  245. if (p.RefKind != RefKind.None)
  246. {
  247. SyntaxKind modifierKind = p.RefKind switch
  248. {
  249. RefKind.Ref => SyntaxKind.RefKeyword,
  250. RefKind.Out => SyntaxKind.OutKeyword,
  251. RefKind.In => SyntaxKind.InKeyword,
  252. _ => throw new NotSupportedException ($"Unsupported RefKind: {p.RefKind}")
  253. };
  254. modifiers.Add (SyntaxFactory.Token (modifierKind));
  255. }
  256. if (modifiers.Any ())
  257. {
  258. parameterSyntax = parameterSyntax.WithModifiers (SyntaxFactory.TokenList (modifiers));
  259. }
  260. // Add default value if one is present
  261. if (p.HasExplicitDefaultValue)
  262. {
  263. ExpressionSyntax defaultValueExpression = p.ExplicitDefaultValue switch
  264. {
  265. null => SyntaxFactory.LiteralExpression (SyntaxKind.NullLiteralExpression),
  266. bool b => SyntaxFactory.LiteralExpression (
  267. b
  268. ? SyntaxKind.TrueLiteralExpression
  269. : SyntaxKind.FalseLiteralExpression),
  270. int i => SyntaxFactory.LiteralExpression (
  271. SyntaxKind.NumericLiteralExpression,
  272. SyntaxFactory.Literal (i)),
  273. double d => SyntaxFactory.LiteralExpression (
  274. SyntaxKind.NumericLiteralExpression,
  275. SyntaxFactory.Literal (d)),
  276. string s => SyntaxFactory.LiteralExpression (
  277. SyntaxKind.StringLiteralExpression,
  278. SyntaxFactory.Literal (s)),
  279. _ => SyntaxFactory.ParseExpression (p.ExplicitDefaultValue.ToString ()) // Fallback
  280. };
  281. parameterSyntax = parameterSyntax.WithDefault (
  282. SyntaxFactory.EqualsValueClause (defaultValueExpression)
  283. );
  284. }
  285. return parameterSyntax;
  286. }
  287. // Helper method to check if a parameter name is a reserved keyword
  288. private bool IsReservedKeyword (string name) { return string.Equals (name, "object"); }
  289. }