diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs
index 4f33d1f..2a08ae3 100644
--- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs
+++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperiesWithAttributes#TestVM.Properties.g.verified.cs
@@ -10,9 +10,9 @@ namespace TestNs
///
/// Partial class for the TestVM which contains ReactiveUI Reactive property initialization.
///
- [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
public partial class TestVM
{
+ [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
///
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
[System.Text.Json.Serialization.JsonInclude]
diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs
index b8c5c8b..fe34cd1 100644
--- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs
+++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactiveProperties#TestVM.Properties.g.verified.cs
@@ -10,9 +10,9 @@ namespace TestNs
///
/// Partial class for the TestVM which contains ReactiveUI Reactive property initialization.
///
- [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
public partial class TestVM
{
+ [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
///
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public int Test1
diff --git a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs
index 91b96a1..155c12f 100644
--- a/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs
+++ b/src/ReactiveUI.SourceGenerator.Tests/REACTIVE/ReactiveGeneratorTests.FromReactivePropertiesWithAccess#TestVM.Properties.g.verified.cs
@@ -10,9 +10,9 @@ namespace TestNs
///
/// Partial class for the TestVM which contains ReactiveUI Reactive property initialization.
///
- [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
public partial class TestVM
{
+ [global::System.CodeDom.Compiler.GeneratedCode("ReactiveUI.SourceGenerators.ReactiveGenerator", "1.1.0.0")]
///
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public int Test2
diff --git a/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs b/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs
index efeb8d4..82eca67 100644
--- a/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs
+++ b/src/ReactiveUI.SourceGenerator.Tests/TestHelper.cs
@@ -47,6 +47,10 @@ public sealed class TestHelper(ITestOutputHelper testOutput) : IDisposable
new("ReactiveUI", VersionRange.AllStableFloating, LibraryDependencyTarget.Package);
#pragma warning restore CS0618 // Type or member is obsolete
+ private static readonly string mscorlibPath = Path.Combine(
+ System.Runtime.InteropServices.RuntimeEnvironment.GetRuntimeDirectory(),
+ "mscorlib.dll");
+
private static readonly MetadataReference[] References =
[
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
@@ -54,6 +58,9 @@ public sealed class TestHelper(ITestOutputHelper testOutput) : IDisposable
MetadataReference.CreateFromFile(typeof(T).Assembly.Location),
MetadataReference.CreateFromFile(typeof(TestHelper).Assembly.Location),
+ // Create mscorlib Reference
+ MetadataReference.CreateFromFile(mscorlibPath)
+
// Wpf references
////MetadataReference.CreateFromFile(Assembly.Load("PresentationCore").Location),
////MetadataReference.CreateFromFile(Assembly.Load("PresentationFramework").Location),
@@ -127,19 +134,22 @@ public void TestFail(
/// Tests a generator expecting it to pass successfully.
///
/// The source code to test.
+ /// if set to true [with pre diagnosics].
///
/// The driver.
///
+ /// Must have valid compiler instance.
/// callerType.
public GeneratorDriver TestPass(
- string source)
+ string source,
+ bool withPreDiagnosics = false)
{
if (_eventCompiler is null)
{
throw new InvalidOperationException("Must have valid compiler instance.");
}
- return RunGeneratorAndCheck(source);
+ return RunGeneratorAndCheck(source, withPreDiagnosics);
}
///
@@ -149,13 +159,15 @@ public GeneratorDriver TestPass(
/// Runs the specified source generator and validates the generated code.
///
/// The code to be parsed and processed by the generator.
+ /// if set to true [with pre diagnosics].
/// Indicates whether to rerun the compilation after running the generator.
- /// The generator driver used to run the generator.
- ///
- /// Thrown if the compiler instance is not valid or if the compilation fails.
- ///
+ ///
+ /// The generator driver used to run the generator.
+ ///
+ /// Thrown if the compiler instance is not valid or if the compilation fails.
public GeneratorDriver RunGeneratorAndCheck(
string code,
+ bool withPreDiagnosics = false,
bool rerunCompilation = true)
{
if (_eventCompiler is null)
@@ -180,11 +192,14 @@ public GeneratorDriver RunGeneratorAndCheck(
assemblies,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, deterministic: true));
- // Validate diagnostics before running the generator.
- ////var prediagnostics = compilation.GetDiagnostics()
- //// .Where(d => !d.Id.Contains("CS0518") && d.Severity > DiagnosticSeverity.Warning)
- //// .ToList();
- ////prediagnostics.Should().BeEmpty();
+ if (withPreDiagnosics)
+ {
+ // Validate diagnostics before running the generator.
+ var prediagnostics = compilation.GetDiagnostics()
+ .Where(d => d.Severity > DiagnosticSeverity.Warning)
+ .ToList();
+ prediagnostics.Should().BeEmpty();
+ }
var generator = new T();
var driver = CSharpGeneratorDriver.Create(generator).WithUpdatedParseOptions((CSharpParseOptions)syntaxTree.Options);
diff --git a/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs b/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs
index 1af3855..20ff1fa 100644
--- a/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs
+++ b/src/ReactiveUI.SourceGenerator.Tests/UnitTests/ReactiveCMDGeneratorTests.cs
@@ -34,10 +34,7 @@ namespace TestNs;
public partial class TestVM : ReactiveObject
{
[ReactiveCommand]
- private void Test1()
- {
- var a = 10;
- }
+ private int Test1() => 10;
}
""";
@@ -68,10 +65,8 @@ namespace TestNs;
public partial class TestVM : ReactiveObject
{
[ReactiveCommand]
- private void Test3(string baseString)
- {
- var a = baseString;
- }
+ [property: JsonInclude]
+ private int Test3(string baseString) => int.Parse(baseString);
}
""";
diff --git a/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs b/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs
index f048339..7fd8eef 100644
--- a/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs
+++ b/src/ReactiveUI.SourceGenerators/Core/Extensions/ITypeSymbolExtensions.cs
@@ -234,6 +234,66 @@ public static string GetFullyQualifiedMetadataName(this ITypeSymbol symbol)
return builder.ToString();
}
+ public static bool IsTaskReturnType(this ITypeSymbol? typeSymbol)
+ {
+ var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
+ do
+ {
+ var typeName = typeSymbol?.ToDisplayString(nameFormat);
+ if (typeName == "global::System.Threading.Tasks.Task")
+ {
+ return true;
+ }
+
+ typeSymbol = typeSymbol?.BaseType;
+ }
+ while (typeSymbol != null);
+
+ return false;
+ }
+
+ public static bool IsObservableReturnType(this ITypeSymbol? typeSymbol)
+ {
+ var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
+ do
+ {
+ var typeName = typeSymbol?.ToDisplayString(nameFormat);
+ if (typeName?.Contains("global::System.IObservable") == true)
+ {
+ return true;
+ }
+
+ typeSymbol = typeSymbol?.BaseType;
+ }
+ while (typeSymbol != null);
+
+ return false;
+ }
+
+ public static bool IsObservableBoolType(this ITypeSymbol? typeSymbol)
+ {
+ var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
+ do
+ {
+ var typeName = typeSymbol?.ToDisplayString(nameFormat);
+ if (typeName?.Contains("global::System.IObservable") == true)
+ {
+ return true;
+ }
+
+ typeSymbol = typeSymbol?.BaseType;
+ }
+ while (typeSymbol != null);
+
+ return false;
+ }
+
+ public static ITypeSymbol GetTaskReturnType(this ITypeSymbol typeSymbol, Compilation compilation) => typeSymbol switch
+ {
+ INamedTypeSymbol { TypeArguments.Length: 1 } namedTypeSymbol => namedTypeSymbol.TypeArguments[0],
+ _ => compilation.GetSpecialType(SpecialType.System_Void)
+ };
+
///
/// Appends the fully qualified metadata name for a given symbol to a target builder.
///
diff --git a/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs b/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs
index c449cd5..3b0fe81 100644
--- a/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs
+++ b/src/ReactiveUI.SourceGenerators/Core/Helpers/AttributeInfo.cs
@@ -131,4 +131,6 @@ public AttributeSyntax GetSyntax()
return Attribute(IdentifierName(TypeName), AttributeArgumentList(SeparatedList(arguments.Concat(namedArguments))));
}
+
+ public override string ToString() => $"[{GetSyntax()}]";
}
diff --git a/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs b/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs
index d7c7fb5..43fe70f 100644
--- a/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs
+++ b/src/ReactiveUI.SourceGenerators/Reactive/ReactiveGenerator.Execute.cs
@@ -147,9 +147,9 @@ namespace {{containingNamespace}}
{{AddTabs(1)}}///
{{AddTabs(1)}}/// Partial class for the {{containingTypeName}} which contains ReactiveUI Reactive property initialization.
{{AddTabs(1)}}///
-{{AddTabs(1)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")]
{{AddTabs(1)}}{{containingClassVisibility}} partial {{containingType}} {{containingTypeName}}
{{AddTabs(1)}}{
+{{AddTabs(2)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")]
{{propertyDeclarations}}
{{AddTabs(1)}}}
}
diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs
index e3da06b..482ed98 100644
--- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs
+++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/Models/CommandInfo.cs
@@ -22,7 +22,7 @@ internal record CommandInfo(
bool IsObservable,
string? CanExecuteObservableName,
CanExecuteTypeInfo? CanExecuteTypeInfo,
- EquatableArray ForwardedPropertyAttributes)
+ EquatableArray ForwardedPropertyAttributes)
{
private const string UnitTypeName = "global::System.Reactive.Unit";
diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs
index f8daf7d..e1236cf 100644
--- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs
+++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.Execute.cs
@@ -3,6 +3,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for full license information.
+using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
@@ -14,7 +15,7 @@
using ReactiveUI.SourceGenerators.Extensions;
using ReactiveUI.SourceGenerators.Helpers;
using ReactiveUI.SourceGenerators.Input.Models;
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using ReactiveUI.SourceGenerators.Models;
namespace ReactiveUI.SourceGenerators;
@@ -24,467 +25,461 @@ namespace ReactiveUI.SourceGenerators;
///
public partial class ReactiveCommandGenerator
{
+ internal static readonly string GeneratorName = typeof(ReactiveCommandGenerator).FullName!;
+ internal static readonly string GeneratorVersion = typeof(ReactiveCommandGenerator).Assembly.GetName().Version.ToString();
+
private const string ReactiveUI = "ReactiveUI";
private const string ReactiveCommand = "ReactiveCommand";
private const string RxCmd = ReactiveUI + "." + ReactiveCommand;
private const string Create = ".Create";
private const string CreateO = ".CreateFromObservable";
private const string CreateT = ".CreateFromTask";
- private const string ObsoleteReason = "Commands are initialized automatically. Method will be removed in future version.";
+ private const string CanExecute = "CanExecute";
+ private static readonly string[] excludeFromCodeCoverage = ["[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]"];
- ///
- /// A container for all the logic for .
- ///
- internal static class Execute
+ private static CommandInfo? GetMethodInfo(in GeneratorAttributeSyntaxContext context, CancellationToken token)
{
- internal static MethodDeclarationSyntax GetCommandInitiliser() => MethodDeclaration(
- PredefinedType(Token(SyntaxKind.VoidKeyword)),
- Identifier("InitializeCommands"))
- .AddAttributeLists(
- AttributeList(SingletonSeparatedList(
- Attribute(IdentifierName(AttributeDefinitions.GeneratedCode))
- .AddArgumentListArguments(
- AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveGenerator).FullName))),
- AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveGenerator).Assembly.GetName().Version.ToString())))))),
- AttributeList(SingletonSeparatedList(Attribute(IdentifierName(AttributeDefinitions.ExcludeFromCodeCoverage)))),
- AttributeList(SingletonSeparatedList(Attribute(IdentifierName(AttributeDefinitions.Obsolete))
- .AddArgumentListArguments(
- AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ObsoleteReason)))))))
- .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword)))
- .WithBody(Block());
-
- internal static MemberDeclarationSyntax[] GetCommandProperty(CommandInfo commandExtensionInfo)
+ var symbol = context.TargetSymbol;
+ if (!symbol.TryGetAttributeWithFullyQualifiedMetadataName(AttributeDefinitions.ReactiveCommandAttributeType, out var attributeData))
{
- var outputType = commandExtensionInfo.GetOutputTypeText();
- var inputType = commandExtensionInfo.GetInputTypeText();
- var commandName = GetGeneratedCommandName(commandExtensionInfo.MethodName, commandExtensionInfo.IsTask);
- var fieldName = GetGeneratedFieldName(commandName);
+ return null;
+ }
- ExpressionSyntax initializer;
- if (commandExtensionInfo.ArgumentType == null)
- {
- initializer = GenerateBasicCommand(commandExtensionInfo, fieldName);
- }
- else if (commandExtensionInfo.ArgumentType != null && commandExtensionInfo.IsReturnTypeVoid)
- {
- initializer = GenerateInCommand(commandExtensionInfo, fieldName, inputType);
- }
- else if (commandExtensionInfo.ArgumentType != null && !commandExtensionInfo.IsReturnTypeVoid)
- {
- initializer = GenerateInOutCommand(commandExtensionInfo, fieldName, outputType, inputType);
- }
- else
+ if (symbol is not IMethodSymbol methodSymbol)
+ {
+ return default;
+ }
+
+ token.ThrowIfCancellationRequested();
+
+ var isTask = methodSymbol.ReturnType.IsTaskReturnType();
+ var isObservable = methodSymbol.ReturnType.IsObservableReturnType();
+
+ var compilation = context.SemanticModel.Compilation;
+ var realReturnType = isTask || isObservable ? methodSymbol.ReturnType.GetTaskReturnType(compilation) : methodSymbol.ReturnType;
+ var isReturnTypeVoid = SymbolEqualityComparer.Default.Equals(realReturnType, compilation.GetSpecialType(SpecialType.System_Void));
+ var hasCancellationToken = isTask && methodSymbol.Parameters.Any(x => x.Type.ToDisplayString() == "System.Threading.CancellationToken");
+
+ using var methodParameters = ImmutableArrayBuilder.Rent();
+ if (hasCancellationToken && methodSymbol.Parameters.Length == 2)
+ {
+ methodParameters.Add(methodSymbol.Parameters[0]);
+ }
+ else if (!hasCancellationToken)
+ {
+ foreach (var parameter in methodSymbol.Parameters)
{
- return [];
+ methodParameters.Add(parameter);
}
+ }
- // Prepare any forwarded property attributes
- var forwardedPropertyAttributes =
- commandExtensionInfo.ForwardedPropertyAttributes
- .Select(static a => AttributeList(SingletonSeparatedList(a.GetSyntax())))
- .ToImmutableArray();
-
- var qualifiedName = QualifiedName(
- IdentifierName(ReactiveUI),
- GenericName(
- Identifier(ReactiveCommand))
- .WithTypeArgumentList(
- TypeArgumentList(
- SeparatedList(
- new SyntaxNodeOrToken[]
- {
- IdentifierName(inputType),
- Token(SyntaxKind.CommaToken),
- IdentifierName(outputType)
- }))));
-
- var fieldDeclaration = FieldDeclaration(
- VariableDeclaration(NullableType(qualifiedName)))
- .AddDeclarationVariables(VariableDeclarator(fieldName))
- .AddAttributeLists(AttributeList(SingletonSeparatedList(
- Attribute(IdentifierName(AttributeDefinitions.GeneratedCode))
- .AddArgumentListArguments(
- AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveCommandGenerator).FullName))),
- AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ReactiveCommandGenerator).Assembly.GetName().Version.ToString())))))))
- .AddModifiers(
- Token(SyntaxKind.PrivateKeyword))
- .NormalizeWhitespace();
-
- var commandDeclaration = PropertyDeclaration(
- qualifiedName,
- Identifier(commandName))
- .AddModifiers(Token(SyntaxKind.PublicKeyword))
- .AddAccessorListAccessors(
- AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
- .WithExpressionBody(ArrowExpressionClause(initializer)))
- .AddAttributeLists([.. forwardedPropertyAttributes])
- .NormalizeWhitespace();
- return [fieldDeclaration, commandDeclaration];
-
- static ExpressionSyntax GenerateBasicCommand(CommandInfo commandExtensionInfo, string fieldName)
- {
- var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create;
- if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
- {
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName});");
- }
+ if (methodParameters.Count > 1)
+ {
+ return default; // Too many parameters, continue
+ }
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});");
- }
+ token.ThrowIfCancellationRequested();
+
+ TryGetCanExecuteExpressionType(methodSymbol, attributeData, out var canExecuteObservableName, out var canExecuteTypeInfo);
+
+ token.ThrowIfCancellationRequested();
+
+ var methodSyntax = (MethodDeclarationSyntax)context.TargetNode;
+ GatherForwardedAttributes(methodSymbol, context.SemanticModel, methodSyntax, token, out var attributes);
+ var forwardedPropertyAttributes = attributes.Select(static a => a.ToString()).ToImmutableArray();
+ token.ThrowIfCancellationRequested();
+
+ // Get the containing type info
+ var targetInfo = TargetInfo.From(methodSymbol.ContainingType);
+
+ token.ThrowIfCancellationRequested();
+
+ return new CommandInfo(
+ targetInfo.FileHintName,
+ targetInfo.TargetName,
+ targetInfo.TargetNamespace,
+ targetInfo.TargetNamespaceWithNamespace,
+ targetInfo.TargetVisibility,
+ targetInfo.TargetType,
+ symbol.Name,
+ realReturnType.GetFullyQualifiedNameWithNullabilityAnnotations(),
+ methodParameters.ToImmutable().SingleOrDefault()?.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
+ isTask,
+ isReturnTypeVoid,
+ isObservable,
+ canExecuteObservableName,
+ canExecuteTypeInfo,
+ forwardedPropertyAttributes);
+ }
- static ExpressionSyntax GenerateInOutCommand(CommandInfo commandExtensionInfo, string fieldName, string outputType, string inputType)
- {
- var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create;
- if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
- {
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName});");
- }
+ private static string GenerateSource(string containingTypeName, string containingNamespace, string containingClassVisibility, string containingType, CommandInfo[] commands)
+ {
+ // Generate all member declarations for the current type
+ var propertyDeclarations = string.Join("\n\r", commands.Select(GetCommandSyntax));
+ return
+$$"""
+//
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});");
- }
+#pragma warning disable
+#nullable enable
- static ExpressionSyntax GenerateInCommand(CommandInfo commandExtensionInfo, string fieldName, string inputType)
- {
- var commandType = commandExtensionInfo.IsTask ? CreateT : Create;
- if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
- {
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName});");
- }
+namespace {{containingNamespace}}
+{
+{{AddTabs(1)}}///
+{{AddTabs(1)}}/// Partial class for the {{containingTypeName}} which contains ReactiveUI ReactiveCommand initialization.
+{{AddTabs(1)}}///
+{{AddTabs(1)}}{{containingClassVisibility}} partial {{containingType}} {{containingTypeName}}
+{{AddTabs(1)}}{
+{{AddTabs(2)}}[global::System.CodeDom.Compiler.GeneratedCode("{{GeneratorName}}", "{{GeneratorVersion}}")]
+{{propertyDeclarations}}
+{{AddTabs(1)}}}
+}
+#nullable restore
+#pragma warning restore
+""";
+ }
- return ParseExpression($"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});");
- }
+ private static string GetCommandSyntax(CommandInfo commandExtensionInfo)
+ {
+ var outputType = commandExtensionInfo.GetOutputTypeText();
+ var inputType = commandExtensionInfo.GetInputTypeText();
+ var commandName = GetGeneratedCommandName(commandExtensionInfo.MethodName, commandExtensionInfo.IsTask);
+ var fieldName = GetGeneratedFieldName(commandName);
+
+ string? initializer;
+ if (commandExtensionInfo.ArgumentType == null)
+ {
+ initializer = GenerateBasicCommand(commandExtensionInfo, fieldName);
+ }
+ else if (commandExtensionInfo.ArgumentType != null && commandExtensionInfo.IsReturnTypeVoid)
+ {
+ initializer = GenerateInCommand(commandExtensionInfo, fieldName, inputType);
+ }
+ else if (commandExtensionInfo.ArgumentType != null && !commandExtensionInfo.IsReturnTypeVoid)
+ {
+ initializer = GenerateInOutCommand(commandExtensionInfo, fieldName, outputType, inputType);
+ }
+ else
+ {
+ return string.Empty;
}
- internal static bool IsTaskReturnType(ITypeSymbol? typeSymbol)
+ // Prepare any forwarded property attributes
+ var forwardedPropertyAttributesString = string.Join("\n\t\t", excludeFromCodeCoverage.Concat(commandExtensionInfo.ForwardedPropertyAttributes));
+
+ return
+$$"""
+{{AddTabs(2)}}private ReactiveUI.ReactiveCommand<{{inputType}}, {{outputType}}>? {{fieldName}};
+
+{{AddTabs(2)}}{{forwardedPropertyAttributesString}}
+{{AddTabs(2)}}public ReactiveUI.ReactiveCommand<{{inputType}}, {{outputType}}> {{commandName}} { get => {{initializer}} }
+""";
+
+ static string GenerateBasicCommand(CommandInfo commandExtensionInfo, string fieldName)
{
- var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
- do
+ var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create;
+ if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
{
- var typeName = typeSymbol?.ToDisplayString(nameFormat);
- if (typeName == "global::System.Threading.Tasks.Task")
- {
- return true;
- }
-
- typeSymbol = typeSymbol?.BaseType;
+ return $"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName});";
}
- while (typeSymbol != null);
- return false;
+ return $"{fieldName} ??= {RxCmd}{commandType}({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});";
}
- internal static bool IsObservableReturnType(ITypeSymbol? typeSymbol)
+ static string GenerateInOutCommand(CommandInfo commandExtensionInfo, string fieldName, string outputType, string inputType)
{
- var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
- do
+ var commandType = commandExtensionInfo.IsObservable ? CreateO : commandExtensionInfo.IsTask ? CreateT : Create;
+ if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
{
- var typeName = typeSymbol?.ToDisplayString(nameFormat);
- if (typeName?.Contains("global::System.IObservable") == true)
- {
- return true;
- }
-
- typeSymbol = typeSymbol?.BaseType;
+ return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName});";
}
- while (typeSymbol != null);
- return false;
+ return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}, {outputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});";
}
- internal static bool IsObservableBoolType(ITypeSymbol? typeSymbol)
+ static string GenerateInCommand(CommandInfo commandExtensionInfo, string fieldName, string inputType)
{
- var nameFormat = SymbolDisplayFormat.FullyQualifiedFormat;
- do
+ var commandType = commandExtensionInfo.IsTask ? CreateT : Create;
+ if (string.IsNullOrEmpty(commandExtensionInfo.CanExecuteObservableName))
{
- var typeName = typeSymbol?.ToDisplayString(nameFormat);
- if (typeName?.Contains("global::System.IObservable") == true)
- {
- return true;
- }
-
- typeSymbol = typeSymbol?.BaseType;
+ return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName});";
}
- while (typeSymbol != null);
- return false;
+ return $"{fieldName} ??= {RxCmd}{commandType}<{inputType}>({commandExtensionInfo.MethodName}, {commandExtensionInfo.CanExecuteObservableName}{(commandExtensionInfo.CanExecuteTypeInfo == CanExecuteTypeInfo.MethodObservable ? "()" : string.Empty)});";
}
+ }
- internal static ITypeSymbol GetTaskReturnType(Compilation compilation, ITypeSymbol typeSymbol) => typeSymbol switch
- {
- INamedTypeSymbol { TypeArguments.Length: 1 } namedTypeSymbol => namedTypeSymbol.TypeArguments[0],
- _ => compilation.GetSpecialType(SpecialType.System_Void)
- };
-
- ///
- /// Tries to get the expression type for the "CanExecute" property, if available.
- ///
- /// The input instance to process.
- /// The instance for .
- /// The resulting can execute member name, if available.
- /// The resulting expression type, if available.
- internal static void TryGetCanExecuteExpressionType(
- IMethodSymbol methodSymbol,
- AttributeData attributeData,
- out string? canExecuteMemberName,
- out CanExecuteTypeInfo? canExecuteTypeInfo)
+ ///
+ /// Tries to get the expression type for the "CanExecute" property, if available.
+ ///
+ /// The input instance to process.
+ /// The instance for .
+ /// The resulting can execute member name, if available.
+ /// The resulting expression type, if available.
+ private static void TryGetCanExecuteExpressionType(
+ IMethodSymbol methodSymbol,
+ AttributeData attributeData,
+ out string? canExecuteMemberName,
+ out CanExecuteTypeInfo? canExecuteTypeInfo)
+ {
+ // Get the can execute member, if any
+ if (!attributeData.TryGetNamedArgument(CanExecute, out string? memberName))
{
- // Get the can execute member, if any
- if (!attributeData.TryGetNamedArgument("CanExecute", out string? memberName))
- {
- canExecuteMemberName = null;
- canExecuteTypeInfo = null;
-
- return;
- }
+ canExecuteMemberName = null;
+ canExecuteTypeInfo = null;
- if (memberName is null)
- {
- goto Failure;
- }
+ return;
+ }
- var canExecuteSymbols = methodSymbol.ContainingType!.GetAllMembers(memberName).ToImmutableArray();
+ if (memberName is null)
+ {
+ goto Failure;
+ }
- if (canExecuteSymbols.IsEmpty)
- {
- // Special case for when the target member is a generated property from [ObservableProperty]
- if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, out canExecuteTypeInfo))
- {
- canExecuteMemberName = memberName;
+ var canExecuteSymbols = methodSymbol.ContainingType!.GetAllMembers(memberName).ToImmutableArray();
- return;
- }
- }
- else if (canExecuteSymbols.Length > 1)
- {
- goto Failure;
- }
- else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], out canExecuteTypeInfo))
+ if (canExecuteSymbols.IsEmpty)
+ {
+ // Special case for when the target member is a generated property from [ObservableProperty]
+ if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, out canExecuteTypeInfo))
{
canExecuteMemberName = memberName;
return;
}
-
- Failure:
- canExecuteMemberName = null;
- canExecuteTypeInfo = null;
+ }
+ else if (canExecuteSymbols.Length > 1)
+ {
+ goto Failure;
+ }
+ else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], out canExecuteTypeInfo))
+ {
+ canExecuteMemberName = memberName;
return;
}
- ///
- /// Gets the expression type for the can execute logic, if possible.
- ///
- /// The can execute member symbol (either a method or a property).
- /// The resulting can execute expression type, if available.
- /// Whether or not was set and the input symbol was valid.
- internal static bool TryGetCanExecuteExpressionFromSymbol(
- ISymbol canExecuteSymbol,
- [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo)
- {
- if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol)
- {
- // The return type must always be a bool
- if (!IsObservableBoolType(canExecuteMethodSymbol.ReturnType))
- {
- goto Failure;
- }
-
- // If the method has parameters, it has to have a single one matching the command type
- if (canExecuteMethodSymbol.Parameters.Length == 1)
- {
- goto Failure;
- }
+ Failure:
+ canExecuteMemberName = null;
+ canExecuteTypeInfo = null;
- // Parameterless methods are always valid
- if (canExecuteMethodSymbol.Parameters.IsEmpty)
- {
- canExecuteTypeInfo = CanExecuteTypeInfo.MethodObservable;
+ return;
+ }
- return true;
- }
+ ///
+ /// Gets the expression type for the can execute logic, if possible.
+ ///
+ /// The can execute member symbol (either a method or a property).
+ /// The resulting can execute expression type, if available.
+ /// Whether or not was set and the input symbol was valid.
+ private static bool TryGetCanExecuteExpressionFromSymbol(
+ ISymbol canExecuteSymbol,
+ [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo)
+ {
+ if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol)
+ {
+ // The return type must always be a bool
+ if (!canExecuteMethodSymbol.ReturnType.IsObservableBoolType())
+ {
+ goto Failure;
}
- else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol)
+
+ // If the method has parameters, it has to have a single one matching the command type
+ if (canExecuteMethodSymbol.Parameters.Length == 1)
{
- // The property type must always be a bool
- if (!IsObservableBoolType(canExecutePropertySymbol.Type))
- {
- goto Failure;
- }
+ goto Failure;
+ }
- canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable;
+ // Parameterless methods are always valid
+ if (canExecuteMethodSymbol.Parameters.IsEmpty)
+ {
+ canExecuteTypeInfo = CanExecuteTypeInfo.MethodObservable;
return true;
}
- else if (canExecuteSymbol is IFieldSymbol canExecuteFieldSymbol)
+ }
+ else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol)
+ {
+ // The property type must always be a bool
+ if (!canExecutePropertySymbol.Type.IsObservableBoolType())
{
- // The property type must always be a bool
- if (!IsObservableBoolType(canExecuteFieldSymbol.Type))
- {
- goto Failure;
- }
+ goto Failure;
+ }
- canExecuteTypeInfo = CanExecuteTypeInfo.FieldObservable;
+ canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable;
- return true;
+ return true;
+ }
+ else if (canExecuteSymbol is IFieldSymbol canExecuteFieldSymbol)
+ {
+ // The property type must always be a bool
+ if (!canExecuteFieldSymbol.Type.IsObservableBoolType())
+ {
+ goto Failure;
}
- Failure:
- canExecuteTypeInfo = null;
+ canExecuteTypeInfo = CanExecuteTypeInfo.FieldObservable;
- return false;
+ return true;
}
- ///
- /// Gets the expression type for the can execute logic, if possible.
- ///
- /// The member name passed to [ReactiveCommand(CanExecute = ...)].
- /// The containing type for the method annotated with [ReactiveCommand].
- /// The resulting can execute expression type, if available.
- /// Whether or not was set and the input symbol was valid.
- internal static bool TryGetCanExecuteMemberFromGeneratedProperty(
- string memberName,
- INamedTypeSymbol containingType,
- [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo)
+ Failure:
+ canExecuteTypeInfo = null;
+
+ return false;
+ }
+
+ ///
+ /// Gets the expression type for the can execute logic, if possible.
+ ///
+ /// The member name passed to [ReactiveCommand(CanExecute = ...)].
+ /// The containing type for the method annotated with [ReactiveCommand].
+ /// The resulting can execute expression type, if available.
+ /// Whether or not was set and the input symbol was valid.
+ private static bool TryGetCanExecuteMemberFromGeneratedProperty(
+ string memberName,
+ INamedTypeSymbol containingType,
+ [NotNullWhen(true)] out CanExecuteTypeInfo? canExecuteTypeInfo)
+ {
+ foreach (var memberSymbol in containingType.GetAllMembers())
{
- foreach (var memberSymbol in containingType.GetAllMembers())
+ // Only look for instance fields of Observable bool type
+ if (!memberSymbol.ContainingType.IsObservableBoolType() || memberSymbol is not IFieldSymbol fieldSymbol)
{
- // Only look for instance fields of Observable bool type
- if (!IsObservableBoolType(memberSymbol.ContainingType) || memberSymbol is not IFieldSymbol fieldSymbol)
- {
- continue;
- }
+ continue;
+ }
- var attributes = memberSymbol.GetAttributes();
+ var attributes = memberSymbol.GetAttributes();
- // Only filter fields with the [Reactive] attribute
- if (memberSymbol is IFieldSymbol &&
- !attributes.Any(static a => a.AttributeClass?.HasFullyQualifiedMetadataName(
- "ReactiveUI.SourceGenerators.ReactiveAttribute") == true))
- {
- continue;
- }
+ // Only filter fields with the [Reactive] attribute
+ if (memberSymbol is IFieldSymbol &&
+ !attributes.Any(static a => a.AttributeClass?.HasFullyQualifiedMetadataName(
+ AttributeDefinitions.ReactiveAttributeType) == true))
+ {
+ continue;
+ }
- // Get the target property name either directly or matching the generated one
- var propertyName = fieldSymbol.GetGeneratedPropertyName();
+ // Get the target property name either directly or matching the generated one
+ var propertyName = fieldSymbol.GetGeneratedPropertyName();
- // If the generated property name matches, get the right expression type
- if (memberName == propertyName)
- {
- canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable;
+ // If the generated property name matches, get the right expression type
+ if (memberName == propertyName)
+ {
+ canExecuteTypeInfo = CanExecuteTypeInfo.PropertyObservable;
- return true;
- }
+ return true;
}
+ }
- canExecuteTypeInfo = null;
+ canExecuteTypeInfo = null;
- return false;
- }
+ return false;
+ }
+
+ ///
+ /// Gathers all forwarded attributes for the generated field and property.
+ ///
+ /// The input instance to process.
+ /// The instance for the current run.
+ /// The method declaration.
+ /// The cancellation token for the current operation.
+ /// The resulting property attributes to forward.
+ private static void GatherForwardedAttributes(
+ IMethodSymbol methodSymbol,
+ SemanticModel semanticModel,
+ MethodDeclarationSyntax methodDeclaration,
+ CancellationToken token,
+ out ImmutableArray propertyAttributes)
+ {
+ using var propertyAttributesInfo = ImmutableArrayBuilder.Rent();
- ///
- /// Gathers all forwarded attributes for the generated field and property.
- ///
- /// The input instance to process.
- /// The instance for the current run.
- /// The method declaration.
- /// The cancellation token for the current operation.
- /// The resulting property attributes to forward.
- internal static void GatherForwardedAttributes(
+ static void GatherForwardedAttributes(
IMethodSymbol methodSymbol,
SemanticModel semanticModel,
MethodDeclarationSyntax methodDeclaration,
CancellationToken token,
- out ImmutableArray propertyAttributes)
+ ImmutableArrayBuilder propertyAttributesInfo)
{
- using var propertyAttributesInfo = ImmutableArrayBuilder.Rent();
-
- static void GatherForwardedAttributes(
- IMethodSymbol methodSymbol,
- SemanticModel semanticModel,
- MethodDeclarationSyntax methodDeclaration,
- CancellationToken token,
- ImmutableArrayBuilder propertyAttributesInfo)
+ // Get the single syntax reference for the input method symbol (there should be only one)
+ if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference])
+ {
+ return;
+ }
+
+ // Gather explicit forwarded attributes info
+ foreach (var attributeList in methodDeclaration.AttributeLists)
{
- // Get the single syntax reference for the input method symbol (there should be only one)
- if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference])
+ if (attributeList.Target?.Identifier is not SyntaxToken(SyntaxKind.PropertyKeyword))
{
- return;
+ continue;
}
- // Gather explicit forwarded attributes info
- foreach (var attributeList in methodDeclaration.AttributeLists)
+ foreach (var attribute in attributeList.Attributes)
{
- if (attributeList.Target?.Identifier is not SyntaxToken(SyntaxKind.PropertyKeyword))
+ if (!semanticModel.GetSymbolInfo(attribute, token).TryGetAttributeTypeSymbol(out var attributeTypeSymbol))
{
continue;
}
- foreach (var attribute in attributeList.Attributes)
+ var attributeArguments = attribute.ArgumentList?.Arguments ?? Enumerable.Empty();
+
+ // Try to extract the forwarded attribute
+ if (!AttributeInfo.TryCreate(attributeTypeSymbol, semanticModel, attributeArguments, token, out var attributeInfo))
+ {
+ continue;
+ }
+
+ // Add the new attribute info to the right builder
+ if (attributeList.Target?.Identifier is SyntaxToken(SyntaxKind.PropertyKeyword))
{
- if (!semanticModel.GetSymbolInfo(attribute, token).TryGetAttributeTypeSymbol(out var attributeTypeSymbol))
- {
- continue;
- }
-
- var attributeArguments = attribute.ArgumentList?.Arguments ?? Enumerable.Empty();
-
- // Try to extract the forwarded attribute
- if (!AttributeInfo.TryCreate(attributeTypeSymbol, semanticModel, attributeArguments, token, out var attributeInfo))
- {
- continue;
- }
-
- // Add the new attribute info to the right builder
- if (attributeList.Target?.Identifier is SyntaxToken(SyntaxKind.PropertyKeyword))
- {
- propertyAttributesInfo.Add(attributeInfo);
- }
+ propertyAttributesInfo.Add(attributeInfo);
}
}
}
+ }
- // If the method is a partial definition, also gather attributes from the implementation part
- if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null })
- {
- var partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol;
- var partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol;
-
- // We always give priority to the partial definition, to ensure a predictable and testable ordering
- GatherForwardedAttributes(partialDefinition, semanticModel, methodDeclaration, token, propertyAttributesInfo);
- GatherForwardedAttributes(partialImplementation, semanticModel, methodDeclaration, token, propertyAttributesInfo);
- }
- else
- {
- // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications
- GatherForwardedAttributes(methodSymbol, semanticModel, methodDeclaration, token, propertyAttributesInfo);
- }
+ // If the method is a partial definition, also gather attributes from the implementation part
+ if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null })
+ {
+ var partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol;
+ var partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol;
- propertyAttributes = propertyAttributesInfo.ToImmutable();
+ // We always give priority to the partial definition, to ensure a predictable and testable ordering
+ GatherForwardedAttributes(partialDefinition, semanticModel, methodDeclaration, token, propertyAttributesInfo);
+ GatherForwardedAttributes(partialImplementation, semanticModel, methodDeclaration, token, propertyAttributesInfo);
}
-
- internal static string GetGeneratedCommandName(string methodName, bool isAsync)
+ else
{
- var commandName = methodName;
+ // If the method is not a partial definition/implementation, just gather attributes from the method with no modifications
+ GatherForwardedAttributes(methodSymbol, semanticModel, methodDeclaration, token, propertyAttributesInfo);
+ }
- if (commandName.StartsWith("m_"))
- {
- commandName = commandName.Substring(2);
- }
- else if (commandName.StartsWith("_"))
- {
- commandName = commandName.TrimStart('_');
- }
+ propertyAttributes = propertyAttributesInfo.ToImmutable();
+ }
- if (commandName.EndsWith("Async") && isAsync)
- {
- commandName = commandName.Substring(0, commandName.Length - "Async".Length);
- }
+ private static string GetGeneratedCommandName(string methodName, bool isAsync)
+ {
+ var commandName = methodName;
+
+ if (commandName.StartsWith("m_"))
+ {
+ commandName = commandName.Substring(2);
+ }
+ else if (commandName.StartsWith("_"))
+ {
+ commandName = commandName.TrimStart('_');
+ }
- return $"{char.ToUpper(commandName[0], CultureInfo.InvariantCulture)}{commandName.Substring(1)}Command";
+ if (commandName.EndsWith("Async") && isAsync)
+ {
+ commandName = commandName.Substring(0, commandName.Length - "Async".Length);
}
- internal static string GetGeneratedFieldName(string generatedCommandName) =>
- $"_{char.ToLower(generatedCommandName[0], CultureInfo.InvariantCulture)}{generatedCommandName.Substring(1)}";
+ return $"{char.ToUpper(commandName[0], CultureInfo.InvariantCulture)}{commandName.Substring(1)}Command";
}
+
+ private static string GetGeneratedFieldName(string generatedCommandName) =>
+ $"_{char.ToLower(generatedCommandName[0], CultureInfo.InvariantCulture)}{generatedCommandName.Substring(1)}";
+
+ private static string AddTabs(int tabCount) => new('\t', tabCount);
}
diff --git a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs
index 10bcb24..47ca75e 100644
--- a/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs
+++ b/src/ReactiveUI.SourceGenerators/ReactiveCommand/ReactiveCommandGenerator.cs
@@ -3,20 +3,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for full license information.
-using System;
-using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using ReactiveUI.SourceGenerators.Extensions;
using ReactiveUI.SourceGenerators.Helpers;
-using ReactiveUI.SourceGenerators.Input.Models;
-using ReactiveUI.SourceGenerators.Models;
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
namespace ReactiveUI.SourceGenerators;
@@ -33,137 +27,44 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
ctx.AddSource($"{AttributeDefinitions.ReactiveCommandAttributeType}.g.cs", SourceText.From(AttributeDefinitions.ReactiveCommandAttribute, Encoding.UTF8)));
// Gather info for all annotated command methods (starting from method declarations with at least one attribute)
- IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> commandInfoWithErrors =
+ var commandInfo =
context.SyntaxProvider
.ForAttributeWithMetadataName(
AttributeDefinitions.ReactiveCommandAttributeType,
static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax or RecordDeclarationSyntax, AttributeLists.Count: > 0 },
- static (context, token) =>
- {
- CommandInfo? commandExtensionInfos = default;
- HierarchyInfo? hierarchy = default;
- using var diagnostics = ImmutableArrayBuilder.Rent();
-
- var methodSyntax = (MethodDeclarationSyntax)context.TargetNode;
- var symbol = ModelExtensions.GetDeclaredSymbol(context.SemanticModel, methodSyntax, token)!;
- token.ThrowIfCancellationRequested();
-
- // Skip symbols without the target attribute
- if (!symbol.TryGetAttributeWithFullyQualifiedMetadataName(AttributeDefinitions.ReactiveCommandAttributeType, out var attributeData))
- {
- return default;
- }
-
- token.ThrowIfCancellationRequested();
- if (attributeData != null)
- {
- var compilation = context.SemanticModel.Compilation;
- var methodSymbol = (IMethodSymbol)symbol!;
- var isTask = Execute.IsTaskReturnType(methodSymbol.ReturnType);
- var isObservable = Execute.IsObservableReturnType(methodSymbol.ReturnType);
- var realReturnType = isTask || isObservable ? Execute.GetTaskReturnType(compilation, methodSymbol.ReturnType) : methodSymbol.ReturnType;
- var isReturnTypeVoid = SymbolEqualityComparer.Default.Equals(realReturnType, compilation.GetSpecialType(SpecialType.System_Void));
- var hasCancellationToken = isTask && methodSymbol.Parameters.Any(x => x.Type.ToDisplayString() == "System.Threading.CancellationToken");
- var methodParameters = new List();
- if (hasCancellationToken && methodSymbol.Parameters.Length == 2)
- {
- methodParameters.Add(methodSymbol.Parameters[0]);
- }
- else if (!hasCancellationToken)
- {
- methodParameters.AddRange(methodSymbol.Parameters);
- }
-
- if (methodParameters.Count > 1)
- {
- return default; // Too many parameters, continue
- }
-
- token.ThrowIfCancellationRequested();
-
- // Get the hierarchy info for the target symbol, and try to gather the command info
- hierarchy = HierarchyInfo.From(methodSymbol.ContainingType);
-
- // Get the CanExecute expression type, if any
- Execute.TryGetCanExecuteExpressionType(
- methodSymbol,
- attributeData,
- out var canExecuteMemberName,
- out var canExecuteTypeInfo);
-
- token.ThrowIfCancellationRequested();
-
- Execute.GatherForwardedAttributes(
- methodSymbol,
- context.SemanticModel,
- methodSyntax,
- token,
- out var forwardedAttributes);
+ static (context, token) => GetMethodInfo(context, token))
+ .Where(x => x != null)
+ .Select((x, _) => x!)
+ .Collect();
- token.ThrowIfCancellationRequested();
-
- // Get the containing type info
- var targetInfo = TargetInfo.From(methodSymbol.ContainingType);
-
- token.ThrowIfCancellationRequested();
-
- commandExtensionInfos = new(
- targetInfo.FileHintName,
- targetInfo.TargetName,
- targetInfo.TargetNamespace,
- targetInfo.TargetNamespaceWithNamespace,
- targetInfo.TargetVisibility,
- targetInfo.TargetType,
- methodSymbol.Name,
- realReturnType.GetFullyQualifiedNameWithNullabilityAnnotations(),
- methodParameters?.SingleOrDefault()?.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
- isTask,
- isReturnTypeVoid,
- isObservable,
- canExecuteMemberName,
- canExecuteTypeInfo,
- forwardedAttributes);
- }
-
- token.ThrowIfCancellationRequested();
- return (Hierarchy: hierarchy, new Result(commandExtensionInfos, diagnostics.ToImmutable()));
- })
- .Where(static item => item.Hierarchy is not null)!;
+ // Generate the requested properties and methods
+ context.RegisterSourceOutput(commandInfo, static (context, input) =>
+ {
+ var groupedcommandInfo = input.GroupBy(
+ static info => (info.FileHintName, info.TargetName, info.TargetNamespace, info.TargetVisibility, info.TargetType),
+ static info => info)
+ .ToImmutableArray();
- // Get the filtered sequence to enable caching
- var propertyInfo =
- commandInfoWithErrors
- .Where(static item => item.Info.Value is not null)!;
+ if (groupedcommandInfo.Length == 0)
+ {
+ return;
+ }
- // Split and group by containing type
- var groupedPropertyInfo =
- propertyInfo
- .GroupBy(static item => item.Left, static item => item.Right.Value);
+ foreach (var grouping in groupedcommandInfo)
+ {
+ var items = grouping.ToImmutableArray();
- // Generate the requested properties and methods
- context.RegisterSourceOutput(groupedPropertyInfo, static (context, item) =>
- {
- var commandInfos = item.Right.ToArray();
+ if (items.Length == 0)
+ {
+ continue;
+ }
- // Generate all member declarations for the current type
- var propertyDeclarations =
- commandInfos
- .SelectMany(Execute.GetCommandProperty)
- .ToList();
+ var (fileHintName, targetName, targetNamespace, targetVisibility, targetType) = grouping.Key;
- var c = Execute.GetCommandInitiliser();
- propertyDeclarations.Add(c);
- var memberDeclarations = propertyDeclarations.ToImmutableArray();
+ var source = GenerateSource(targetName, targetNamespace, targetVisibility, targetType, [.. grouping]);
- // Insert all members into the same partial type declaration
- var compilationUnit = item.Key.GetCompilationUnit(memberDeclarations)
- .WithLeadingTrivia(TriviaList(
- Comment("// "),
- Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)),
- Trivia(NullableDirectiveTrivia(Token(SyntaxKind.EnableKeyword), true)),
- CarriageReturn))
- .NormalizeWhitespace();
- context.AddSource($"{item.Key.FilenameHint}.ReactiveCommands.g.cs", compilationUnit);
+ context.AddSource($"{fileHintName}.ReactiveCommands.g.cs", source);
+ }
});
}
}