diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs index 4f33d1f..2a08ae3 100644 --- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs +++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs @@ -10,9 +10,9 @@ namespace TestNs /// /// Partial class for the TestVM which contains ReactiveUI Reactive property initialization. /// - [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] public partial class TestVM { + [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] /// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] [System.Text.Json.Serialization.JsonInclude] diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs index b8c5c8b..fe34cd1 100644 --- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs +++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs @@ -10,9 +10,9 @@ namespace TestNs /// /// Partial class for the TestVM which contains ReactiveUI Reactive property initialization. /// - [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] public partial class TestVM { + [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] /// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] public int Test1 diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs index 91b96a1..155c12f 100644 --- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs +++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs @@ -10,9 +10,9 @@ namespace TestNs /// /// Partial class for the TestVM which contains ReactiveUI Reactive property initialization. /// - [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] public partial class TestVM { + [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")] /// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] public int Test2 diff --git a/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs b/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs index efeb8d4..82eca67 100644 --- a/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs +++ b/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs @@ -47,6 +47,10 @@ public sealed class TestHelper(ITestOutputHelper testOutput) : IDisposable new("ReactiveUI", VersionRange.AllStableFloating, LibraryDependencyTarget.Package); #pragma warning restore CS0618 // Type or member is obsolete + private static readonly string mscorlibPath = Path.Combine( + System.Runtime.InteropServices.RuntimeEnvironment.GetRuntimeDirectory(), + "mscorlib.dll"); + private static readonly MetadataReference[] References = [ MetadataReference.CreateFromFile(typeof(object).Assembly.Location), @@ -54,6 +58,9 @@ public sealed class TestHelper(ITestOutputHelper testOutput) : IDisposable MetadataReference.CreateFromFile(typeof(T).Assembly.Location), MetadataReference.CreateFromFile(typeof(TestHelper).Assembly.Location), + // Create mscorlib Reference + MetadataReference.CreateFromFile(mscorlibPath) + // Wpf references ////MetadataReference.CreateFromFile(Assembly.Load("PresentationCore").Location), ////MetadataReference.CreateFromFile(Assembly.Load("PresentationFramework").Location), @@ -127,19 +134,22 @@ public void TestFail( /// Tests a generator expecting it to pass successfully. /// /// The source code to test. + /// if set to true [with pre diagnosics]. /// /// The driver. /// + /// Must have valid compiler instance. /// callerType. public GeneratorDriver TestPass( - string source) + string source, + bool withPreDiagnosics = false) { if (_eventCompiler is null) { throw new InvalidOperationException("Must have valid compiler instance."); } - return RunGeneratorAndCheck(source); + return RunGeneratorAndCheck(source, withPreDiagnosics); } /// @@ -149,13 +159,15 @@ public GeneratorDriver TestPass( /// Runs the specified source generator and validates the generated code. /// /// The code to be parsed and processed by the generator. + /// if set to true [with pre diagnosics]. /// Indicates whether to rerun the compilation after running the generator. - /// The generator driver used to run the generator. - /// - /// Thrown if the compiler instance is not valid or if the compilation fails. - /// + /// + /// The generator driver used to run the generator. + /// + /// Thrown if the compiler instance is not valid or if the compilation fails. public GeneratorDriver RunGeneratorAndCheck( string code, + bool withPreDiagnosics = false, bool rerunCompilation = true) { if (_eventCompiler is null) @@ -180,11 +192,14 @@ public GeneratorDriver RunGeneratorAndCheck( assemblies, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, deterministic: true)); - // Validate diagnostics before running the generator. - ////var prediagnostics = compilation.GetDiagnostics() - //// .Where(d => !d.Id.Contains("CS0518") && d.Severity > DiagnosticSeverity.Warning) - //// .ToList(); - ////prediagnostics.Should().BeEmpty(); + if (withPreDiagnosics) + { + // Validate diagnostics before running the generator. + var prediagnostics = compilation.GetDiagnostics() + .Where(d => d.Severity > DiagnosticSeverity.Warning) + .ToList(); + prediagnostics.Should().BeEmpty(); + } var generator = new T(); var driver = CSharpGeneratorDriver.Create(generator).WithUpdatedParseOptions((CSharpParseOptions)syntaxTree.Options); diff --git a/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs b/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs index 1af3855..20ff1fa 100644 --- a/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs +++ b/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs @@ -34,10 +34,7 @@ namespace TestNs; public partial class TestVM : ReactiveObject { [ReactiveCommand] - private void Test1() - { - var a = 10; - } + private int Test1() => 10; } """; @@ -68,10 +65,8 @@ namespace TestNs; public partial class TestVM : ReactiveObject { [ReactiveCommand] - private void Test3(string baseString) - { - var a = baseString; - } + [property: JsonInclude] + private int Test3(string baseString) => int.Parse(baseString); } """; diff --git a/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs b/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs index f048339..7fd8eef 100644 --- a/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs +++ b/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs @@ -234,6 +234,66 @@ public static string GetFullyQualifiedMetadataName(this ITypeSymbol symbol) return builder.ToString(); } + public static bool IsTaskReturnType(this ITypeSymbol? typeSymbol) + { + var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; + do + { + var typeName = typeSymbol?.ToDisplayString(nameFormat); + if (typeName == "global::System.Threading.Tasks.Task") + { + return true; + } + + typeSymbol = typeSymbol?.BaseType; + } + while (typeSymbol != null); + + return false; + } + + public static bool IsObservableReturnType(this ITypeSymbol? typeSymbol) + { + var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; + do + { + var typeName = typeSymbol?.ToDisplayString(nameFormat); + if (typeName?.Contains("global::System.IObservable") == true) + { + return true; + } + + typeSymbol = typeSymbol?.BaseType; + } + while (typeSymbol != null); + + return false; + } + + public static bool IsObservableBoolType(this ITypeSymbol? typeSymbol) + { + var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; + do + { + var typeName = typeSymbol?.ToDisplayString(nameFormat); + if (typeName?.Contains("global::System.IObservable") == true) + { + return true; + } + + typeSymbol = typeSymbol?.BaseType; + } + while (typeSymbol != null); + + return false; + } + + public static ITypeSymbol GetTaskReturnType(this ITypeSymbol typeSymbol, Compilation compilation) => typeSymbol switch + { + INamedTypeSymbol { TypeArguments.Length: 1 } namedTypeSymbol => namedTypeSymbol.TypeArguments[0], + _ => compilation.GetSpecialType(SpecialType.System_Void) + }; + /// /// Appends the fully qualified metadata name for a given symbol to a target builder. /// diff --git a/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs b/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs index c449cd5..3b0fe81 100644 --- a/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs +++ b/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs @@ -131,4 +131,6 @@ public AttributeSyntax GetSyntax() return Attribute(IdentifierName(TypeName), AttributeArgumentList(SeparatedList(arguments.Concat(namedArguments)))); } + + public override string ToString() => $"[{GetSyntax()}]"; } diff --git a/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs b/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs index d7c7fb5..43fe70f 100644 --- a/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs +++ b/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs @@ -147,9 +147,9 @@ namespace {{containingNamespace}} {{AddTabs(1)}}/// {{AddTabs(1)}}/// Partial class for the {{containingTypeName}} which contains ReactiveUI Reactive property initialization. {{AddTabs(1)}}/// -{{AddTabs(1)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")] {{AddTabs(1)}}{{containingClassVisibility}} partial {{containingType}} {{containingTypeName}} {{AddTabs(1)}}{ +{{AddTabs(2)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")] {{propertyDeclarations}} {{AddTabs(1)}}} } diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs index e3da06b..482ed98 100644 --- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs +++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs @@ -22,7 +22,7 @@ internal record CommandInfo( bool IsObservable, string? CanExecuteObservableName, CanExecuteTypeInfo? CanExecuteTypeInfo, - EquatableArray ForwardedPropertyAttributes) + EquatableArray ForwardedPropertyAttributes) { private const string UnitTypeName = "global::System.Reactive.Unit"; diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs index f8daf7d..e1236cf 100644 --- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs +++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs @@ -3,6 +3,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for full license information. +using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using System.Globalization; @@ -14,7 +15,7 @@ using ReactiveUI.SourceGenerators.Extensions; using ReactiveUI.SourceGenerators.Helpers; using ReactiveUI.SourceGenerators.Input.Models; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using ReactiveUI.SourceGenerators.Models; namespace ReactiveUI.SourceGenerators; @@ -24,467 +25,461 @@ namespace ReactiveUI.SourceGenerators; /// public partial class ReactiveCommandGenerator { + internal static readonly string GeneratorName = typeof(ReactiveCommandGenerator).FullName!; + internal static readonly string GeneratorVersion = typeof(ReactiveCommandGenerator).Assembly.GetName().Version.ToString(); + private const string ReactiveUI = "ReactiveUI"; private const string ReactiveCommand = "ReactiveCommand"; private const string RxCmd = ReactiveUI + "." + ReactiveCommand; private const string Create = ".Create"; private const string CreateO = ".CreateFromObservable"; private const string CreateT = ".CreateFromTask"; - private const string ObsoleteReason = "Commands are initialized automatically. Method will be removed in future version."; + private const string CanExecute = "CanExecute"; + private static readonly string[] excludeFromCodeCoverage = ["[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]"]; - /// - /// A container for all the logic for . - /// - internal static class Execute + private static CommandInfo? GetMethodInfo(in GeneratorAttributeSyntaxContext context, CancellationToken token) { - internal static MethodDeclarationSyntax GetCommandInitiliser() => MethodDeclaration( - PredefinedType(Token(SyntaxKind.VoidKeyword)), - Identifier("InitializeCommands")) - .AddAttributeLists( - AttributeList(SingletonSeparatedList( - Attribute(IdentifierName(AttributeDefinitions.GeneratedCode)) - .AddArgumentListArguments( - AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveGenerator).FullName))), - AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveGenerator).Assembly.GetName().Version.ToString())))))), - AttributeList(SingletonSeparatedList(Attribute(IdentifierName(AttributeDefinitions.ExcludeFromCodeCoverage)))), - AttributeList(SingletonSeparatedList(Attribute(IdentifierName(AttributeDefinitions.Obsolete)) - .AddArgumentListArguments( - AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ObsoleteReason))))))) - .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword))) - .WithBody(Block()); - - internal static MemberDeclarationSyntax[] GetCommandProperty(CommandInfo commandExtensionInfo) + var symbol = context.TargetSymbol; + if (!symbol.TryGetAttributeWithFullyQualifiedMetadataName(AttributeDefinitions.ReactiveCommandAttributeType, out var attributeData)) { - var outputType = commandExtensionInfo.GetOutputTypeText(); - var inputType = commandExtensionInfo.GetInputTypeText(); - var commandName = GetGeneratedCommandName(commandExtensionInfo.MethodName, commandExtensionInfo.IsTask); - var fieldName = GetGeneratedFieldName(commandName); + return null; + } - ExpressionSyntax initializer; - if (commandExtensionInfo.ArgumentType == null) - { - initializer = GenerateBasicCommand(commandExtensionInfo, fieldName); - } - else if (commandExtensionInfo.ArgumentType != null && commandExtensionInfo.IsReturnTypeVoid) - { - initializer = GenerateInCommand(commandExtensionInfo, fieldName, inputType); - } - else if (commandExtensionInfo.ArgumentType != null && !commandExtensionInfo.IsReturnTypeVoid) - { - initializer = GenerateInOutCommand(commandExtensionInfo, fieldName, outputType, inputType); - } - else + if (symbol is not IMethodSymbol methodSymbol) + { + return default; + } + + token.ThrowIfCancellationRequested(); + + var isTask = methodSymbol.ReturnType.IsTaskReturnType(); + var isObservable = methodSymbol.ReturnType.IsObservableReturnType(); + + var compilation = context.SemanticModel.Compilation; + var realReturnType = isTask || isObservable ? methodSymbol.ReturnType.GetTaskReturnType(compilation) : methodSymbol.ReturnType; + var isReturnTypeVoid = SymbolEqualityComparer.Default.Equals(realReturnType, compilation.GetSpecialType(SpecialType.System_Void)); + var hasCancellationToken = isTask && methodSymbol.Parameters.Any(x => x.Type.ToDisplayString() == "System.Threading.CancellationToken"); + + using var methodParameters = ImmutableArrayBuilder.Rent(); + if (hasCancellationToken && methodSymbol.Parameters.Length == 2) + { + methodParameters.Add(methodSymbol.Parameters[0]); + } + else if (!hasCancellationToken) + { + foreach (var parameter in methodSymbol.Parameters) { - return []; + methodParameters.Add(parameter); } + } - // Prepare any forwarded property attributes - var forwardedPropertyAttributes = - commandExtensionInfo.ForwardedPropertyAttributes - .Select(static a => AttributeList(SingletonSeparatedList(a.GetSyntax()))) - .ToImmutableArray(); - - var qualifiedName = QualifiedName( - IdentifierName(ReactiveUI), - GenericName( - Identifier(ReactiveCommand)) - .WithTypeArgumentList( - TypeArgumentList( - SeparatedList( - new SyntaxNodeOrToken[] - { - IdentifierName(inputType), - Token(SyntaxKind.CommaToken), - IdentifierName(outputType) - })))); - - var fieldDeclaration = FieldDeclaration( - VariableDeclaration(NullableType(qualifiedName))) - .AddDeclarationVariables(VariableDeclarator(fieldName)) - .AddAttributeLists(AttributeList(SingletonSeparatedList( - Attribute(IdentifierName(AttributeDefinitions.GeneratedCode)) - .AddArgumentListArguments( - AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveCommandGenerator).FullName))), - AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveCommandGenerator).Assembly.GetName().Version.ToString()))))))) - .AddModifiers( - Token(SyntaxKind.PrivateKeyword)) - .NormalizeWhitespace(); - - var commandDeclaration = PropertyDeclaration( - qualifiedName, - Identifier(commandName)) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddAccessorListAccessors( - AccessorDeclaration(SyntaxKind.GetAccessorDeclaration) - .WithExpressionBody(ArrowExpressionClause(initializer))) - .AddAttributeLists([.. forwardedPropertyAttributes]) - .NormalizeWhitespace(); - return [fieldDeclaration, commandDeclaration]; - - static ExpressionSyntax GenerateBasicCommand(CommandInfo commandExtensionInfo, string fieldName) - { - var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create; - if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) - { - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName});"); - } + if (methodParameters.Count > 1) + { + return default; // Too many parameters, continue + } - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"); - } + token.ThrowIfCancellationRequested(); + + TryGetCanExecuteExpressionType(methodSymbol, attributeData, out var canExecuteObservableName, out var canExecuteTypeInfo); + + token.ThrowIfCancellationRequested(); + + var methodSyntax = (MethodDeclarationSyntax)context.TargetNode; + GatherForwardedAttributes(methodSymbol, context.SemanticModel, methodSyntax, token, out var attributes); + var forwardedPropertyAttributes = attributes.Select(static a => a.ToString()).ToImmutableArray(); + token.ThrowIfCancellationRequested(); + + // Get the containing type info + var targetInfo = TargetInfo.From(methodSymbol.ContainingType); + + token.ThrowIfCancellationRequested(); + + return new CommandInfo( + targetInfo.FileHintName, + targetInfo.TargetName, + targetInfo.TargetNamespace, + targetInfo.TargetNamespaceWithNamespace, + targetInfo.TargetVisibility, + targetInfo.TargetType, + symbol.Name, + realReturnType.GetFullyQualifiedNameWithNullabilityAnnotations(), + methodParameters.ToImmutable().SingleOrDefault()?.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + isTask, + isReturnTypeVoid, + isObservable, + canExecuteObservableName, + canExecuteTypeInfo, + forwardedPropertyAttributes); + } - static ExpressionSyntax GenerateInOutCommand(CommandInfo commandExtensionInfo, string fieldName, string outputType, string inputType) - { - var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create; - if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) - { - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName});"); - } + private static string GenerateSource(string containingTypeName, string containingNamespace, string containingClassVisibility, string containingType, CommandInfo[] commands) + { + // Generate all member declarations for the current type + var propertyDeclarations = string.Join("\n\r", commands.Select(GetCommandSyntax)); + return +$$""" +// - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"); - } +#pragma warning disable +#nullable enable - static ExpressionSyntax GenerateInCommand(CommandInfo commandExtensionInfo, string fieldName, string inputType) - { - var commandType = commandExtensionInfo.IsTask ? CreateT : Create; - if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) - { - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName});"); - } +namespace {{containingNamespace}} +{ +{{AddTabs(1)}}/// +{{AddTabs(1)}}/// Partial class for the {{containingTypeName}} which contains ReactiveUI ReactiveCommand initialization. +{{AddTabs(1)}}/// +{{AddTabs(1)}}{{containingClassVisibility}} partial {{containingType}} {{containingTypeName}} +{{AddTabs(1)}}{ +{{AddTabs(2)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")] +{{propertyDeclarations}} +{{AddTabs(1)}}} +} +#nullable restore +#pragma warning restore +"""; + } - return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"); - } + private static string GetCommandSyntax(CommandInfo commandExtensionInfo) + { + var outputType = commandExtensionInfo.GetOutputTypeText(); + var inputType = commandExtensionInfo.GetInputTypeText(); + var commandName = GetGeneratedCommandName(commandExtensionInfo.MethodName, commandExtensionInfo.IsTask); + var fieldName = GetGeneratedFieldName(commandName); + + string? initializer; + if (commandExtensionInfo.ArgumentType == null) + { + initializer = GenerateBasicCommand(commandExtensionInfo, fieldName); + } + else if (commandExtensionInfo.ArgumentType != null && commandExtensionInfo.IsReturnTypeVoid) + { + initializer = GenerateInCommand(commandExtensionInfo, fieldName, inputType); + } + else if (commandExtensionInfo.ArgumentType != null && !commandExtensionInfo.IsReturnTypeVoid) + { + initializer = GenerateInOutCommand(commandExtensionInfo, fieldName, outputType, inputType); + } + else + { + return string.Empty; } - internal static bool IsTaskReturnType(ITypeSymbol? typeSymbol) + // Prepare any forwarded property attributes + var forwardedPropertyAttributesString = string.Join("\n\t\t", excludeFromCodeCoverage.Concat(commandExtensionInfo.ForwardedPropertyAttributes)); + + return +$$""" +{{AddTabs(2)}}private ReactiveUI.ReactiveCommand<{{inputType}}, {{outputType}}>? {{fieldName}}; + +{{AddTabs(2)}}{{forwardedPropertyAttributesString}} +{{AddTabs(2)}}public ReactiveUI.ReactiveCommand<{{inputType}}, {{outputType}}> {{commandName}} { get => {{initializer}} } +"""; + + static string GenerateBasicCommand(CommandInfo commandExtensionInfo, string fieldName) { - var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; - do + var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create; + if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) { - var typeName = typeSymbol?.ToDisplayString(nameFormat); - if (typeName == "global::System.Threading.Tasks.Task") - { - return true; - } - - typeSymbol = typeSymbol?.BaseType; + return $"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName});"; } - while (typeSymbol != null); - return false; + return $"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"; } - internal static bool IsObservableReturnType(ITypeSymbol? typeSymbol) + static string GenerateInOutCommand(CommandInfo commandExtensionInfo, string fieldName, string outputType, string inputType) { - var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; - do + var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create; + if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) { - var typeName = typeSymbol?.ToDisplayString(nameFormat); - if (typeName?.Contains("global::System.IObservable") == true) - { - return true; - } - - typeSymbol = typeSymbol?.BaseType; + return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName});"; } - while (typeSymbol != null); - return false; + return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"; } - internal static bool IsObservableBoolType(ITypeSymbol? typeSymbol) + static string GenerateInCommand(CommandInfo commandExtensionInfo, string fieldName, string inputType) { - var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat; - do + var commandType = commandExtensionInfo.IsTask ? CreateT : Create; + if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName)) { - var typeName = typeSymbol?.ToDisplayString(nameFormat); - if (typeName?.Contains("global::System.IObservable") == true) - { - return true; - } - - typeSymbol = typeSymbol?.BaseType; + return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName});"; } - while (typeSymbol != null); - return false; + return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});"; } + } - internal static ITypeSymbol GetTaskReturnType(Compilation compilation, ITypeSymbol typeSymbol) => typeSymbol switch - { - INamedTypeSymbol { TypeArguments.Length: 1 } namedTypeSymbol => namedTypeSymbol.TypeArguments[0], - _ => compilation.GetSpecialType(SpecialType.System_Void) - }; - - /// - /// Tries to get the expression type for the "CanExecute" property, if available. - /// - /// The input instance to process. - /// The instance for . - /// The resulting can execute member name, if available. - /// The resulting expression type, if available. - internal static void TryGetCanExecuteExpressionType( - IMethodSymbol methodSymbol, - AttributeData attributeData, - out string? canExecuteMemberName, - out CanExecuteTypeInfo? canExecuteTypeInfo) + /// + /// Tries to get the expression type for the "CanExecute" property, if available. + /// + /// The input instance to process. + /// The instance for . + /// The resulting can execute member name, if available. + /// The resulting expression type, if available. + private static void TryGetCanExecuteExpressionType( + IMethodSymbol methodSymbol, + AttributeData attributeData, + out string? canExecuteMemberName, + out CanExecuteTypeInfo? canExecuteTypeInfo) + { + // Get the can execute member, if any + if (!attributeData.TryGetNamedArgument(CanExecute, out string? memberName)) { - // Get the can execute member, if any - if (!attributeData.TryGetNamedArgument("CanExecute", out string? memberName)) - { - canExecuteMemberName = null; - canExecuteTypeInfo = null; - - return; - } + canExecuteMemberName = null; + canExecuteTypeInfo = null; - if (memberName is null) - { - goto Failure; - } + return; + } - var canExecuteSymbols = methodSymbol.ContainingType!.GetAllMembers(memberName).ToImmutableArray(); + if (memberName is null) + { + goto Failure; + } - if (canExecuteSymbols.IsEmpty) - { - // Special case for when the target member is a generated property from [ObservableProperty] - if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, out canExecuteTypeInfo)) - { - canExecuteMemberName = memberName; + var canExecuteSymbols = methodSymbol.ContainingType!.GetAllMembers(memberName).ToImmutableArray(); - return; - } - } - else if (canExecuteSymbols.Length > 1) - { - goto Failure; - } - else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], out canExecuteTypeInfo)) + if (canExecuteSymbols.IsEmpty) + { + // Special case for when the target member is a generated property from [ObservableProperty] + if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, out canExecuteTypeInfo)) { canExecuteMemberName = memberName; return; } - - Failure: - canExecuteMemberName = null; - canExecuteTypeInfo = null; + } + else if (canExecuteSymbols.Length > 1) + { + goto Failure; + } + else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], out canExecuteTypeInfo)) + { + canExecuteMemberName = memberName; return; } - /// - /// Gets the expression type for the can execute logic, if possible. - /// - /// The can execute member symbol (either a method or a property). - /// The resulting can execute expression type, if available. - /// Whether or not was set and the input symbol was valid. - internal static bool TryGetCanExecuteExpressionFromSymbol( - ISymbol canExecuteSymbol, - [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo) - { - if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol) - { - // The return type must always be a bool - if (!IsObservableBoolType(canExecuteMethodSymbol.ReturnType)) - { - goto Failure; - } - - // If the method has parameters, it has to have a single one matching the command type - if (canExecuteMethodSymbol.Parameters.Length == 1) - { - goto Failure; - } + Failure: + canExecuteMemberName = null; + canExecuteTypeInfo = null; - // Parameterless methods are always valid - if (canExecuteMethodSymbol.Parameters.IsEmpty) - { - canExecuteTypeInfo = CanExecuteTypeInfo.MethodObservable; + return; + } - return true; - } + /// + /// Gets the expression type for the can execute logic, if possible. + /// + /// The can execute member symbol (either a method or a property). + /// The resulting can execute expression type, if available. + /// Whether or not was set and the input symbol was valid. + private static bool TryGetCanExecuteExpressionFromSymbol( + ISymbol canExecuteSymbol, + [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo) + { + if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol) + { + // The return type must always be a bool + if (!canExecuteMethodSymbol.ReturnType.IsObservableBoolType()) + { + goto Failure; } - else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol) + + // If the method has parameters, it has to have a single one matching the command type + if (canExecuteMethodSymbol.Parameters.Length == 1) { - // The property type must always be a bool - if (!IsObservableBoolType(canExecutePropertySymbol.Type)) - { - goto Failure; - } + goto Failure; + } - canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable; + // Parameterless methods are always valid + if (canExecuteMethodSymbol.Parameters.IsEmpty) + { + canExecuteTypeInfo = CanExecuteTypeInfo.MethodObservable; return true; } - else if (canExecuteSymbol is IFieldSymbol canExecuteFieldSymbol) + } + else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol) + { + // The property type must always be a bool + if (!canExecutePropertySymbol.Type.IsObservableBoolType()) { - // The property type must always be a bool - if (!IsObservableBoolType(canExecuteFieldSymbol.Type)) - { - goto Failure; - } + goto Failure; + } - canExecuteTypeInfo = CanExecuteTypeInfo.FieldObservable; + canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable; - return true; + return true; + } + else if (canExecuteSymbol is IFieldSymbol canExecuteFieldSymbol) + { + // The property type must always be a bool + if (!canExecuteFieldSymbol.Type.IsObservableBoolType()) + { + goto Failure; } - Failure: - canExecuteTypeInfo = null; + canExecuteTypeInfo = CanExecuteTypeInfo.FieldObservable; - return false; + return true; } - /// - /// Gets the expression type for the can execute logic, if possible. - /// - /// The member name passed to [ReactiveCommand(CanExecute = ...)]. - /// The containing type for the method annotated with [ReactiveCommand]. - /// The resulting can execute expression type, if available. - /// Whether or not was set and the input symbol was valid. - internal static bool TryGetCanExecuteMemberFromGeneratedProperty( - string memberName, - INamedTypeSymbol containingType, - [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo) + Failure: + canExecuteTypeInfo = null; + + return false; + } + + /// + /// Gets the expression type for the can execute logic, if possible. + /// + /// The member name passed to [ReactiveCommand(CanExecute = ...)]. + /// The containing type for the method annotated with [ReactiveCommand]. + /// The resulting can execute expression type, if available. + /// Whether or not was set and the input symbol was valid. + private static bool TryGetCanExecuteMemberFromGeneratedProperty( + string memberName, + INamedTypeSymbol containingType, + [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo) + { + foreach (var memberSymbol in containingType.GetAllMembers()) { - foreach (var memberSymbol in containingType.GetAllMembers()) + // Only look for instance fields of Observable bool type + if (!memberSymbol.ContainingType.IsObservableBoolType() || memberSymbol is not IFieldSymbol fieldSymbol) { - // Only look for instance fields of Observable bool type - if (!IsObservableBoolType(memberSymbol.ContainingType) || memberSymbol is not IFieldSymbol fieldSymbol) - { - continue; - } + continue; + } - var attributes = memberSymbol.GetAttributes(); + var attributes = memberSymbol.GetAttributes(); - // Only filter fields with the [Reactive] attribute - if (memberSymbol is IFieldSymbol && - !attributes.Any(static a => a.AttributeClass?.HasFullyQualifiedMetadataName( - "ReactiveUI.SourceGenerators.ReactiveAttribute") == true)) - { - continue; - } + // Only filter fields with the [Reactive] attribute + if (memberSymbol is IFieldSymbol && + !attributes.Any(static a => a.AttributeClass?.HasFullyQualifiedMetadataName( + AttributeDefinitions.ReactiveAttributeType) == true)) + { + continue; + } - // Get the target property name either directly or matching the generated one - var propertyName = fieldSymbol.GetGeneratedPropertyName(); + // Get the target property name either directly or matching the generated one + var propertyName = fieldSymbol.GetGeneratedPropertyName(); - // If the generated property name matches, get the right expression type - if (memberName == propertyName) - { - canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable; + // If the generated property name matches, get the right expression type + if (memberName == propertyName) + { + canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable; - return true; - } + return true; } + } - canExecuteTypeInfo = null; + canExecuteTypeInfo = null; - return false; - } + return false; + } + + /// + /// Gathers all forwarded attributes for the generated field and property. + /// + /// The input instance to process. + /// The instance for the current run. + /// The method declaration. + /// The cancellation token for the current operation. + /// The resulting property attributes to forward. + private static void GatherForwardedAttributes( + IMethodSymbol methodSymbol, + SemanticModel semanticModel, + MethodDeclarationSyntax methodDeclaration, + CancellationToken token, + out ImmutableArray propertyAttributes) + { + using var propertyAttributesInfo = ImmutableArrayBuilder.Rent(); - /// - /// Gathers all forwarded attributes for the generated field and property. - /// - /// The input instance to process. - /// The instance for the current run. - /// The method declaration. - /// The cancellation token for the current operation. - /// The resulting property attributes to forward. - internal static void GatherForwardedAttributes( + static void GatherForwardedAttributes( IMethodSymbol methodSymbol, SemanticModel semanticModel, MethodDeclarationSyntax methodDeclaration, CancellationToken token, - out ImmutableArray propertyAttributes) + ImmutableArrayBuilder propertyAttributesInfo) { - using var propertyAttributesInfo = ImmutableArrayBuilder.Rent(); - - static void GatherForwardedAttributes( - IMethodSymbol methodSymbol, - SemanticModel semanticModel, - MethodDeclarationSyntax methodDeclaration, - CancellationToken token, - ImmutableArrayBuilder propertyAttributesInfo) + // Get the single syntax reference for the input method symbol (there should be only one) + if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference]) + { + return; + } + + // Gather explicit forwarded attributes info + foreach (var attributeList in methodDeclaration.AttributeLists) { - // Get the single syntax reference for the input method symbol (there should be only one) - if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference]) + if (attributeList.Target?.Identifier is not SyntaxToken(SyntaxKind.PropertyKeyword)) { - return; + continue; } - // Gather explicit forwarded attributes info - foreach (var attributeList in methodDeclaration.AttributeLists) + foreach (var attribute in attributeList.Attributes) { - if (attributeList.Target?.Identifier is not SyntaxToken(SyntaxKind.PropertyKeyword)) + if (!semanticModel.GetSymbolInfo(attribute, token).TryGetAttributeTypeSymbol(out var attributeTypeSymbol)) { continue; } - foreach (var attribute in attributeList.Attributes) + var attributeArguments = attribute.ArgumentList?.Arguments ?? Enumerable.Empty(); + + // Try to extract the forwarded attribute + if (!AttributeInfo.TryCreate(attributeTypeSymbol, semanticModel, attributeArguments, token, out var attributeInfo)) + { + continue; + } + + // Add the new attribute info to the right builder + if (attributeList.Target?.Identifier is SyntaxToken(SyntaxKind.PropertyKeyword)) { - if (!semanticModel.GetSymbolInfo(attribute, token).TryGetAttributeTypeSymbol(out var attributeTypeSymbol)) - { - continue; - } - - var attributeArguments = attribute.ArgumentList?.Arguments ?? Enumerable.Empty(); - - // Try to extract the forwarded attribute - if (!AttributeInfo.TryCreate(attributeTypeSymbol, semanticModel, attributeArguments, token, out var attributeInfo)) - { - continue; - } - - // Add the new attribute info to the right builder - if (attributeList.Target?.Identifier is SyntaxToken(SyntaxKind.PropertyKeyword)) - { - propertyAttributesInfo.Add(attributeInfo); - } + propertyAttributesInfo.Add(attributeInfo); } } } + } - // If the method is a partial definition, also gather attributes from the implementation part - if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null }) - { - var partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol; - var partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol; - - // We always give priority to the partial definition, to ensure a predictable and testable ordering - GatherForwardedAttributes(partialDefinition, semanticModel, methodDeclaration, token, propertyAttributesInfo); - GatherForwardedAttributes(partialImplementation, semanticModel, methodDeclaration, token, propertyAttributesInfo); - } - else - { - // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications - GatherForwardedAttributes(methodSymbol, semanticModel, methodDeclaration, token, propertyAttributesInfo); - } + // If the method is a partial definition, also gather attributes from the implementation part + if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null }) + { + var partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol; + var partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol; - propertyAttributes = propertyAttributesInfo.ToImmutable(); + // We always give priority to the partial definition, to ensure a predictable and testable ordering + GatherForwardedAttributes(partialDefinition, semanticModel, methodDeclaration, token, propertyAttributesInfo); + GatherForwardedAttributes(partialImplementation, semanticModel, methodDeclaration, token, propertyAttributesInfo); } - - internal static string GetGeneratedCommandName(string methodName, bool isAsync) + else { - var commandName = methodName; + // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications + GatherForwardedAttributes(methodSymbol, semanticModel, methodDeclaration, token, propertyAttributesInfo); + } - if (commandName.StartsWith("m_")) - { - commandName = commandName.Substring(2); - } - else if (commandName.StartsWith("_")) - { - commandName = commandName.TrimStart('_'); - } + propertyAttributes = propertyAttributesInfo.ToImmutable(); + } - if (commandName.EndsWith("Async") && isAsync) - { - commandName = commandName.Substring(0, commandName.Length - "Async".Length); - } + private static string GetGeneratedCommandName(string methodName, bool isAsync) + { + var commandName = methodName; + + if (commandName.StartsWith("m_")) + { + commandName = commandName.Substring(2); + } + else if (commandName.StartsWith("_")) + { + commandName = commandName.TrimStart('_'); + } - return $"{char.ToUpper(commandName[0], CultureInfo.InvariantCulture)}{commandName.Substring(1)}Command"; + if (commandName.EndsWith("Async") && isAsync) + { + commandName = commandName.Substring(0, commandName.Length - "Async".Length); } - internal static string GetGeneratedFieldName(string generatedCommandName) => - $"_{char.ToLower(generatedCommandName[0], CultureInfo.InvariantCulture)}{generatedCommandName.Substring(1)}"; + return $"{char.ToUpper(commandName[0], CultureInfo.InvariantCulture)}{commandName.Substring(1)}Command"; } + + private static string GetGeneratedFieldName(string generatedCommandName) => + $"_{char.ToLower(generatedCommandName[0], CultureInfo.InvariantCulture)}{generatedCommandName.Substring(1)}"; + + private static string AddTabs(int tabCount) => new('\t', tabCount); } diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs index 10bcb24..47ca75e 100644 --- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs +++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs @@ -3,20 +3,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for full license information. -using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; using ReactiveUI.SourceGenerators.Extensions; using ReactiveUI.SourceGenerators.Helpers; -using ReactiveUI.SourceGenerators.Input.Models; -using ReactiveUI.SourceGenerators.Models; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace ReactiveUI.SourceGenerators; @@ -33,137 +27,44 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ctx.AddSource($"{AttributeDefinitions.ReactiveCommandAttributeType}.g.cs", SourceText.From(AttributeDefinitions.ReactiveCommandAttribute, Encoding.UTF8))); // Gather info for all annotated command methods (starting from method declarations with at least one attribute) - IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> commandInfoWithErrors = + var commandInfo = context.SyntaxProvider .ForAttributeWithMetadataName( AttributeDefinitions.ReactiveCommandAttributeType, static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax or RecordDeclarationSyntax, AttributeLists.Count: > 0 }, - static (context, token) => - { - CommandInfo? commandExtensionInfos = default; - HierarchyInfo? hierarchy = default; - using var diagnostics = ImmutableArrayBuilder.Rent(); - - var methodSyntax = (MethodDeclarationSyntax)context.TargetNode; - var symbol = ModelExtensions.GetDeclaredSymbol(context.SemanticModel, methodSyntax, token)!; - token.ThrowIfCancellationRequested(); - - // Skip symbols without the target attribute - if (!symbol.TryGetAttributeWithFullyQualifiedMetadataName(AttributeDefinitions.ReactiveCommandAttributeType, out var attributeData)) - { - return default; - } - - token.ThrowIfCancellationRequested(); - if (attributeData != null) - { - var compilation = context.SemanticModel.Compilation; - var methodSymbol = (IMethodSymbol)symbol!; - var isTask = Execute.IsTaskReturnType(methodSymbol.ReturnType); - var isObservable = Execute.IsObservableReturnType(methodSymbol.ReturnType); - var realReturnType = isTask || isObservable ? Execute.GetTaskReturnType(compilation, methodSymbol.ReturnType) : methodSymbol.ReturnType; - var isReturnTypeVoid = SymbolEqualityComparer.Default.Equals(realReturnType, compilation.GetSpecialType(SpecialType.System_Void)); - var hasCancellationToken = isTask && methodSymbol.Parameters.Any(x => x.Type.ToDisplayString() == "System.Threading.CancellationToken"); - var methodParameters = new List(); - if (hasCancellationToken && methodSymbol.Parameters.Length == 2) - { - methodParameters.Add(methodSymbol.Parameters[0]); - } - else if (!hasCancellationToken) - { - methodParameters.AddRange(methodSymbol.Parameters); - } - - if (methodParameters.Count > 1) - { - return default; // Too many parameters, continue - } - - token.ThrowIfCancellationRequested(); - - // Get the hierarchy info for the target symbol, and try to gather the command info - hierarchy = HierarchyInfo.From(methodSymbol.ContainingType); - - // Get the CanExecute expression type, if any - Execute.TryGetCanExecuteExpressionType( - methodSymbol, - attributeData, - out var canExecuteMemberName, - out var canExecuteTypeInfo); - - token.ThrowIfCancellationRequested(); - - Execute.GatherForwardedAttributes( - methodSymbol, - context.SemanticModel, - methodSyntax, - token, - out var forwardedAttributes); + static (context, token) => GetMethodInfo(context, token)) + .Where(x => x != null) + .Select((x, _) => x!) + .Collect(); - token.ThrowIfCancellationRequested(); - - // Get the containing type info - var targetInfo = TargetInfo.From(methodSymbol.ContainingType); - - token.ThrowIfCancellationRequested(); - - commandExtensionInfos = new( - targetInfo.FileHintName, - targetInfo.TargetName, - targetInfo.TargetNamespace, - targetInfo.TargetNamespaceWithNamespace, - targetInfo.TargetVisibility, - targetInfo.TargetType, - methodSymbol.Name, - realReturnType.GetFullyQualifiedNameWithNullabilityAnnotations(), - methodParameters?.SingleOrDefault()?.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - isTask, - isReturnTypeVoid, - isObservable, - canExecuteMemberName, - canExecuteTypeInfo, - forwardedAttributes); - } - - token.ThrowIfCancellationRequested(); - return (Hierarchy: hierarchy, new Result(commandExtensionInfos, diagnostics.ToImmutable())); - }) - .Where(static item => item.Hierarchy is not null)!; + // Generate the requested properties and methods + context.RegisterSourceOutput(commandInfo, static (context, input) => + { + var groupedcommandInfo = input.GroupBy( + static info => (info.FileHintName, info.TargetName, info.TargetNamespace, info.TargetVisibility, info.TargetType), + static info => info) + .ToImmutableArray(); - // Get the filtered sequence to enable caching - var propertyInfo = - commandInfoWithErrors - .Where(static item => item.Info.Value is not null)!; + if (groupedcommandInfo.Length == 0) + { + return; + } - // Split and group by containing type - var groupedPropertyInfo = - propertyInfo - .GroupBy(static item => item.Left, static item => item.Right.Value); + foreach (var grouping in groupedcommandInfo) + { + var items = grouping.ToImmutableArray(); - // Generate the requested properties and methods - context.RegisterSourceOutput(groupedPropertyInfo, static (context, item) => - { - var commandInfos = item.Right.ToArray(); + if (items.Length == 0) + { + continue; + } - // Generate all member declarations for the current type - var propertyDeclarations = - commandInfos - .SelectMany(Execute.GetCommandProperty) - .ToList(); + var (fileHintName, targetName, targetNamespace, targetVisibility, targetType) = grouping.Key; - var c = Execute.GetCommandInitiliser(); - propertyDeclarations.Add(c); - var memberDeclarations = propertyDeclarations.ToImmutableArray(); + var source = GenerateSource(targetName, targetNamespace, targetVisibility, targetType, [.. grouping]); - // Insert all members into the same partial type declaration - var compilationUnit = item.Key.GetCompilationUnit(memberDeclarations) - .WithLeadingTrivia(TriviaList( - Comment("// "), - Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)), - Trivia(NullableDirectiveTrivia(Token(SyntaxKind.EnableKeyword), true)), - CarriageReturn)) - .NormalizeWhitespace(); - context.AddSource($"{item.Key.FilenameHint}.ReactiveCommands.g.cs", compilationUnit); + context.AddSource($"{fileHintName}.ReactiveCommands.g.cs", source); + } }); } }