using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Lua.SourceGenerator; partial class LuaObjectGenerator { static string GetLuaValuePrefix(ITypeSymbol typeSymbol, SymbolReferences references, Compilation compilation) { return compilation.ClassifyCommonConversion(typeSymbol, references.LuaUserData).Exists ? "global::Lua.LuaValue.FromUserData(" : "("; } static bool TryEmit(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation, in SourceProductionContext context, Dictionary metaDict) { try { var error = false; // must be partial if (!typeMetadata.IsPartial()) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.MustBePartial, typeMetadata.Syntax.Identifier.GetLocation(), typeMetadata.Symbol.Name)); error = true; } // nested is not allowed if (typeMetadata.IsNested()) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.NestedNotAllowed, typeMetadata.Syntax.Identifier.GetLocation(), typeMetadata.Symbol.Name)); error = true; } // verify abstract/interface if (typeMetadata.Symbol.IsAbstract) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.AbstractNotAllowed, typeMetadata.Syntax.Identifier.GetLocation(), typeMetadata.TypeName)); error = true; } if (!ValidateMembers(typeMetadata, compilation, references, context, metaDict)) { error = true; } if (error) { return false; } builder.AppendLine("// "); builder.AppendLine("#nullable enable"); builder.AppendLine("#pragma warning disable CS0162 // Unreachable code"); builder.AppendLine("#pragma warning disable CS0219 // Variable assigned but never used"); builder.AppendLine("#pragma warning disable CS8600 // Converting null literal or possible null value to non-nullable type."); builder.AppendLine("#pragma warning disable CS8601 // Possible null reference assignment"); builder.AppendLine("#pragma warning disable CS8602 // Possible null return"); builder.AppendLine("#pragma warning disable CS8604 // Possible null reference argument for parameter"); builder.AppendLine("#pragma warning disable CS8631 // The type cannot be used as type parameter in the generic type or method"); builder.AppendLine(); var ns = typeMetadata.Symbol.ContainingNamespace; if (!ns.IsGlobalNamespace) { builder.AppendLine($"namespace {ns}"); builder.BeginBlock(); } var typeDeclarationKeyword = (typeMetadata.Symbol.IsRecord, typeMetadata.Symbol.IsValueType) switch { (true, true) => "record struct", (true, false) => "record", (false, true) => "struct", (false, false) => "class" }; using var _ = builder.BeginBlockScope($"partial {typeDeclarationKeyword} {typeMetadata.TypeName} : global::Lua.ILuaUserData"); var metamethodSet = new HashSet(); if (!TryEmitMethods(typeMetadata, builder, references, compilation, metamethodSet, context)) { return false; } if (!TryEmitIndexMetamethod(typeMetadata, builder, references, compilation, context)) { return false; } if (!TryEmitNewIndexMetamethod(typeMetadata, builder, references, context)) { return false; } if (!TryEmitMetatable(builder, metamethodSet, context)) { return false; } // implicit operator builder.AppendLine($"public static implicit operator global::Lua.LuaValue({typeMetadata.FullTypeName} value)"); using (builder.BeginBlockScope()) { builder.AppendLine("return global::Lua.LuaValue.FromUserData(value);"); } if (!ns.IsGlobalNamespace) { builder.EndBlock(); } builder.AppendLine("#pragma warning restore CS0162 // Unreachable code"); builder.AppendLine("#pragma warning restore CS0219 // Variable assigned but never used"); builder.AppendLine("#pragma warning restore CS8600 // Converting null literal or possible null value to non-nullable type."); builder.AppendLine("#pragma warning restore CS8601 // Possible null reference assignment"); builder.AppendLine("#pragma warning restore CS8602 // Possible null return"); builder.AppendLine("#pragma warning restore CS8604 // Possible null reference argument for parameter"); builder.AppendLine("#pragma warning restore CS8631 // The type cannot be used as type parameter in the generic type or method"); return true; } catch (Exception) { return false; } } static bool ValidateMembers(TypeMetadata typeMetadata, Compilation compilation, SymbolReferences references, in SourceProductionContext context, Dictionary metaDict) { var isValid = true; foreach (var property in typeMetadata.Properties) { if (SymbolEqualityComparer.Default.Equals(property.Type, references.LuaValue)) { continue; } if (SymbolEqualityComparer.Default.Equals(property.Type, references.LuaUserData)) { continue; } if (SymbolEqualityComparer.Default.Equals(property.Type, typeMetadata.Symbol)) { continue; } if (compilation.ClassifyConversion(property.Type, references.LuaUserData).Exists) { continue; } var conversion = compilation.ClassifyConversion(property.Type, references.LuaValue); if (!conversion.Exists && (property.Type is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol))) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.InvalidPropertyType, property.Symbol.Locations.FirstOrDefault(), property.Type.Name)); isValid = false; } } foreach (var method in typeMetadata.Methods) { if (!method.Symbol.ReturnsVoid) { var typeSymbol = method.Symbol.ReturnType; if (method.IsAsync) { var namedType = (INamedTypeSymbol)typeSymbol; if (namedType.TypeArguments.Length == 0) { goto PARAMETERS; } typeSymbol = namedType.TypeArguments[0]; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) { goto PARAMETERS; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaUserData)) { goto PARAMETERS; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) { goto PARAMETERS; } if (compilation.ClassifyConversion(typeSymbol, references.LuaUserData).Exists) { goto PARAMETERS; } var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue); if (!conversion.Exists && (typeSymbol is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol))) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.InvalidReturnType, typeSymbol.Locations.FirstOrDefault(), typeSymbol.Name)); isValid = false; } } PARAMETERS: for (var index = 0; index < method.Symbol.Parameters.Length; index++) { var parameterSymbol = method.Symbol.Parameters[index]; var typeSymbol = parameterSymbol.Type; if (index == method.Symbol.Parameters.Length - 1 && SymbolEqualityComparer.Default.Equals(typeSymbol, references.CancellationToken)) { continue; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) { continue; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaUserData)) { continue; } if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) { continue; } if (compilation.ClassifyConversion(typeSymbol, references.LuaUserData).Exists) { continue; } var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue); if (!conversion.Exists && (typeSymbol is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol))) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.InvalidParameterType, typeSymbol.Locations.FirstOrDefault(), typeSymbol.Name)); isValid = false; } } } return isValid; } static bool TryEmitIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation, in SourceProductionContext context) { builder.AppendLine(@"static readonly global::Lua.LuaFunction __metamethod_index = new global::Lua.LuaFunction(""index"", (context, ct) =>"); using (builder.BeginBlockScope()) { builder.AppendLine($"var userData = context.GetArgument<{typeMetadata.FullTypeName}>(0);"); builder.AppendLine($"var key = context.GetArgument(1);"); builder.AppendLine("var result = key switch"); using (builder.BeginBlockScope()) { foreach (var propertyMetadata in typeMetadata.Properties) { var conversionPrefix = GetLuaValuePrefix(propertyMetadata.Type, references, compilation); if (propertyMetadata.IsStatic) { builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => {conversionPrefix}{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name}),"); } else { builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => {conversionPrefix}userData.{propertyMetadata.Symbol.Name}),"); } } foreach (var methodMetadata in typeMetadata.Methods .Where(x => x.HasMemberAttribute)) { builder.AppendLine(@$"""{methodMetadata.LuaMemberName}"" => new global::Lua.LuaValue(__function_{methodMetadata.LuaMemberName}),"); } builder.AppendLine(@$"_ => global::Lua.LuaValue.Nil,"); } builder.AppendLine(";"); builder.AppendLine("return new global::System.Threading.Tasks.ValueTask(context.Return(result));"); } builder.AppendLine(");"); return true; } static bool TryEmitNewIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, in SourceProductionContext context) { builder.AppendLine(@"static readonly global::Lua.LuaFunction __metamethod_newindex = new global::Lua.LuaFunction(""newindex"", (context, ct) =>"); using (builder.BeginBlockScope()) { builder.AppendLine($"var userData = context.GetArgument<{typeMetadata.FullTypeName}>(0);"); builder.AppendLine($"var key = context.GetArgument(1);"); builder.AppendLine("switch (key)"); using (builder.BeginBlockScope()) { foreach (var propertyMetadata in typeMetadata.Properties) { builder.AppendLine(@$"case ""{propertyMetadata.LuaMemberName}"":"); using (builder.BeginIndentScope()) { if (propertyMetadata.IsReadOnly) { builder.AppendLine($@"throw new global::Lua.LuaRuntimeException(context.State, $""'{{key}}' cannot overwrite."");"); } else if (propertyMetadata.IsStatic) { if (SymbolEqualityComparer.Default.Equals(propertyMetadata.Type, references.LuaValue)) { builder.AppendLine($"{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name} = context.GetArgument(2);"); } else { builder.AppendLine($"{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);"); } builder.AppendLine("break;"); } else { if (SymbolEqualityComparer.Default.Equals(propertyMetadata.Type, references.LuaValue)) { builder.AppendLine($"userData.{propertyMetadata.Symbol.Name} = context.GetArgument(2);"); } else { builder.AppendLine($"userData.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);"); } builder.AppendLine("break;"); } } } foreach (var methodMetadata in typeMetadata.Methods .Where(x => x.HasMemberAttribute)) { builder.AppendLine(@$"case ""{methodMetadata.LuaMemberName}"":"); using (builder.BeginIndentScope()) { builder.AppendLine($@"throw new global::Lua.LuaRuntimeException(context.State, $""'{{key}}' cannot overwrite."");"); } } builder.AppendLine(@$"default:"); using (builder.BeginIndentScope()) { builder.AppendLine(@$"throw new global::Lua.LuaRuntimeException(context.State, $""'{{key}}' not found."");"); } } builder.AppendLine("return new global::System.Threading.Tasks.ValueTask(context.Return());"); } builder.AppendLine(");"); return true; } static bool TryEmitMethods(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation, HashSet metamethodSet, in SourceProductionContext context) { builder.AppendLine(); foreach (var methodMetadata in typeMetadata.Methods) { string? functionName = null; if (methodMetadata.HasMemberAttribute) { functionName = $"__function_{methodMetadata.LuaMemberName}"; EmitMethodFunction(functionName, methodMetadata.LuaMemberName, typeMetadata, methodMetadata, builder, references, compilation); } if (methodMetadata.HasMetamethodAttribute) { if (!metamethodSet.Add(methodMetadata.Metamethod)) { context.ReportDiagnostic(Diagnostic.Create( DiagnosticDescriptors.DuplicateMetamethod, methodMetadata.Symbol.Locations.FirstOrDefault(), typeMetadata.TypeName, methodMetadata.Metamethod )); continue; } if (functionName == null) { EmitMethodFunction($"__metamethod_{methodMetadata.Metamethod}", methodMetadata.Metamethod.ToString().ToLower(), typeMetadata, methodMetadata, builder, references, compilation); } else { builder.AppendLine($"static global::Lua.LuaFunction __metamethod_{methodMetadata.Metamethod} => {functionName};"); } } } return true; } static void EmitMethodFunction(string functionName, string chunkName, TypeMetadata typeMetadata, MethodMetadata methodMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation) { builder.AppendLine($@"static readonly global::Lua.LuaFunction {functionName} = new global::Lua.LuaFunction(""{chunkName}"", {(methodMetadata.IsAsync ? "async" : "")} (context, ct) =>"); using (builder.BeginBlockScope()) { var index = 0; if (!methodMetadata.IsStatic) { builder.AppendLine($"var userData = context.GetArgument<{typeMetadata.FullTypeName}>(0);"); index++; } var hasCancellationToken = false; for (var i = 0; i < methodMetadata.Symbol.Parameters.Length; i++) { var parameter = methodMetadata.Symbol.Parameters[i]; var parameterType = parameter.Type; var isParameterLuaValue = SymbolEqualityComparer.Default.Equals(parameterType, references.LuaValue); if (i == methodMetadata.Symbol.Parameters.Length - 1 && SymbolEqualityComparer.Default.Equals(parameterType, references.CancellationToken)) { hasCancellationToken = true; break; } if (parameter.HasExplicitDefaultValue) { var syntax = (ParameterSyntax)parameter.DeclaringSyntaxReferences[0].GetSyntax(); if (isParameterLuaValue) { builder.AppendLine($"var arg{index} = context.HasArgument({index}) ? context.GetArgument({index}) : {syntax.Default!.Value.ToFullString()};"); } else { builder.AppendLine($"var arg{index} = context.HasArgument({index}) ? context.GetArgument<{parameterType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>({index}) : {syntax.Default!.Value.ToFullString()};"); } } else { if (isParameterLuaValue) { builder.AppendLine($"var arg{index} = context.GetArgument({index});"); } else { builder.AppendLine($"var arg{index} = context.GetArgument<{parameterType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>({index});"); } } index++; } if (methodMetadata.HasReturnValue) { builder.Append("var result = "); } if (methodMetadata.IsAsync) { builder.Append("await ", !methodMetadata.HasReturnValue); } if (methodMetadata.IsStatic) { builder.Append($"{typeMetadata.FullTypeName}.{methodMetadata.Symbol.Name}(", !(methodMetadata.HasReturnValue || methodMetadata.IsAsync)); builder.Append(string.Join(",", Enumerable.Range(0, index).Select(x => $"arg{x}")), false); if (hasCancellationToken) { builder.Append(index > 0 ? ",ct" : "ct", false); } builder.AppendLine(");", false); } else { builder.Append($"userData.{methodMetadata.Symbol.Name}(", !(methodMetadata.HasReturnValue || methodMetadata.IsAsync)); builder.Append(string.Join(",", Enumerable.Range(1, index - 1).Select(x => $"arg{x}")), false); if (hasCancellationToken) { builder.Append(index > 1 ? ",ct" : "ct", false); } builder.AppendLine(");", false); } builder.Append("return "); if (methodMetadata.HasReturnValue) { var returnType = methodMetadata.Symbol.ReturnType; if (methodMetadata.IsAsync) { var namedType = (INamedTypeSymbol)returnType; if (namedType.TypeArguments.Length == 1) { returnType = namedType.TypeArguments[0]; } } var conversionPrefix = GetLuaValuePrefix(returnType, references, compilation); builder.AppendLine(methodMetadata.IsAsync ? $"context.Return({conversionPrefix}result));" : $"new global::System.Threading.Tasks.ValueTask(context.Return({conversionPrefix}result)));", false); } else { builder.AppendLine(methodMetadata.IsAsync ? "context.Return();" : "new global::System.Threading.Tasks.ValueTask(context.Return());", false); } } builder.AppendLine(");"); builder.AppendLine(); } static bool TryEmitMetatable(CodeBuilder builder, IEnumerable metamethods, in SourceProductionContext context) { builder.AppendLine("global::Lua.LuaTable? global::Lua.ILuaUserData.Metatable"); using (builder.BeginBlockScope()) { builder.AppendLine("get"); using (builder.BeginBlockScope()) { builder.AppendLine("if (__metatable != null) return __metatable;"); builder.AppendLine(); builder.AppendLine("__metatable = new();"); builder.AppendLine("__metatable[global::Lua.Runtime.Metamethods.Index] = __metamethod_index;"); builder.AppendLine("__metatable[global::Lua.Runtime.Metamethods.NewIndex] = __metamethod_newindex;"); foreach (var metamethod in metamethods) { builder.AppendLine($"__metatable[global::Lua.Runtime.Metamethods.{metamethod}] = __metamethod_{metamethod};"); } builder.AppendLine("return __metatable;"); } builder.AppendLine("set"); using (builder.BeginBlockScope()) { builder.AppendLine("__metatable = value;"); } } builder.AppendLine("static global::Lua.LuaTable? __metatable;"); builder.AppendLine(); return true; } }