diff --git a/src/Analysis/Ast/Impl/Analyzer/AnalysisWalker.cs b/src/Analysis/Ast/Impl/Analyzer/AnalysisWalker.cs index 336094021..d7f1d810a 100644 --- a/src/Analysis/Ast/Impl/Analyzer/AnalysisWalker.cs +++ b/src/Analysis/Ast/Impl/Analyzer/AnalysisWalker.cs @@ -74,8 +74,8 @@ public override async Task WalkAsync(ForStatement node, CancellationToken return await base.WalkAsync(node, cancellationToken); } - public override Task WalkAsync(FromImportStatement node, CancellationToken cancellationToken = default) - => ImportHandler.HandleFromImportAsync(node, cancellationToken); + public override async Task WalkAsync(FromImportStatement node, CancellationToken cancellationToken = default) + => ImportHandler.HandleFromImport(node, cancellationToken); public override Task WalkAsync(GlobalStatement node, CancellationToken cancellationToken = default) => NonLocalHandler.HandleGlobalAsync(node, cancellationToken); @@ -83,8 +83,8 @@ public override Task WalkAsync(GlobalStatement node, CancellationToken can public override Task WalkAsync(IfStatement node, CancellationToken cancellationToken = default) => ConditionalHandler.HandleIfAsync(node, cancellationToken); - public override Task WalkAsync(ImportStatement node, CancellationToken cancellationToken = default) - => ImportHandler.HandleImportAsync(node, cancellationToken); + public override async Task WalkAsync(ImportStatement node, CancellationToken cancellationToken = default) + => ImportHandler.HandleImport(node, cancellationToken); public override Task WalkAsync(NonlocalStatement node, CancellationToken cancellationToken = default) => NonLocalHandler.HandleNonLocalAsync(node, cancellationToken); diff --git a/src/Analysis/Ast/Impl/Analyzer/Definitions/IAnalyzable.cs b/src/Analysis/Ast/Impl/Analyzer/Definitions/IAnalyzable.cs index 2d11db886..1c0c53c0c 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Definitions/IAnalyzable.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Definitions/IAnalyzable.cs @@ -20,39 +20,10 @@ namespace Microsoft.Python.Analysis.Analyzer { /// Represents document that can be analyzed asynchronously. /// internal interface IAnalyzable { - /// - /// Expected version of the analysis when asynchronous operations complete. - /// Typically every change to the document or documents that depend on it - /// increment the expected version. At the end of the analysis if the expected - /// version is still the same, the analysis is applied to the document and - /// becomes available to consumers. - /// - int ExpectedAnalysisVersion { get; } - - /// - /// Notifies document that analysis is now pending. Typically document increments - /// the expected analysis version. The method can be called repeatedly without - /// calling `CompleteAnalysis` first. The method is invoked for every dependency - /// in the chain to ensure that objects know that their dependencies have been - /// modified and the current analysis is no longer up to date. - /// - void NotifyAnalysisPending(); - /// /// Notifies document that its analysis is now complete. /// /// Document analysis - /// True if analysis was accepted, false if is is out of date. - bool NotifyAnalysisComplete(IDocumentAnalysis analysis); - - /// - /// Notifies module that analysis has been canceled. - /// - void NotifyAnalysisCanceled(); - - /// - /// Notifies module that analysis has thrown an exception. - /// - void NotifyAnalysisFailed(Exception ex); + void NotifyAnalysisComplete(IDocumentAnalysis analysis); } } diff --git a/src/Analysis/Ast/Impl/Analyzer/Definitions/IPythonAnalyzer.cs b/src/Analysis/Ast/Impl/Analyzer/Definitions/IPythonAnalyzer.cs index cbec51b2f..4a2fb65f5 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Definitions/IPythonAnalyzer.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Definitions/IPythonAnalyzer.cs @@ -16,17 +16,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Analysis.Documents; +using Microsoft.Python.Analysis.Types; +using Microsoft.Python.Parsing.Ast; namespace Microsoft.Python.Analysis.Analyzer { public interface IPythonAnalyzer { + Task WaitForCompleteAnalysisAsync(CancellationToken cancellationToken = default); + /// - /// Analyze single document. + /// Schedules module for re-analysis /// - Task AnalyzeDocumentAsync(IDocument document, CancellationToken cancellationToken); + void EnqueueDocumentForAnalysis(IPythonModule module, PythonAst ast, int version, CancellationToken cancellationToken = default); /// - /// Analyze document with dependents. + /// /// - Task AnalyzeDocumentDependencyChainAsync(IDocument document, CancellationToken cancellationToken); + Task GetAnalysisAsync(IPythonModule module, int waitTime = 200, CancellationToken cancellationToken = default); } } diff --git a/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Callables.cs b/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Callables.cs index 2f45ad985..e11fb83d8 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Callables.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Callables.cs @@ -160,9 +160,9 @@ public async Task GetValueFromFunctionTypeAsync(IPythonFunctionType fn, fn.IsStub || !string.IsNullOrEmpty(fn.Overloads[args.OverloadIndex].GetReturnDocumentation(null))) { if (fn.IsSpecialized && fn is PythonFunctionType ft) { - foreach (var module in ft.Dependencies) { + foreach (var moduleName in ft.Dependencies) { cancellationToken.ThrowIfCancellationRequested(); - await Interpreter.ModuleResolution.ImportModuleAsync(module, cancellationToken); + Interpreter.ModuleResolution.GetOrLoadModule(moduleName); } } diff --git a/src/Analysis/Ast/Impl/Analyzer/Expressions/ExpressionFinder.cs b/src/Analysis/Ast/Impl/Analyzer/Expressions/ExpressionFinder.cs index 3f4ac60e3..41df70c75 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Expressions/ExpressionFinder.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Expressions/ExpressionFinder.cs @@ -270,12 +270,12 @@ public override bool Walk(ImportStatement node) { SaveStmt(node, true); if (_options.ImportNames) { - foreach (var n in node.Names.MaybeEnumerate()) { + foreach (var n in node.Names) { n?.Walk(this); } } if (_options.ImportAsNames) { - foreach (var n in node.AsNames.MaybeEnumerate()) { + foreach (var n in node.AsNames) { n?.Walk(this); } } @@ -291,13 +291,14 @@ public override bool Walk(FromImportStatement node) { SaveStmt(node, true); if (_options.ImportNames) { - node.Root?.Walk(this); + node.Root.Walk(this); } - foreach (var n in node.Names.MaybeEnumerate()) { + foreach (var n in node.Names) { n?.Walk(this); } - foreach (var n in node.AsNames.MaybeEnumerate()) { + + foreach (var n in node.AsNames) { n?.Walk(this); } diff --git a/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs b/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs index e0a524199..97f503807 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Handlers/FromImportHandler.cs @@ -16,7 +16,6 @@ using System.Diagnostics; using System.Linq; using System.Threading; -using System.Threading.Tasks; using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Modules; using Microsoft.Python.Analysis.Types; @@ -26,9 +25,9 @@ namespace Microsoft.Python.Analysis.Analyzer.Handlers { internal sealed partial class ImportHandler { - public async Task HandleFromImportAsync(FromImportStatement node, CancellationToken cancellationToken = default) { + public bool HandleFromImport(FromImportStatement node, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - if (node.Root == null || node.Names == null || Module.ModuleType == ModuleType.Specialized) { + if (Module.ModuleType == ModuleType.Specialized) { return false; } @@ -46,13 +45,13 @@ public async Task HandleFromImportAsync(FromImportStatement node, Cancella await ImportMembersFromSelfAsync(node, cancellationToken); break; case ModuleImport moduleImport: - await ImportMembersFromModuleAsync(node, moduleImport.FullName, cancellationToken); + ImportMembersFromModule(node, moduleImport.FullName, cancellationToken); break; case PossibleModuleImport possibleModuleImport: - await HandlePossibleImportAsync(possibleModuleImport, possibleModuleImport.PossibleModuleFullName, Eval.GetLoc(node.Root), cancellationToken); + HandlePossibleImport(possibleModuleImport, possibleModuleImport.PossibleModuleFullName, Eval.GetLoc(node.Root)); break; case PackageImport packageImports: - await ImportMembersFromPackageAsync(node, packageImports, cancellationToken); + ImportMembersFromPackage(node, packageImports); break; case ImportNotFound notFound: MakeUnresolvedImport(null, notFound.FullName, Eval.GetLoc(node.Root)); @@ -85,17 +84,17 @@ private async Task ImportMembersFromSelfAsync(FromImportStatement node, Cancella // Consider 'from . import path as path' in os.pyi in typeshed. var import = ModuleResolution.CurrentPathResolver.GetModuleImportFromModuleName($"{Module.Name}.{importName}"); if (!string.IsNullOrEmpty(import?.FullName)) { - member = await ModuleResolution.ImportModuleAsync(import.FullName, cancellationToken); + member = ModuleResolution.GetOrLoadModule(import.FullName); } } Eval.DeclareVariable(memberName, member ?? Eval.UnknownType, VariableSource.Declaration, Eval.GetLoc(names[i])); } } - private async Task ImportMembersFromModuleAsync(FromImportStatement node, string moduleName, CancellationToken cancellationToken = default) { + private void ImportMembersFromModule(FromImportStatement node, string moduleName, CancellationToken cancellationToken = default) { var names = node.Names; var asNames = node.AsNames; - var module = await ModuleResolution.ImportModuleAsync(moduleName, cancellationToken); + var module = ModuleResolution.GetOrLoadModule(moduleName); if (module == null) { return; } @@ -104,7 +103,7 @@ private async Task ImportMembersFromModuleAsync(FromImportStatement node, string // TODO: warn this is not a good style per // TODO: https://docs.python.org/3/faq/programming.html#what-are-the-best-practices-for-using-import-in-a-module // TODO: warn this is invalid if not in the global scope. - await HandleModuleImportStarAsync(module, cancellationToken); + HandleModuleImportStar(module, cancellationToken); return; } @@ -122,7 +121,7 @@ private async Task ImportMembersFromModuleAsync(FromImportStatement node, string } } - private async Task HandleModuleImportStarAsync(IPythonModule module, CancellationToken cancellationToken = default) { + private void HandleModuleImportStar(IPythonModule module, CancellationToken cancellationToken = default) { foreach (var memberName in module.GetMemberNames()) { cancellationToken.ThrowIfCancellationRequested(); @@ -135,14 +134,14 @@ private async Task HandleModuleImportStarAsync(IPythonModule module, Cancellatio member = member ?? Eval.UnknownType; if (member is IPythonModule m) { - await ModuleResolution.ImportModuleAsync(m.Name, cancellationToken); + ModuleResolution.GetOrLoadModule(m.Name); } Eval.DeclareVariable(memberName, member, VariableSource.Import, module.Location); } } - private async Task ImportMembersFromPackageAsync(FromImportStatement node, PackageImport packageImport, CancellationToken cancellationToken = default) { + private void ImportMembersFromPackage(FromImportStatement node, PackageImport packageImport) { var names = node.Names; var asNames = node.AsNames; @@ -158,13 +157,10 @@ private async Task ImportMembersFromPackageAsync(FromImportStatement node, Packa var memberName = memberReference.Name; var location = Eval.GetLoc(memberReference); - ModuleImport moduleImport; - IPythonType member; - if ((moduleImport = packageImport.Modules.FirstOrDefault(mi => mi.Name.EqualsOrdinal(importName))) != null) { - member = await ModuleResolution.ImportModuleAsync(moduleImport.FullName, cancellationToken); - } else { - member = Eval.UnknownType; - } + var moduleImport = packageImport.Modules.FirstOrDefault(mi => mi.Name.EqualsOrdinal(importName)); + var member = moduleImport != null + ? ModuleResolution.GetOrLoadModule(moduleImport.FullName) + : Eval.UnknownType; Eval.DeclareVariable(memberName, member, VariableSource.Import, location); } diff --git a/src/Analysis/Ast/Impl/Analyzer/Handlers/ImportHandler.cs b/src/Analysis/Ast/Impl/Analyzer/Handlers/ImportHandler.cs index c42040dbe..a3b369aa1 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Handlers/ImportHandler.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Handlers/ImportHandler.cs @@ -16,7 +16,6 @@ using System; using System.Linq; using System.Threading; -using System.Threading.Tasks; using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Diagnostics; using Microsoft.Python.Analysis.Modules; @@ -31,9 +30,9 @@ namespace Microsoft.Python.Analysis.Analyzer.Handlers { internal sealed partial class ImportHandler : StatementHandler { public ImportHandler(AnalysisWalker walker) : base(walker) { } - public async Task HandleImportAsync(ImportStatement node, CancellationToken cancellationToken = default) { + public bool HandleImport(ImportStatement node, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - if (node.Names == null || Module.ModuleType == ModuleType.Specialized) { + if (Module.ModuleType == ModuleType.Specialized) { return false; } @@ -58,10 +57,10 @@ public async Task HandleImportAsync(ImportStatement node, CancellationToke Eval.DeclareVariable(memberName, Module, VariableSource.Declaration, location); break; case ModuleImport moduleImport: - module = await HandleImportAsync(moduleImport, location, cancellationToken); + module = HandleImport(moduleImport, location); break; case PossibleModuleImport possibleModuleImport: - module = await HandlePossibleImportAsync(possibleModuleImport, possibleModuleImport.PossibleModuleFullName, location, cancellationToken); + module = HandlePossibleImport(possibleModuleImport, possibleModuleImport.PossibleModuleFullName, location); break; default: // TODO: Package import? @@ -73,22 +72,23 @@ public async Task HandleImportAsync(ImportStatement node, CancellationToke AssignImportedVariables(module, moduleImportExpression, asNameExpression); } } + return false; } - private async Task HandleImportAsync(ModuleImport moduleImport, LocationInfo location, CancellationToken cancellationToken) { - var module = await ModuleResolution.ImportModuleAsync(moduleImport.FullName, cancellationToken); - if (module == null) { - MakeUnresolvedImport(moduleImport.FullName, moduleImport.FullName, location); - return null; + private IPythonModule HandleImport(ModuleImport moduleImport, LocationInfo location) { + var module = ModuleResolution.GetOrLoadModule(moduleImport.FullName); + if (module != null) { + return module; } - return module; + + MakeUnresolvedImport(moduleImport.FullName, moduleImport.FullName, location); + return null; } - private async Task HandlePossibleImportAsync( - PossibleModuleImport possibleModuleImport, string moduleName, LocationInfo location, CancellationToken cancellationToken) { + private IPythonModule HandlePossibleImport(PossibleModuleImport possibleModuleImport, string moduleName, LocationInfo location) { var fullName = possibleModuleImport.PrecedingModuleFullName; - var module = await ModuleResolution.ImportModuleAsync(possibleModuleImport.PrecedingModuleFullName, cancellationToken); + var module = ModuleResolution.GetOrLoadModule(possibleModuleImport.PrecedingModuleFullName); if (module == null) { MakeUnresolvedImport(possibleModuleImport.PrecedingModuleFullName, moduleName, location); return null; diff --git a/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzer.cs b/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzer.cs index 5c90c2038..6fa39d60b 100644 --- a/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzer.cs +++ b/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzer.cs @@ -14,145 +14,308 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Dependencies; using Microsoft.Python.Analysis.Documents; +using Microsoft.Python.Analysis.Modules; +using Microsoft.Python.Analysis.Types; using Microsoft.Python.Core; -using Microsoft.Python.Core.Diagnostics; +using Microsoft.Python.Core.Collections; +using Microsoft.Python.Core.Disposables; +using Microsoft.Python.Core.IO; using Microsoft.Python.Core.Logging; using Microsoft.Python.Core.Services; +using Microsoft.Python.Parsing.Ast; namespace Microsoft.Python.Analysis.Analyzer { public sealed class PythonAnalyzer : IPythonAnalyzer, IDisposable { private readonly IServiceManager _services; - private readonly IDependencyResolver _dependencyResolver; - private readonly CancellationTokenSource _globalCts = new CancellationTokenSource(); + private readonly IDependencyResolver _dependencyResolver; + private readonly Dictionary _analysisEntries = new Dictionary(); + private readonly DisposeToken _disposeToken = DisposeToken.Create(); + private readonly object _syncObj = new object(); + private readonly AsyncManualResetEvent _analysisCompleteEvent = new AsyncManualResetEvent(); private readonly ILogger _log; + private readonly int _maxTaskRunning = 8; + private int _runningTasks; + private int _version; - public PythonAnalyzer(IServiceManager services, string root) { + public PythonAnalyzer(IServiceManager services) { _services = services; _log = services.GetService(); + _dependencyResolver = new DependencyResolver(new ModuleDependencyFinder(services.GetService())); + _analysisCompleteEvent.Set(); + } - _dependencyResolver = services.GetService(); - if (_dependencyResolver == null) { - _dependencyResolver = new DependencyResolver(_services); - _services.AddService(_dependencyResolver); - } + public void Dispose() => _disposeToken.TryMarkDisposed(); - //var rdt = services.GetService(); - //if (rdt == null) { - // services.AddService(new RunningDocumentTable(root, services)); - //} - } + public Task WaitForCompleteAnalysisAsync(CancellationToken cancellationToken = default) + => _analysisCompleteEvent.WaitAsync(cancellationToken); - public void Dispose() => _globalCts.Cancel(); + public void EnqueueDocumentForAnalysis(IPythonModule module, PythonAst ast, int version, CancellationToken cancellationToken) + => AnalyzeDocumentAsync(module, ast, version, cancellationToken).DoNotWait(); - /// - /// Analyze single document. - /// - public async Task AnalyzeDocumentAsync(IDocument document, CancellationToken cancellationToken) { - var node = new DependencyChainNode(document); - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(_globalCts.Token, cancellationToken)) { - try { - var analysis = await AnalyzeAsync(node, cts.Token); - node.Analyzable.NotifyAnalysisComplete(analysis); - } catch(OperationCanceledException) { - node.Analyzable.NotifyAnalysisCanceled(); - throw; - } catch(Exception ex) when(!ex.IsCriticalException()) { - node.Analyzable.NotifyAnalysisFailed(ex); - throw; + public async Task GetAnalysisAsync(IPythonModule module, int waitTime, CancellationToken cancellationToken) { + var key = new ModuleKey(module); + PythonAnalyzerEntry entry; + lock (_syncObj) { + if (!_analysisEntries.TryGetValue(key, out entry)) { + var emptyAnalysis = new EmptyAnalysis(_services, (IDocument)module); + entry = new PythonAnalyzerEntry(module, emptyAnalysis.Ast, emptyAnalysis, 0); + _analysisEntries[key] = entry; } } - } - /// - /// Analyze document with dependents. - /// - public async Task AnalyzeDocumentDependencyChainAsync(IDocument document, CancellationToken cancellationToken) { - Check.InvalidOperation(() => _dependencyResolver != null, "Dependency resolver must be provided for the group analysis."); - - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(_globalCts.Token, cancellationToken)) { - var dependencyRoot = await _dependencyResolver.GetDependencyChainAsync(document, cts.Token); - // Notify each dependency that the analysis is now pending - NotifyAnalysisPending(document, dependencyRoot); - cts.Token.ThrowIfCancellationRequested(); - await AnalyzeChainAsync(dependencyRoot, cts.Token); + if (waitTime == 0 || Debugger.IsAttached) { + return await GetAnalysisAsync(entry, default, cancellationToken); + } + + using (var timeoutCts = new CancellationTokenSource(waitTime)) + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken)) { + cts.CancelAfter(waitTime); + var timeoutToken = timeoutCts.Token; + return await GetAnalysisAsync(entry, timeoutToken, cts.Token); } } - private void NotifyAnalysisPending(IDocument document, IDependencyChainNode node) { - // Notify each dependency that the analysis is now pending except the source - // since if document has changed, it already incremented its expected analysis. - if (node.Analyzable != document) { - node.Analyzable.NotifyAnalysisPending(); - } - foreach (var c in node.Children) { - NotifyAnalysisPending(document, c); + private async Task GetAnalysisAsync(PythonAnalyzerEntry entry, CancellationToken timeoutCt, CancellationToken cancellationToken) { + while (!timeoutCt.IsCancellationRequested) { + try { + var analysis = await entry.GetAnalysisAsync(cancellationToken); + lock (_syncObj) { + if (entry.Version == analysis.Version) { + return analysis; + } + } + } catch (OperationCanceledException) when (timeoutCt.IsCancellationRequested) { + return entry.PreviousAnalysis; + } } + + return entry.PreviousAnalysis; } - private async Task AnalyzeChainAsync(IDependencyChainNode node, CancellationToken cancellationToken) { - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(_globalCts.Token, cancellationToken)) { - try { - var analysis = await AnalyzeAsync(node, cts.Token); - NotifyAnalysisComplete(node, analysis); - } catch (OperationCanceledException) { - node.Analyzable.NotifyAnalysisCanceled(); - throw; - } catch (Exception ex) when (!ex.IsCriticalException()) { - node.Analyzable.NotifyAnalysisFailed(ex); - throw; + private async Task AnalyzeDocumentAsync(IPythonModule module, PythonAst ast, int bufferVersion, CancellationToken cancellationToken) { + var key = new ModuleKey(module); + PythonAnalyzerEntry entry; + lock (_syncObj) { + if (_analysisEntries.TryGetValue(key, out entry)) { + entry.Invalidate(_version + 1, ast); + } else { + entry = new PythonAnalyzerEntry(module, ast, new EmptyAnalysis(_services, (IDocument)module), _version); + _analysisEntries[key] = entry; + } + } + + _analysisCompleteEvent.Reset(); + _log?.Log(TraceEventType.Verbose, $"Analysis of {module.Name}({module.ModuleType}) queued"); + + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(_disposeToken.CancellationToken, cancellationToken)) { + var analysisToken = cts.Token; + + var walker = await _dependencyResolver.AddChangesAsync(new ModuleKey(module), entry, bufferVersion, cts.Token); + var abortAnalysisOnVersionChange = true; + lock (_syncObj) { + if (_version < walker.Version) { + _version = walker.Version; + foreach (var affectedEntry in walker.AffectedValues) { + affectedEntry.Invalidate(_version, affectedEntry.Ast); + if (affectedEntry.UserNotAnalyzed) { + abortAnalysisOnVersionChange = false; + } + } + } + } + + if (walker.MissingKeys.Count > 0) { + LoadMissingDocuments(module.Interpreter, walker.MissingKeys); } - cts.Token.ThrowIfCancellationRequested(); + var stopWatch = Stopwatch.StartNew(); + IDependencyChainNode node; + while ((node = await walker.GetNextAsync(analysisToken)) != null) { + lock (_syncObj) { + if (_version > walker.Version) { + if (abortAnalysisOnVersionChange) { + stopWatch.Stop(); + return; + } - foreach (var c in node.Children) { - await AnalyzeChainAsync(c, cts.Token); + if (!node.Value.UserNotAnalyzed) { + node.MarkCompleted(); + continue; + } + } + } + + if (Interlocked.Increment(ref _runningTasks) >= _maxTaskRunning) { + await AnalyzeAsync(node, walker.Version, stopWatch, analysisToken); + } else { + StartAnalysis(node, walker.Version, stopWatch, analysisToken); + } + } + + + if (walker.MissingKeys.Where(k => !k.IsTypeshed).Count == 0) { + Interlocked.Exchange(ref _runningTasks, 0); + _analysisCompleteEvent.Set(); } } } - private static void NotifyAnalysisComplete(IDependencyChainNode node, IDocumentAnalysis analysis) { - if (!node.Analyzable.NotifyAnalysisComplete(analysis)) { - // If snapshot does not match, there is no reason to continue analysis along the chain - // since subsequent change that incremented the expected version will start - // another analysis run. - throw new OperationCanceledException(); + private static void LoadMissingDocuments(IPythonInterpreter interpreter, ImmutableArray missingKeys) { + foreach (var (moduleName, _, isTypeshed) in missingKeys) { + var moduleResolution = isTypeshed ? interpreter.TypeshedResolution : interpreter.ModuleResolution; + moduleResolution.GetOrLoadModule(moduleName); } } + private void StartAnalysis(IDependencyChainNode node, int version, Stopwatch stopWatch, CancellationToken cancellationToken) + => Task.Run(() => AnalyzeAsync(node, version, stopWatch, cancellationToken), cancellationToken).DoNotWait(); + /// /// Performs analysis of the document. Returns document global scope /// with declared variables and inner scopes. Does not analyze chain /// of dependencies, it is intended for the single file analysis. /// - private async Task AnalyzeAsync(IDependencyChainNode node, CancellationToken cancellationToken) { - var startTime = DateTime.Now; + private async Task AnalyzeAsync(IDependencyChainNode node, int version, Stopwatch stopWatch, CancellationToken cancellationToken) { + try { + var startTime = stopWatch.ElapsedMilliseconds; + var module = node.Value.Module; + var ast = node.Value.Ast; + + // Now run the analysis. + var walker = new ModuleWalker(_services, module, ast); - //_log?.Log(TraceEventType.Verbose, $"Analysis begins: {node.Document.Name}({node.Document.ModuleType})"); - // Store current expected version so we can see if it still - // the same at the time the analysis completes. - var analysisVersion = node.Analyzable.ExpectedAnalysisVersion; + await ast.WalkAsync(walker, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); - // Make sure the file is parsed ans the AST is up to date. - var ast = await node.Document.GetAstAsync(cancellationToken); - //_log?.Log(TraceEventType.Verbose, $"Parse of {node.Document.Name}({node.Document.ModuleType}) complete in {(DateTime.Now - startTime).TotalMilliseconds} ms."); + // Note that we do not set the new analysis here and rather let + // Python analyzer to call NotifyAnalysisComplete. + await walker.CompleteAsync(cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + var analysis = new DocumentAnalysis((IDocument)module, version, walker.GlobalScope, walker.Eval); - // Now run the analysis. - var walker = new ModuleWalker(_services, node.Document, ast); + (module as IAnalyzable)?.NotifyAnalysisComplete(analysis); + node.Value.TrySetAnalysis(analysis, version, _syncObj); + + _log?.Log(TraceEventType.Verbose, $"Analysis of {module.Name}({module.ModuleType}) complete in {stopWatch.ElapsedMilliseconds - startTime} ms."); + } catch (OperationCanceledException oce) { + node.Value.TryCancel(oce, version, _syncObj); + } catch (Exception exception) { + node.Value.TrySetException(exception, version, _syncObj); + } finally { + Interlocked.Decrement(ref _runningTasks); + node.MarkCompleted(); + } + } - await ast.WalkAsync(walker, cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); + [DebuggerDisplay("{Name} : {FilePath}")] + private struct ModuleKey : IEquatable { + public string Name { get; } + public string FilePath { get; } + public bool IsTypeshed { get; } + + public ModuleKey(IPythonModule module) { + Name = module.Name; + FilePath = module.ModuleType == ModuleType.CompiledBuiltin ? null : module.FilePath; + IsTypeshed = module is StubPythonModule stub && stub.IsTypeshed; + } + + public ModuleKey(string name, string filePath, bool isTypeshed) { + Name = name; + FilePath = filePath; + IsTypeshed = isTypeshed; + } + + public bool Equals(ModuleKey other) + => Name.EqualsOrdinal(other.Name) && FilePath.PathEquals(other.FilePath) && IsTypeshed == other.IsTypeshed; + + public override bool Equals(object obj) => obj is ModuleKey other && Equals(other); + + public override int GetHashCode() { + unchecked { + var hashCode = (Name != null ? Name.GetHashCode() : 0); + hashCode = (hashCode * 397) ^ (FilePath != null ? FilePath.GetPathHashCode() : 0); + hashCode = (hashCode * 397) ^ IsTypeshed.GetHashCode(); + return hashCode; + } + } + + public static bool operator ==(ModuleKey left, ModuleKey right) => left.Equals(right); + + public static bool operator !=(ModuleKey left, ModuleKey right) => !left.Equals(right); + + public void Deconstruct(out string moduleName, out string filePath, out bool isTypeshed) { + moduleName = Name; + filePath = FilePath; + isTypeshed = IsTypeshed; + } + + public override string ToString() => $"{Name}({FilePath})"; + } + + private sealed class ModuleDependencyFinder : IDependencyFinder { + private readonly IFileSystem _fileSystem; + + public ModuleDependencyFinder(IFileSystem fileSystem) { + _fileSystem = fileSystem; + } + + public Task> FindDependenciesAsync(PythonAnalyzerEntry value, CancellationToken cancellationToken) { + var dependencies = new HashSet(); + var module = value.Module; + var isTypeshed = module is StubPythonModule stub && stub.IsTypeshed; + var moduleResolution = module.Interpreter.ModuleResolution; + var pathResolver = isTypeshed + ? module.Interpreter.TypeshedResolution.CurrentPathResolver + : moduleResolution.CurrentPathResolver; + + if (module.Stub != null) { + dependencies.Add(new ModuleKey(module.Stub)); + } + + foreach (var node in value.Ast.TraverseDepthFirst(n => n.GetChildNodes())) { + if (cancellationToken.IsCancellationRequested) { + return Task.FromCanceled>(cancellationToken); + } + + switch (node) { + case ImportStatement import: + foreach (var moduleName in import.Names) { + HandleSearchResults(isTypeshed, dependencies, moduleResolution, pathResolver.FindImports(module.FilePath, moduleName, import.ForceAbsolute)); + } + break; + case FromImportStatement fromImport: + HandleSearchResults(isTypeshed, dependencies, moduleResolution, pathResolver.FindImports(module.FilePath, fromImport)); + break; + } + } + + dependencies.Remove(new ModuleKey(value.Module)); + return Task.FromResult(ImmutableArray.Create(dependencies)); + } + + private static void HandleSearchResults(bool isTypeshed, HashSet dependencies, IModuleManagement moduleResolution, IImportSearchResult searchResult) { + switch (searchResult) { + case ModuleImport moduleImport when !Ignore(moduleResolution, moduleImport.FullName): + dependencies.Add(new ModuleKey(moduleImport.FullName, moduleImport.ModulePath, isTypeshed)); + return; + case PossibleModuleImport possibleModuleImport when !Ignore(moduleResolution, possibleModuleImport.PrecedingModuleFullName): + dependencies.Add(new ModuleKey(possibleModuleImport.PrecedingModuleFullName, possibleModuleImport.PrecedingModulePath, isTypeshed)); + return; + default: + return; + } + } - // Note that we do not set the new analysis here and rather let - // Python analyzer to call NotifyAnalysisComplete. - await walker.CompleteAsync(cancellationToken); - _log?.Log(TraceEventType.Verbose, $"Analysis of {node.Document.Name}({node.Document.ModuleType}) complete in {(DateTime.Now - startTime).TotalMilliseconds} ms."); - return new DocumentAnalysis(node.Document, analysisVersion, walker.GlobalScope, walker.Eval); + private static bool Ignore(IModuleManagement moduleResolution, string name) + => moduleResolution.BuiltinModuleName.EqualsOrdinal(name) || moduleResolution.GetSpecializedModule(name) != null; } } } diff --git a/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzerEntry.cs b/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzerEntry.cs new file mode 100644 index 000000000..43cf78e6c --- /dev/null +++ b/src/Analysis/Ast/Impl/Analyzer/PythonAnalyzerEntry.cs @@ -0,0 +1,101 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Analysis.Modules; +using Microsoft.Python.Analysis.Types; +using Microsoft.Python.Parsing.Ast; + +namespace Microsoft.Python.Analysis.Analyzer { + internal sealed class PythonAnalyzerEntry { + private TaskCompletionSource _analysisTcs; + + public IPythonModule Module { get; } + public PythonAst Ast { get; private set; } + public IDocumentAnalysis PreviousAnalysis { get; private set; } + public int Version { get; private set; } + public bool UserNotAnalyzed => PreviousAnalysis is EmptyAnalysis && Module.ModuleType == ModuleType.User; + + public PythonAnalyzerEntry(IPythonModule module, PythonAst ast, IDocumentAnalysis previousAnalysis, int version) { + Module = module; + Ast = ast; + PreviousAnalysis = previousAnalysis; + + Version = version; + _analysisTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + + public Task GetAnalysisAsync(CancellationToken cancellationToken) + => _analysisTcs.Task.ContinueWith(t => t.GetAwaiter().GetResult(), cancellationToken); + + public void TrySetAnalysis(IDocumentAnalysis analysis, int version, object syncObj) { + lock (syncObj) { + if (UserNotAnalyzed) { + PreviousAnalysis = analysis; + } + + if (Version > version) { + return; + } + + Version = version; + } + + _analysisTcs.TrySetResult(analysis); + } + + public void TrySetException(Exception ex, int version, object syncObj) { + lock (syncObj) { + if (Version > version) { + return; + } + + Version = version; + } + + _analysisTcs.TrySetException(ex); + } + + public void TryCancel(OperationCanceledException oce, int version, object syncObj) { + lock (syncObj) { + if (Version > version) { + return; + } + + Version = version; + } + + _analysisTcs.TrySetCanceled(oce.CancellationToken); + } + + public void Invalidate(int version, PythonAst ast) { + if (Version >= version) { + return; + } + + Version = version; + Ast = ast; + if (_analysisTcs.Task.Status == TaskStatus.RanToCompletion) { + PreviousAnalysis = _analysisTcs.Task.Result; + } + + if (_analysisTcs.Task.IsCompleted) { + _analysisTcs = new TaskCompletionSource(); + } + } + } +} diff --git a/src/Analysis/Ast/Impl/Analyzer/Symbols/ClassEvaluator.cs b/src/Analysis/Ast/Impl/Analyzer/Symbols/ClassEvaluator.cs index 500f1b1d8..51fc72539 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Symbols/ClassEvaluator.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Symbols/ClassEvaluator.cs @@ -84,11 +84,13 @@ private async Task ProcessClassBody(CancellationToken cancellationToken = defaul // Process imports foreach (var s in GetStatements(_classDef)) { - await ImportHandler.HandleFromImportAsync(s, cancellationToken); + ImportHandler.HandleFromImport(s, cancellationToken); } + foreach (var s in GetStatements(_classDef)) { - await ImportHandler.HandleImportAsync(s, cancellationToken); + ImportHandler.HandleImport(s, cancellationToken); } + UpdateClassMembers(); // Process assignments so we get class variables declared. diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyChainNode.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyChainNode.cs deleted file mode 100644 index 53d8a6495..000000000 --- a/src/Analysis/Ast/Impl/Dependencies/DependencyChainNode.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright(c) Microsoft Corporation -// All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the License); you may not use -// this file except in compliance with the License. You may obtain a copy of the -// License at http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS -// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY -// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABILITY OR NON-INFRINGEMENT. -// -// See the Apache Version 2.0 License for specific language governing -// permissions and limitations under the License. - -using System.Collections.Generic; -using System.Linq; -using Microsoft.Python.Analysis.Analyzer; -using Microsoft.Python.Analysis.Documents; -using Microsoft.Python.Core.Diagnostics; - -namespace Microsoft.Python.Analysis.Dependencies { - internal sealed class DependencyChainNode : IDependencyChainNode { - public DependencyChainNode(IDocument document, IEnumerable children = null) { - Check.InvalidOperation(() => document is IAnalyzable, "Document must be analyzable entity"); - - Document = document; - SnapshotVersion = Analyzable.ExpectedAnalysisVersion; - Children = children ?? Enumerable.Empty(); - } - - /// - /// Analyzable object (usually the document itself). - /// - public IAnalyzable Analyzable => (IAnalyzable)Document; - - /// - /// Document to analyze. - /// - public IDocument Document { get; } - - /// - /// Object snapshot version at the time of the dependency chain creation. - /// Used to track if completed analysis version matches the current snapshot. - /// - public int SnapshotVersion { get; } - - /// - /// Dependent documents to analyze after this one. Child chains - /// can be analyzed concurrently. - /// - public IEnumerable Children { get; } - } -} diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyGraph.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyGraph.cs new file mode 100644 index 000000000..62fa36ada --- /dev/null +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyGraph.cs @@ -0,0 +1,110 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.Python.Core.Collections; + +namespace Microsoft.Python.Analysis.Dependencies { + /// + /// Graph that represents dependencies between modules + /// NOT THREAD SAFE. All operations should happen under lock + /// + internal sealed class DependencyGraph { + private readonly Dictionary> _verticesByKey = new Dictionary>(); + private readonly List> _verticesByIndex = new List>(); + + public int Version { get; private set; } + + public bool TryAddOrUpdate(TKey key, TValue value, int version, out DependencyVertex vertex) { + if (_verticesByKey.TryGetValue(key, out var currentVertex)) { + if (version <= currentVertex.Version) { + // We aren't interested in older versions + vertex = null; + return false; + } + + vertex = new DependencyVertex(currentVertex, value, version); + _verticesByIndex[vertex.Index] = vertex; + } else { + vertex = new DependencyVertex(key, value, _verticesByIndex.Count, version); + _verticesByIndex.Add(vertex); + } + + Version++; + _verticesByKey[key] = vertex; + return true; + } + + public void ResolveDependencies(out ImmutableArray> snapshot, out ImmutableArray missingKeys) { + var missingKeysHashSet = new HashSet(); + var vertices = _verticesByIndex + .Where(v => !v.IsSealed || v.HasMissingKeys) + .Select(v => GetOrCreateNonSealedVertex(v.Index)) + .ToArray(); + + if (vertices.Length == 0) { + snapshot = ImmutableArray>.Create(_verticesByIndex); + missingKeys = ImmutableArray.Empty; + return; + } + + foreach (var vertex in vertices) { + var newIncoming = ImmutableArray.Empty; + var oldIncoming = vertex.Incoming; + + foreach (var dependencyKey in vertex.IncomingKeys) { + if (_verticesByKey.TryGetValue(dependencyKey, out var dependency)) { + newIncoming = newIncoming.Add(dependency.Index); + } else { + missingKeysHashSet.Add(dependencyKey); + vertex.SetHasMissingKeys(); + } + } + + foreach (var index in oldIncoming.Except(newIncoming)) { + var incomingVertex = GetOrCreateNonSealedVertex(index); + incomingVertex.RemoveOutgoing(vertex.Index); + } + + foreach (var index in newIncoming.Except(oldIncoming)) { + var incomingVertex = GetOrCreateNonSealedVertex(index); + incomingVertex.AddOutgoing(vertex.Index); + } + + vertex.SetIncoming(newIncoming); + } + + foreach (var vertex in _verticesByIndex) { + vertex.Seal(); + } + + snapshot = ImmutableArray>.Create(_verticesByIndex); + missingKeys = ImmutableArray.Create(missingKeysHashSet); + + DependencyVertex GetOrCreateNonSealedVertex(int index) { + var vertex = _verticesByIndex[index]; + if (!vertex.IsSealed) { + return vertex; + } + + vertex = new DependencyVertex(vertex, vertex.Value, vertex.Version); + _verticesByIndex[index] = vertex; + _verticesByKey[vertex.Key] = vertex; + return vertex; + } + } + } +} diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs index f3a9a846a..2723fc8ba 100644 --- a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs @@ -13,18 +13,358 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Microsoft.Python.Analysis.Documents; -using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; +using Microsoft.Python.Core.Threading; namespace Microsoft.Python.Analysis.Dependencies { - internal sealed class DependencyResolver : IDependencyResolver { - public DependencyResolver(IServiceContainer services) { + internal sealed class DependencyResolver : IDependencyResolver { + private readonly IDependencyFinder _dependencyFinder; + private readonly DependencyGraph _vertices = new DependencyGraph(); + private readonly Dictionary> _changedVertices = new Dictionary>(); + private readonly object _syncObj = new object(); + + public ImmutableArray MissingKeys { get; private set; } + + public DependencyResolver(IDependencyFinder dependencyFinder) { + _dependencyFinder = dependencyFinder; + } + + public async Task> AddChangesAsync(TKey key, TValue value, int valueVersion, CancellationToken cancellationToken) { + int version; + ImmutableArray> changedVertices; + + lock (_syncObj) { + cancellationToken.ThrowIfCancellationRequested(); + if (!_vertices.TryAddOrUpdate(key, value, valueVersion, out var dependencyVertex)) { + throw new OperationCanceledException(); + } + + version = _vertices.Version; + _changedVertices[key] = dependencyVertex; + changedVertices = ImmutableArray>.Create(_changedVertices.Values); + } + + if (changedVertices.Count == 1) { + await changedVertices[0].EnsureDependenciesAsync(_dependencyFinder); + } else { + // If any of the dependency analysis cancelled, this method will be canceled as well, + // so no need to terminate it explicitly + var tasks = changedVertices.Select(e => e.EnsureDependenciesAsync(_dependencyFinder)).ToArray(); + await ChangedVerticesAwaitable.Create(tasks); + } + + ImmutableArray> snapshot; + ImmutableArray missingKeys; + lock (_syncObj) { + cancellationToken.ThrowIfCancellationRequested(); + if (version < _vertices.Version) { + throw new OperationCanceledException(); + } + + _vertices.ResolveDependencies(out snapshot, out missingKeys); + } + + var walkingGraph = CreateWalkingGraph(snapshot, changedVertices); + var affectedValues = walkingGraph.Select(v => v.DependencyVertex.Value); + var loopsCount = FindLoops(walkingGraph); + var (startingVertices, totalNodesCount) = ResolveLoops(walkingGraph, loopsCount); + foreach (var vertex in walkingGraph) { + vertex.Seal(); + vertex.SecondPass?.Seal(); + } + + lock (_syncObj) { + cancellationToken.ThrowIfCancellationRequested(); + if (version < _vertices.Version) { + throw new OperationCanceledException(); + } + + MissingKeys = missingKeys; + } + + return new DependencyChainWalker(this, startingVertices, affectedValues, missingKeys, totalNodesCount, version); + } + + private void CommitChanges(int version) { + lock (_syncObj) { + if (version == _vertices.Version) { + _changedVertices.Clear(); + } + } + } + + private ImmutableArray> CreateWalkingGraph(ImmutableArray> snapshot, ImmutableArray> changedVertices) { + var analysisGraph = ImmutableArray>.Empty; + var nodesByVertexIndex = new Dictionary>(); + + foreach (var vertex in changedVertices) { + var node = new WalkingVertex(snapshot[vertex.Index]); + analysisGraph = analysisGraph.Add(node); + nodesByVertexIndex[vertex.Index] = node; + } + + var queue = new Queue>(analysisGraph); + while (queue.Count > 0) { + var node = queue.Dequeue(); + foreach (var outgoingIndex in node.DependencyVertex.Outgoing) { + if (!nodesByVertexIndex.TryGetValue(outgoingIndex, out var outgoingNode)) { + outgoingNode = new WalkingVertex(snapshot[outgoingIndex]); + analysisGraph = analysisGraph.Add(outgoingNode); + nodesByVertexIndex[outgoingIndex] = outgoingNode; + + queue.Enqueue(outgoingNode); + } + + node.AddOutgoing(outgoingNode); + } + } + + return analysisGraph; + } + + private int FindLoops(ImmutableArray> graph) { + var index = 0; + var loopNumber = 0; + var stackP = new Stack>(); + var stackS = new Stack>(); + + foreach (var vertex in graph) { + if (vertex.Index == -1) { + CheckForLoop(vertex, stackP, stackS, ref index, ref loopNumber); + } + } + + return loopNumber; + } + + private void CheckForLoop(WalkingVertex vertex, Stack> stackP, Stack> stackS, ref int counter, ref int loopNumber) { + vertex.Index = counter++; + stackP.Push(vertex); + stackS.Push(vertex); + + foreach (var child in vertex.Outgoing) { + if (child.Index == -1) { + CheckForLoop(child, stackP, stackS, ref counter, ref loopNumber); + } else if (child.LoopNumber == -1) { + while (stackP.Peek().Index > child.Index) { + stackP.Pop(); + } + } + } + + if (stackP.Count > 0 && vertex == stackP.Peek()) { + if (SetLoopNumber(vertex, stackS, loopNumber)) { + loopNumber++; + } + stackP.Pop(); + } + } + + private static bool SetLoopNumber(WalkingVertex vertex, Stack> stackS, int loopIndex) { + var count = 0; + WalkingVertex loopVertex; + do { + loopVertex = stackS.Pop(); + loopVertex.LoopNumber = loopIndex; + count++; + } while (loopVertex != vertex); + + if (count != 1) { + return true; + } + + vertex.LoopNumber = -2; + return false; } - public Task GetDependencyChainAsync(IDocument document, CancellationToken cancellationToken) { - // TODO: implement - return Task.FromResult(new DependencyChainNode(document)); + + private static (ImmutableArray>, int) ResolveLoops(ImmutableArray> graph, int loopsCount) { + // Create vertices for second pass + var inLoopsCount = 0; + var secondPassLoops = new List>[loopsCount]; + foreach (var vertex in graph) { + if (vertex.IsInLoop) { + var secondPassVertex = vertex.CreateSecondPassVertex(); + var loopNumber = vertex.LoopNumber; + if (secondPassLoops[loopNumber] == null) { + secondPassLoops[loopNumber] = new List> { secondPassVertex }; + } else { + secondPassLoops[loopNumber].Add(secondPassVertex); + } + + inLoopsCount++; + } + vertex.Index = -1; // Reset index, will use later + } + + // Break the loops so that its items can be iterated + foreach (var loop in secondPassLoops) { + // Sort loop items by amount of incoming connections + loop.Sort(WalkingVertex.FirstPassIncomingComparison); + + var counter = 0; + foreach (var secondPassVertex in loop) { + var vertex = secondPassVertex.FirstPass; + if (vertex.Index == -1) { + RemoveLoopEdges(vertex, ref counter); + } + } + } + + // Make all vertices from second pass loop have incoming edges from vertices from first pass loop and set unique loop numbers + foreach (var loop in secondPassLoops) { + foreach (var secondPassVertex in loop) { + var firstPassVertex = secondPassVertex.FirstPass; + secondPassVertex.LoopNumber = loopsCount; + firstPassVertex.AddOutgoing(loop); + + // Copy outgoing edges to the second pass vertex + foreach (var outgoingVertex in firstPassVertex.Outgoing) { + if (outgoingVertex.LoopNumber == firstPassVertex.LoopNumber) { + secondPassVertex.AddOutgoing(outgoingVertex.SecondPass); + } + } + } + + loopsCount++; + } + + // Iterate original graph to get starting vertices + return (graph.Where(v => v.IncomingCount == 0), graph.Count + inLoopsCount); + } + + private static void RemoveLoopEdges(WalkingVertex vertex, ref int counter) { + vertex.Index = counter++; + for (var i = vertex.Outgoing.Count - 1; i >= 0; i--) { + var outgoing = vertex.Outgoing[i]; + if (outgoing.Index == -1) { + RemoveLoopEdges(outgoing, ref counter); + } else if (outgoing.Index < vertex.Index) { + vertex.RemoveOutgoingAt(i); + } + } + } + + private sealed class DependencyChainWalker : IDependencyChainWalker { + private readonly DependencyResolver _dependencyResolver; + private readonly PriorityProducerConsumer> _ppc; + private readonly object _syncObj; + private int _remaining; + + public ImmutableArray MissingKeys { get; } + public ImmutableArray AffectedValues { get; } + public int Version { get; } + public bool IsCompleted { get; private set; } + + public DependencyChainWalker(in DependencyResolver dependencyResolver, in ImmutableArray> startingVertices, in ImmutableArray affectedValues, in ImmutableArray missingKeys, in int totalNodesCount, in int version) { + _syncObj = new object(); + _dependencyResolver = dependencyResolver; + _ppc = new PriorityProducerConsumer>(); + AffectedValues = affectedValues; + Version = version; + MissingKeys = missingKeys; + + _remaining = totalNodesCount; + IsCompleted = _remaining == 0; + foreach (var vertex in startingVertices) { + _ppc.Produce(new DependencyChainNode(this, vertex)); + } + } + + public Task> GetNextAsync(CancellationToken cancellationToken) => + _ppc.ConsumeAsync(cancellationToken); + + public void MarkCompleted(WalkingVertex vertex) { + var verticesToProduce = new List>(); + var isCompleted = false; + lock (_syncObj) { + _remaining--; + foreach (var outgoing in vertex.Outgoing) { + if (outgoing.IncomingCount == 0) { + continue; + } + + outgoing.DecrementIncoming(); + if (outgoing.IncomingCount > 0) { + continue; + } + + verticesToProduce.Add(outgoing); + } + + if (_remaining == 0) { + IsCompleted = isCompleted = true; + } + } + + if (isCompleted) { + _ppc.Produce(null); + _ppc.Dispose(); + _dependencyResolver.CommitChanges(Version); + } else { + foreach (var toProduce in verticesToProduce) { + _ppc.Produce(new DependencyChainNode(this, toProduce)); + } + } + } + } + + private sealed class DependencyChainNode : IDependencyChainNode { + private readonly WalkingVertex _vertex; + private DependencyChainWalker _walker; + public TValue Value => _vertex.DependencyVertex.Value; + + public DependencyChainNode(DependencyChainWalker walker, WalkingVertex vertex) { + _walker = walker; + _vertex = vertex; + } + + public void MarkCompleted() => Interlocked.Exchange(ref _walker, null)?.MarkCompleted(_vertex); + } + + private sealed class ChangedVerticesAwaitable { + private readonly TaskCompletionSourceEx _tcs; + private int _count; + + private ChangedVerticesAwaitable(Task[] tasks) { + _tcs = new TaskCompletionSourceEx(); + _count = tasks.Length; + + foreach (var task in tasks) { + task.ContinueWith(CountdownOnCompletion, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } + } + + public static ChangedVerticesAwaitable Create(Task[] tasks) => new ChangedVerticesAwaitable(tasks); + + public TaskAwaiter GetAwaiter() => ((Task)_tcs.Task).GetAwaiter(); + + private void CountdownOnCompletion(Task task) { + switch (task.Status) { + case TaskStatus.RanToCompletion: + if (Interlocked.Decrement(ref _count) == 0) { + _tcs.TrySetResult(0); + } + return; + case TaskStatus.Canceled: + try { + task.GetAwaiter().GetResult(); + } catch (OperationCanceledException ex) { + _tcs.TrySetCanceled(ex); + } + return; + case TaskStatus.Faulted: + _tcs.TrySetException(task.Exception); + return; + default: + throw new ArgumentOutOfRangeException(); + } + } } } } diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs new file mode 100644 index 000000000..d747897e4 --- /dev/null +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs @@ -0,0 +1,119 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; + +namespace Microsoft.Python.Analysis.Dependencies { + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + internal sealed class DependencyVertex { + private readonly CancellationTokenSource _incomingKeysCts; + private TaskCompletionSource _incomingKeysTcs; + + public TKey Key { get; } + public TValue Value { get; } + public int Version { get; } + public int Index { get; } + public string DebuggerDisplay => $"{Key}:{Value}"; + + public bool IsSealed { get; private set; } + public bool HasMissingKeys { get; private set; } + + public ImmutableArray IncomingKeys { get; private set; } + public ImmutableArray Incoming { get; private set; } + public ImmutableArray Outgoing { get; private set; } + + public DependencyVertex(DependencyVertex oldVertex, TValue value, int version) { + Key = oldVertex.Key; + Value = value; + Version = version; + Index = oldVertex.Index; + + _incomingKeysCts = new CancellationTokenSource(); + oldVertex._incomingKeysCts.Cancel(); + + IncomingKeys = oldVertex.IncomingKeys; + Incoming = oldVertex.Incoming; + Outgoing = oldVertex.Outgoing; + } + + public DependencyVertex(TKey key, TValue value, int index, int version) { + Key = key; + Value = value; + Version = version; + Index = index; + + _incomingKeysCts = new CancellationTokenSource(); + + IncomingKeys = ImmutableArray.Empty; + Incoming = ImmutableArray.Empty; + Outgoing = ImmutableArray.Empty; + } + + public Task EnsureDependenciesAsync(IDependencyFinder dependencyFinder) { + if (_incomingKeysTcs != null) { + return _incomingKeysTcs.Task; + } + + var tcs = new TaskCompletionSource(); + if (Interlocked.CompareExchange(ref _incomingKeysTcs, tcs, null) == null) { + return FindDependenciesAsync(dependencyFinder); + } + + return _incomingKeysTcs.Task; + } + + private async Task FindDependenciesAsync(IDependencyFinder dependencyFinder) { + try { + IncomingKeys = await dependencyFinder.FindDependenciesAsync(Value, _incomingKeysCts.Token); + } catch (OperationCanceledException e) { + _incomingKeysTcs.TrySetCanceled(e.CancellationToken); + throw; + } catch (Exception ex) { + _incomingKeysTcs.TrySetException(ex); + throw; + } + _incomingKeysTcs.TrySetResult(0); + } + + public void AddOutgoing(int index) { + AssertIsNotSealed(); + Outgoing = Outgoing.Add(index); + } + + public void RemoveOutgoing(int index) { + AssertIsNotSealed(); + Outgoing = Outgoing.Remove(index); + } + + public void SetIncoming(ImmutableArray incoming) { + AssertIsNotSealed(); + Incoming = incoming; + } + + public void SetHasMissingKeys() { + AssertIsNotSealed(); + HasMissingKeys = true; + } + + public void Seal() => IsSealed = true; + + [Conditional("DEBUG")] + private void AssertIsNotSealed() => Debug.Assert(!IsSealed); + } +} diff --git a/src/Analysis/Ast/Impl/Dependencies/IDependencyChainNode.cs b/src/Analysis/Ast/Impl/Dependencies/IDependencyChainNode.cs index 7de845db1..8d1addee2 100644 --- a/src/Analysis/Ast/Impl/Dependencies/IDependencyChainNode.cs +++ b/src/Analysis/Ast/Impl/Dependencies/IDependencyChainNode.cs @@ -13,35 +13,9 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. -using System.Collections.Generic; -using Microsoft.Python.Analysis.Analyzer; -using Microsoft.Python.Analysis.Documents; - namespace Microsoft.Python.Analysis.Dependencies { - /// - /// Represents a node in a chain of a document dependencies. - /// - internal interface IDependencyChainNode { - /// - /// Analyzable object (usually the document itself). - /// - IAnalyzable Analyzable { get; } - - /// - /// Document to analyze. - /// - IDocument Document { get; } - - /// - /// Version of the document at the time of the dependency chain creation. - /// Used to track if completed analysis matches current document snapshot. - /// - int SnapshotVersion { get; } - - /// - /// Dependent documents to analyze after this one. Child chains - /// can be analyzed concurrently. - /// - IEnumerable Children { get; } + internal interface IDependencyChainNode { + TValue Value { get; } + void MarkCompleted(); } } diff --git a/src/Analysis/Ast/Impl/Dependencies/IDependencyChainWalker.cs b/src/Analysis/Ast/Impl/Dependencies/IDependencyChainWalker.cs new file mode 100644 index 000000000..03841cf06 --- /dev/null +++ b/src/Analysis/Ast/Impl/Dependencies/IDependencyChainWalker.cs @@ -0,0 +1,28 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; + +namespace Microsoft.Python.Analysis.Dependencies { + internal interface IDependencyChainWalker { + Task> GetNextAsync(CancellationToken cancellationToken); + ImmutableArray MissingKeys { get; } + ImmutableArray AffectedValues { get; } + int Version { get; } + bool IsCompleted { get; } + } +} diff --git a/src/Analysis/Ast/Impl/Dependencies/IDependencyFinder.cs b/src/Analysis/Ast/Impl/Dependencies/IDependencyFinder.cs new file mode 100644 index 000000000..52dabfff1 --- /dev/null +++ b/src/Analysis/Ast/Impl/Dependencies/IDependencyFinder.cs @@ -0,0 +1,24 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; + +namespace Microsoft.Python.Analysis.Dependencies { + internal interface IDependencyFinder { + Task> FindDependenciesAsync(TValue value, CancellationToken cancellationToken); + } +} diff --git a/src/Analysis/Ast/Impl/Dependencies/IDependencyResolver.cs b/src/Analysis/Ast/Impl/Dependencies/IDependencyResolver.cs index 8885436a3..4bd1ae13b 100644 --- a/src/Analysis/Ast/Impl/Dependencies/IDependencyResolver.cs +++ b/src/Analysis/Ast/Impl/Dependencies/IDependencyResolver.cs @@ -16,6 +16,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Analysis.Documents; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Analysis.Dependencies { /// @@ -24,7 +25,8 @@ namespace Microsoft.Python.Analysis.Dependencies { /// for the analysis. The chain is a tree where child branches can be analyzed /// concurrently. /// - internal interface IDependencyResolver { - Task GetDependencyChainAsync(IDocument document, CancellationToken cancellationToken); + internal interface IDependencyResolver { + ImmutableArray MissingKeys { get; } + Task> AddChangesAsync(TKey key, TValue value, int valueVersion, CancellationToken cancellationToken); } } diff --git a/src/Analysis/Ast/Impl/Dependencies/WalkingVertex.cs b/src/Analysis/Ast/Impl/Dependencies/WalkingVertex.cs new file mode 100644 index 000000000..b32dd8b1c --- /dev/null +++ b/src/Analysis/Ast/Impl/Dependencies/WalkingVertex.cs @@ -0,0 +1,101 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.Python.Core.Diagnostics; + +namespace Microsoft.Python.Analysis.Dependencies { + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + internal sealed class WalkingVertex { + public static Comparison> FirstPassIncomingComparison { get; } = (v1, v2) => v1.FirstPass.IncomingCount.CompareTo(v2.FirstPass.IncomingCount); + + private readonly List> _outgoing; + private bool _isSealed; + public DependencyVertex DependencyVertex { get; } + public IReadOnlyList> Outgoing => _outgoing; + + public int Index { get; set; } + public int LoopNumber { get; set; } + public int IncomingCount { get; private set; } + + public WalkingVertex FirstPass { get; } + public WalkingVertex SecondPass { get; private set; } + + public bool IsInLoop => LoopNumber >= 0; + + public string DebuggerDisplay => DependencyVertex.DebuggerDisplay; + + public WalkingVertex(DependencyVertex vertex, WalkingVertex firstPass = null) { + DependencyVertex = vertex; + FirstPass = firstPass; + Index = -1; + LoopNumber = -1; + _outgoing = new List>(); + } + + public void AddOutgoing(WalkingVertex outgoingVertex) { + CheckNotSealed(); + + _outgoing.Add(outgoingVertex); + outgoingVertex.IncomingCount++; + } + + public void AddOutgoing(List> loop) { + CheckNotSealed(); + + _outgoing.AddRange(loop); + foreach (var outgoingVertex in loop) { + outgoingVertex.IncomingCount++; + } + } + + public void RemoveOutgoingAt(int index) { + CheckNotSealed(); + + var outgoingVertex = _outgoing[index]; + _outgoing.RemoveAt(index); + outgoingVertex.IncomingCount--; + } + + public WalkingVertex CreateSecondPassVertex() { + CheckNotSealed(); + + SecondPass = new WalkingVertex(DependencyVertex, this); + + for (var i = _outgoing.Count - 1; i >= 0; i--) { + var outgoingVertex = _outgoing[i]; + if (LoopNumber == outgoingVertex.LoopNumber) { + continue; + } + + SecondPass._outgoing.Add(outgoingVertex); + _outgoing.RemoveAt(i); + } + + return SecondPass; + } + + public void Seal() => _isSealed = true; + public void DecrementIncoming() { + CheckSealed(); + IncomingCount--; + } + + private void CheckSealed() => Check.InvalidOperation(_isSealed); + private void CheckNotSealed() => Check.InvalidOperation(!_isSealed); + } +} diff --git a/src/Analysis/Ast/Impl/Documents/Definitions/IDocument.cs b/src/Analysis/Ast/Impl/Documents/Definitions/IDocument.cs index 5429b855d..ad07de2b3 100644 --- a/src/Analysis/Ast/Impl/Documents/Definitions/IDocument.cs +++ b/src/Analysis/Ast/Impl/Documents/Definitions/IDocument.cs @@ -49,7 +49,7 @@ public interface IDocument: IPythonModule, IDisposable { /// /// Returns document analysis. /// - Task GetAnalysisAsync(CancellationToken cancellationToken = default); + Task GetAnalysisAsync(int waitTime, CancellationToken cancellationToken = default); /// /// Returns last known document AST. The AST may be out of date or null. diff --git a/src/Analysis/Ast/Impl/Documents/RunningDocumentTable.cs b/src/Analysis/Ast/Impl/Documents/RunningDocumentTable.cs index 0bda5802e..389bf6c32 100644 --- a/src/Analysis/Ast/Impl/Documents/RunningDocumentTable.cs +++ b/src/Analysis/Ast/Impl/Documents/RunningDocumentTable.cs @@ -21,6 +21,7 @@ using System.Linq; using Microsoft.Python.Analysis.Modules; using Microsoft.Python.Core; +using Microsoft.Python.Core.Diagnostics; namespace Microsoft.Python.Analysis.Documents { /// @@ -95,7 +96,7 @@ public IDocument AddModule(ModuleCreationOptions mco) { mco.Uri = uri; } - var entry = FindDocument(mco.FilePath, mco.Uri) ?? CreateDocument(mco); + var entry = FindDocument(mco.ModuleName, mco.Uri) ?? CreateDocument(mco); entry.LockCount++; return entry.Document; } @@ -165,17 +166,14 @@ private DocumentEntry FindDocument(string moduleName, Uri uri) { private DocumentEntry CreateDocument(ModuleCreationOptions mco) { IDocument document; switch (mco.ModuleType) { - case ModuleType.Stub: - document = new StubPythonModule(mco.ModuleName, mco.FilePath, _services); - break; - case ModuleType.Compiled: + case ModuleType.Compiled when TryAddModulePath(mco): document = new CompiledPythonModule(mco.ModuleName, ModuleType.Compiled, mco.FilePath, mco.Stub, _services); break; case ModuleType.CompiledBuiltin: document = new CompiledBuiltinPythonModule(mco.ModuleName, mco.Stub, _services); break; - case ModuleType.User: - case ModuleType.Library: + case ModuleType.User when TryAddModulePath(mco): + case ModuleType.Library when TryAddModulePath(mco): document = new PythonModule(mco, _services); break; default: @@ -185,11 +183,24 @@ private DocumentEntry CreateDocument(ModuleCreationOptions mco) { var entry = new DocumentEntry { Document = document, LockCount = 0 }; _documentsByUri[document.Uri] = entry; _documentsByName[mco.ModuleName] = entry; - - ModuleManagement.AddModulePath(document.FilePath); return entry; } + private bool TryAddModulePath(ModuleCreationOptions mco) { + var filePath = mco.FilePath ?? mco.Uri?.ToAbsolutePath(); + if (filePath == null) { + throw new InvalidOperationException("Can't create document with no file path or URI specified"); + } + + if (!ModuleManagement.TryAddModulePath(filePath, out var fullName)) { + return false; + } + + mco.FilePath = filePath; + mco.ModuleName = fullName; + return true; + } + private bool TryOpenDocument(DocumentEntry entry, string content) { if (!entry.Document.IsOpen) { entry.Document.Reset(content); diff --git a/src/Analysis/Ast/Impl/Modules/BuiltinsPythonModule.cs b/src/Analysis/Ast/Impl/Modules/BuiltinsPythonModule.cs index 7c2e6d0dd..55ba2ad16 100644 --- a/src/Analysis/Ast/Impl/Modules/BuiltinsPythonModule.cs +++ b/src/Analysis/Ast/Impl/Modules/BuiltinsPythonModule.cs @@ -46,13 +46,13 @@ protected override IEnumerable GetScrapeArguments(IPythonInterpreter int => !InstallPath.TryGetFile("scrape_module.py", out var sb) ? null : new List { "-B", "-E", sb }; protected override void OnAnalysisComplete() { - lock (AnalysisLock) { - SpecializeTypes(); - SpecializeFunctions(); - foreach (var n in GetMemberNames()) { - GetMember(n).GetPythonType()?.MakeReadOnly(); - } + SpecializeTypes(); + SpecializeFunctions(); + foreach (var n in GetMemberNames()) { + GetMember(n).GetPythonType()?.MakeReadOnly(); } + + base.OnAnalysisComplete(); } private void SpecializeTypes() { diff --git a/src/Analysis/Ast/Impl/Modules/CompiledBuiltinPythonModule.cs b/src/Analysis/Ast/Impl/Modules/CompiledBuiltinPythonModule.cs index aceff65f8..34e1242d4 100644 --- a/src/Analysis/Ast/Impl/Modules/CompiledBuiltinPythonModule.cs +++ b/src/Analysis/Ast/Impl/Modules/CompiledBuiltinPythonModule.cs @@ -25,7 +25,7 @@ namespace Microsoft.Python.Analysis.Modules { /// internal sealed class CompiledBuiltinPythonModule : CompiledPythonModule { public CompiledBuiltinPythonModule(string moduleName, IPythonModule stub, IServiceContainer services) - : base(moduleName, ModuleType.Compiled, MakeFakeFilePath(moduleName, services), stub, services) { } + : base(moduleName, ModuleType.CompiledBuiltin, MakeFakeFilePath(moduleName, services), stub, services) { } protected override IEnumerable GetScrapeArguments(IPythonInterpreter interpreter) => !InstallPath.TryGetFile("scrape_module.py", out var sm) diff --git a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs index a5997bc29..bbe8230f1 100644 --- a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs +++ b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs @@ -35,7 +35,13 @@ public interface IModuleManagement: IModuleResolution { IModuleCache ModuleCache { get; } - void AddModulePath(string path); + bool TryAddModulePath(in string path, out string fullName); + + /// + /// Sets user search paths. This changes . + /// + /// Added roots. + IEnumerable SetUserSearchPaths(in IEnumerable searchPaths); /// /// Provides ability to specialize module by replacing module import by diff --git a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleResolution.cs b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleResolution.cs index caa66b332..def49f596 100644 --- a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleResolution.cs +++ b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleResolution.cs @@ -41,15 +41,15 @@ public interface IModuleResolution { /// /// Returns an IPythonModule for a given module name. Returns null if - /// the module does not exist. The import is performed asynchronously. + /// the module has not been imported. /// - Task ImportModuleAsync(string name, CancellationToken cancellationToken = default); + IPythonModule GetImportedModule(string name); /// - /// Returns an IPythonModule for a given module name. Returns null if - /// the module has not been imported. + /// Returns an IPythonModule for a given module name. + /// Returns null if the module wasn't found. /// - IPythonModule GetImportedModule(string name); + IPythonModule GetOrLoadModule(string name); /// /// Sets user search paths. This changes . diff --git a/src/Analysis/Ast/Impl/Modules/ModuleCache.cs b/src/Analysis/Ast/Impl/Modules/ModuleCache.cs index e4c30565d..c7e2300ed 100644 --- a/src/Analysis/Ast/Impl/Modules/ModuleCache.cs +++ b/src/Analysis/Ast/Impl/Modules/ModuleCache.cs @@ -66,7 +66,7 @@ public async Task ImportFromCacheAsync(string name, CancellationToken var rdt = _services.GetService(); var mco = new ModuleCreationOptions { ModuleName = name, - ModuleType = ModuleType.Stub, + ModuleType = ModuleType.Compiled, FilePath = cache }; var module = rdt.AddModule(mco); diff --git a/src/Analysis/Ast/Impl/Modules/PythonModule.cs b/src/Analysis/Ast/Impl/Modules/PythonModule.cs index d0cc0cea2..a54b48e2e 100644 --- a/src/Analysis/Ast/Impl/Modules/PythonModule.cs +++ b/src/Analysis/Ast/Impl/Modules/PythonModule.cs @@ -42,7 +42,7 @@ namespace Microsoft.Python.Analysis.Modules { /// [DebuggerDisplay("{Name} : {ModuleType}")] public class PythonModule : IDocument, IAnalyzable, IEquatable { - protected enum State { + private enum State { None, Loading, Loaded, @@ -52,15 +52,12 @@ protected enum State { Analyzed } - private readonly AsyncLocal _awaiting = new AsyncLocal(); private readonly DocumentBuffer _buffer = new DocumentBuffer(); private readonly CancellationTokenSource _allProcessingCts = new CancellationTokenSource(); private IReadOnlyList _parseErrors = Array.Empty(); private readonly IDiagnosticsService _diagnosticsService; private string _documentation; // Must be null initially. - private TaskCompletionSource _analysisTcs; - private CancellationTokenSource _linkedAnalysisCts; // cancellation token combined with the 'dispose' cts private CancellationTokenSource _parseCts; private CancellationTokenSource _linkedParseCts; // combined with 'dispose' cts private Task _parsingTask; @@ -69,8 +66,8 @@ protected enum State { protected ILogger Log { get; } protected IFileSystem FileSystem { get; } protected IServiceContainer Services { get; } - protected object AnalysisLock { get; } = new object(); - protected State ContentState { get; set; } = State.None; + private object AnalysisLock { get; } = new object(); + private State ContentState { get; set; } = State.None; protected PythonModule(string name, ModuleType moduleType, IServiceContainer services) { Name = name ?? throw new ArgumentNullException(nameof(name)); @@ -196,13 +193,10 @@ public virtual IEnumerable GetMemberNames() { /// loaded (lazy) modules may choose to defer content retrieval and /// analysis until later time, when module members are actually needed. /// - public virtual Task LoadAndAnalyzeAsync(CancellationToken cancellationToken = default) { - if (_awaiting.Value) { - return Task.FromResult(Analysis); - } - _awaiting.Value = true; + public async Task LoadAndAnalyzeAsync(CancellationToken cancellationToken = default) { InitializeContent(null); - return GetAnalysisAsync(cancellationToken); + await GetAstAsync(cancellationToken); + await Services.GetService().GetAnalysisAsync(this, -1, cancellationToken); } protected virtual string LoadContent() { @@ -222,11 +216,6 @@ private void InitializeContent(string content) { LoadContent(content); var startParse = ContentState < State.Parsing && _parsingTask == null; - var startAnalysis = startParse | (ContentState < State.Analyzing && _analysisTcs?.Task == null); - - if (startAnalysis) { - ExpectNewAnalysis(); - } if (startParse) { Parse(); } @@ -317,9 +306,6 @@ public async Task GetAstAsync(CancellationToken cancellationToken = d public void Update(IEnumerable changes) { lock (AnalysisLock) { - ExpectNewAnalysis(); - _linkedAnalysisCts?.Cancel(); - _parseCts?.Cancel(); _parseCts = new CancellationTokenSource(); @@ -340,8 +326,6 @@ public void Reset(string content) { } private void Parse() { - _awaiting.Value = false; - _parseCts?.Cancel(); _parseCts = new CancellationTokenSource(); @@ -388,25 +372,20 @@ private void Parse(CancellationToken cancellationToken) { _diagnosticsService?.Replace(Uri, _parseErrors.Concat(Analysis.Diagnostics)); } - _parsingTask = null; ContentState = State.Parsed; } NewAst?.Invoke(this, EventArgs.Empty); if (ContentState < State.Analyzing) { - Log?.Log(TraceEventType.Verbose, $"Analysis queued: {Name}"); ContentState = State.Analyzing; - _linkedAnalysisCts?.Dispose(); - _linkedAnalysisCts = CancellationTokenSource.CreateLinkedTokenSource(_allProcessingCts.Token, cancellationToken); - var analyzer = Services.GetService(); - if (ModuleType == ModuleType.User || ModuleType == ModuleType.Library) { - analyzer.AnalyzeDocumentDependencyChainAsync(this, _linkedAnalysisCts.Token).DoNotWait(); - } else { - analyzer.AnalyzeDocumentAsync(this, _linkedAnalysisCts.Token).DoNotWait(); - } + analyzer.EnqueueDocumentForAnalysis(this, ast, version, _allProcessingCts.Token); + } + + lock (AnalysisLock) { + _parsingTask = null; } } @@ -420,65 +399,29 @@ public override void Add(string message, SourceSpan span, int errorCode, Severit #endregion #region IAnalyzable - /// - /// Expected version of the analysis when asynchronous operations complete. - /// Typically every change to the document or documents that depend on it - /// increment the expected version. At the end of the analysis if the expected - /// version is still the same, the analysis is applied to the document and - /// becomes available to consumers. - /// - public int ExpectedAnalysisVersion { get; private set; } - - /// - /// Notifies document that analysis is now pending. Typically document increments - /// the expected analysis version. The method can be called repeatedly without - /// calling `CompleteAnalysis` first. The method is invoked for every dependency - /// in the chain to ensure that objects know that their dependencies have been - /// modified and the current analysis is no longer up to date. - /// - public void NotifyAnalysisPending() { + public void NotifyAnalysisComplete(IDocumentAnalysis analysis) { lock (AnalysisLock) { - // The notification comes from the analyzer when it needs to invalidate - // current analysis since one of the dependencies changed. If text - // buffer changed then the notification won't come since the analyzer - // filters out original initiator of the analysis. - ExpectNewAnalysis(); - //Log?.Log(TraceEventType.Verbose, $"Analysis pending: {Name}"); - } - } + if (analysis.Version < Analysis.Version) { + return; + } - public virtual bool NotifyAnalysisComplete(IDocumentAnalysis analysis) { - lock (AnalysisLock) { - // Log?.Log(TraceEventType.Verbose, $"Analysis complete: {Name}, Version: {analysis.Version}, Expected: {ExpectedAnalysisVersion}"); - if (analysis.Version == ExpectedAnalysisVersion) { - Analysis = analysis; - GlobalScope = analysis.GlobalScope; - - // Derived classes can override OnAnalysisComplete if they want - // to perform additional actions on the completed analysis such - // as declare additional variables, etc. - OnAnalysisComplete(); - ContentState = State.Analyzed; - - // Do not report issues with libraries or stubs - if (ModuleType == ModuleType.User) { - _diagnosticsService?.Replace(Uri, _parseErrors.Concat(Analysis.Diagnostics)); - } + Analysis = analysis; + GlobalScope = analysis.GlobalScope; - var tcs = _analysisTcs; - _analysisTcs = null; - tcs.TrySetResult(analysis); + // Derived classes can override OnAnalysisComplete if they want + // to perform additional actions on the completed analysis such + // as declare additional variables, etc. + OnAnalysisComplete(); + ContentState = State.Analyzed; + } - NewAnalysis?.Invoke(this, EventArgs.Empty); - return true; - } - Debug.Assert(ExpectedAnalysisVersion > analysis.Version); - return false; + // Do not report issues with libraries or stubs + if (ModuleType == ModuleType.User) { + _diagnosticsService?.Replace(Uri, _parseErrors.Concat(analysis.Diagnostics)); } - } - public void NotifyAnalysisCanceled() => _analysisTcs?.TrySetCanceled(); - public void NotifyAnalysisFailed(Exception ex) => _analysisTcs?.TrySetException(ex); + NewAnalysis?.Invoke(this, EventArgs.Empty); + } protected virtual void OnAnalysisComplete() { } #endregion @@ -486,17 +429,10 @@ protected virtual void OnAnalysisComplete() { } #region Analysis public IDocumentAnalysis GetAnyAnalysis() => Analysis; - public Task GetAnalysisAsync(CancellationToken cancellationToken = default) { - lock (AnalysisLock) { - return _analysisTcs?.Task ?? Task.FromResult(Analysis); - } - } - #endregion + public Task GetAnalysisAsync(int waitTime = 200, CancellationToken cancellationToken = default) + => Services.GetService().GetAnalysisAsync(this, waitTime, cancellationToken); - private void ExpectNewAnalysis() { - ExpectedAnalysisVersion++; - _analysisTcs = _analysisTcs ?? new TaskCompletionSource(); - } + #endregion private string TryGetDocFromModuleInitFile() { if (string.IsNullOrEmpty(FilePath) || !FileSystem.FileExists(FilePath)) { diff --git a/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs b/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs index 039736b04..ed7eeb0e0 100644 --- a/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs +++ b/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs @@ -21,11 +21,13 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Analysis.Analyzer; using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Core.Interpreter; using Microsoft.Python.Analysis.Documents; using Microsoft.Python.Analysis.Types; using Microsoft.Python.Core; +using Microsoft.Python.Core.Diagnostics; namespace Microsoft.Python.Analysis.Modules.Resolution { internal sealed class MainModuleResolution : ModuleResolutionBase, IModuleManagement { @@ -45,7 +47,8 @@ internal async Task InitializeAsync(CancellationToken cancellationToken = defaul var modulePath = ModuleCache.GetCacheFilePath(_interpreter.Configuration.InterpreterPath); var b = new BuiltinsPythonModule(moduleName, modulePath, _services); - _modules[BuiltinModuleName] = BuiltinsModule = b; + BuiltinsModule = b; + _modules[BuiltinModuleName] = new ModuleRef(b); } public async Task> GetSearchPathsAsync(CancellationToken cancellationToken = default) { @@ -55,9 +58,11 @@ public async Task> GetSearchPathsAsync(CancellationToken c _searchPaths = await GetInterpreterSearchPathsAsync(cancellationToken); Debug.Assert(_searchPaths != null, "Should have search paths"); - _searchPaths = _searchPaths != null - ? _searchPaths.Concat(Configuration.SearchPaths ?? Array.Empty()).ToArray() - : Array.Empty(); + _searchPaths = _searchPaths != null + ? Configuration.SearchPaths != null + ? _searchPaths.Concat(Configuration.SearchPaths).ToArray() + : _searchPaths + : Array.Empty(); _log?.Log(TraceEventType.Information, "SearchPaths:"); foreach (var s in _searchPaths) { @@ -66,7 +71,7 @@ public async Task> GetSearchPathsAsync(CancellationToken c return _searchPaths; } - protected override async Task DoImportAsync(string name, CancellationToken cancellationToken = default) { + protected override IPythonModule CreateModule(string name) { var moduleImport = CurrentPathResolver.GetModuleImportFromModuleName(name); if (moduleImport == null) { _log?.Log(TraceEventType.Verbose, "Import not found: ", name); @@ -75,19 +80,25 @@ protected override async Task DoImportAsync(string name, Cancella // If there is a stub, make sure it is loaded and attached // First check stub next to the module. - var stub = await GetModuleStubAsync(name, moduleImport.ModulePath, cancellationToken); - // If nothing found, try Typeshed. - stub = stub ?? await _interpreter.TypeshedResolution.ImportModuleAsync(moduleImport.IsBuiltin ? name : moduleImport.FullName, cancellationToken); + if (!TryCreateModuleStub(name, moduleImport.ModulePath, out var stub)) { + // If nothing found, try Typeshed. + stub = _interpreter.TypeshedResolution.GetOrLoadModule(moduleImport.IsBuiltin ? name : moduleImport.FullName); + } + + // If stub is created and its path equals to module, return stub instead of module + if (stub != null && stub.FilePath.PathEquals(moduleImport.ModulePath)) { + return stub; + } IPythonModule module; if (moduleImport.IsBuiltin) { - _log?.Log(TraceEventType.Verbose, "Import built-in compiled (scraped) module: ", name, Configuration.InterpreterPath); + _log?.Log(TraceEventType.Verbose, "Create built-in compiled (scraped) module: ", name, Configuration.InterpreterPath); module = new CompiledBuiltinPythonModule(name, stub, _services); } else if (moduleImport.IsCompiled) { - _log?.Log(TraceEventType.Verbose, "Import compiled (scraped): ", moduleImport.FullName, moduleImport.ModulePath, moduleImport.RootPath); + _log?.Log(TraceEventType.Verbose, "Create compiled (scraped): ", moduleImport.FullName, moduleImport.ModulePath, moduleImport.RootPath); module = new CompiledPythonModule(moduleImport.FullName, ModuleType.Compiled, moduleImport.ModulePath, stub, _services); } else { - _log?.Log(TraceEventType.Verbose, "Import: ", moduleImport.FullName, moduleImport.ModulePath); + _log?.Log(TraceEventType.Verbose, "Create: ", moduleImport.FullName, moduleImport.ModulePath); var rdt = _services.GetService(); // TODO: handle user code and library module separately. var mco = new ModuleCreationOptions { @@ -99,7 +110,6 @@ protected override async Task DoImportAsync(string name, Cancella module = rdt.AddModule(mco); } - await module.LoadAndAnalyzeAsync(cancellationToken); return module; } @@ -143,6 +153,8 @@ public IPythonModule GetSpecializedModule(string name) internal async Task LoadBuiltinTypesAsync(CancellationToken cancellationToken = default) { await BuiltinsModule.LoadAndAnalyzeAsync(cancellationToken); + Check.InvalidOperation(!(BuiltinsModule.Analysis is EmptyAnalysis), "After await"); + // Add built-in module names var builtinModuleNamesMember = BuiltinsModule.GetAnyMember("__builtin_module_names__"); if (builtinModuleNamesMember.TryGetConstant(out var s)) { @@ -151,7 +163,7 @@ internal async Task LoadBuiltinTypesAsync(CancellationToken cancellationToken = } } - public override async Task ReloadAsync(CancellationToken cancellationToken = default) { + public async Task ReloadAsync(CancellationToken cancellationToken = default) { ModuleCache = new ModuleCache(_interpreter, _services); PathResolver = new PathResolver(_interpreter.LanguageVersion); @@ -169,25 +181,27 @@ public override async Task ReloadAsync(CancellationToken cancellationToken = def ReloadModulePaths(addedRoots); } + public IEnumerable SetUserSearchPaths(in IEnumerable searchPaths) + => PathResolver.SetUserSearchPaths(searchPaths); + // For tests - internal void AddUnimportableModule(string moduleName) - => _modules[moduleName] = new SentinelModule(moduleName, _services); + internal void AddUnimportableModule(string moduleName) + => _modules[moduleName] = new ModuleRef(new SentinelModule(moduleName, _services)); - private async Task GetModuleStubAsync(string name, string modulePath, CancellationToken cancellationToken = default) { + private bool TryCreateModuleStub(string name, string modulePath, out IPythonModule module) { // First check stub next to the module. if (!string.IsNullOrEmpty(modulePath)) { var pyiPath = Path.ChangeExtension(modulePath, "pyi"); if (_fs.FileExists(pyiPath)) { - return await CreateStubModuleAsync(name, pyiPath, cancellationToken); + module = new StubPythonModule(name, pyiPath, false, _services); + return true; } } // Try location of stubs that are in a separate folder next to the package. var stubPath = CurrentPathResolver.GetPossibleModuleStubPaths(name).FirstOrDefault(p => _fs.FileExists(p)); - if (!string.IsNullOrEmpty(stubPath)) { - return await CreateStubModuleAsync(name, stubPath, cancellationToken); - } - return null; + module = !string.IsNullOrEmpty(stubPath) ? new StubPythonModule(name, stubPath, false, _services) : null; + return module != null; } } } diff --git a/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs b/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs index b46a3b0ca..6af91641b 100644 --- a/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs +++ b/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs @@ -1,4 +1,4 @@ -// Copyright(c) Microsoft Corporation +// Copyright(c) Microsoft Corporation // All rights reserved. // // Licensed under the Apache License, Version 2.0 (the License); you may not use @@ -13,16 +13,14 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. -using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; using System.IO; using System.Linq; using System.Threading; -using System.Threading.Tasks; using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Core.Interpreter; +using Microsoft.Python.Analysis.Documents; using Microsoft.Python.Analysis.Types; using Microsoft.Python.Core; using Microsoft.Python.Core.IO; @@ -30,7 +28,7 @@ namespace Microsoft.Python.Analysis.Modules.Resolution { internal abstract class ModuleResolutionBase { - protected readonly ConcurrentDictionary _modules = new ConcurrentDictionary(); + protected readonly ConcurrentDictionary _modules = new ConcurrentDictionary(); protected readonly IServiceContainer _services; protected readonly IPythonInterpreter _interpreter; protected readonly IFileSystem _fs; @@ -65,8 +63,7 @@ protected ModuleResolutionBase(string root, IServiceContainer services) { /// public IBuiltinsPythonModule BuiltinsModule { get; protected set; } - public abstract Task ReloadAsync(CancellationToken cancellationToken = default); - protected abstract Task DoImportAsync(string name, CancellationToken cancellationToken = default); + protected abstract IPythonModule CreateModule(string name); public IReadOnlyCollection GetPackagesFromDirectory(string searchPath, CancellationToken cancellationToken) { return ModulePath.GetModulesInPath( @@ -77,18 +74,25 @@ public IReadOnlyCollection GetPackagesFromDirectory(string searchPath, C ).Select(mp => mp.ModuleName).Where(n => !string.IsNullOrEmpty(n)).TakeWhile(_ => !cancellationToken.IsCancellationRequested).ToList(); } - public IPythonModule GetImportedModule(string name) { + public IPythonModule GetImportedModule(string name) + => _modules.TryGetValue(name, out var moduleRef) ? moduleRef.Value : _interpreter.ModuleResolution.GetSpecializedModule(name); + + public IPythonModule GetOrLoadModule(string name) { + if (_modules.TryGetValue(name, out var moduleRef)) { + return moduleRef.GetOrCreate(name, this); + } + var module = _interpreter.ModuleResolution.GetSpecializedModule(name); if (module != null) { return module; } - return _modules.TryGetValue(name, out module) ? module : null; - } - public IEnumerable SetUserSearchPaths(in IEnumerable searchPaths) - => PathResolver.SetUserSearchPaths(searchPaths); + moduleRef = _modules.GetOrAdd(name, new ModuleRef()); + return moduleRef.GetOrCreate(name, this); + } - public void AddModulePath(string path) => PathResolver.TryAddModulePath(path, out var _); + public bool TryAddModulePath(in string path, out string fullModuleName) + => PathResolver.TryAddModulePath(path, true, out fullModuleName); public ModulePath FindModule(string filePath) { var bestLibraryPath = string.Empty; @@ -103,127 +107,57 @@ public ModulePath FindModule(string filePath) { return ModulePath.FromFullPath(filePath, bestLibraryPath); } - public async Task ImportModuleAsync(string name, CancellationToken cancellationToken = default) { - if (name == BuiltinModuleName) { - return BuiltinsModule; - } - var module = _interpreter.ModuleResolution.GetSpecializedModule(name); - if (module != null) { - return module; + protected void ReloadModulePaths(in IEnumerable rootPaths) { + foreach (var modulePath in rootPaths.Where(Directory.Exists).SelectMany(p => PathUtils.EnumerateFiles(p))) { + PathResolver.TryAddModulePath(modulePath, false, out _); } - return await DoImportModuleAsync(name, cancellationToken); } - private async Task DoImportModuleAsync(string name, CancellationToken cancellationToken = default) { - for (var retries = 5; retries > 0; --retries) { - cancellationToken.ThrowIfCancellationRequested(); - - // The call should be cancelled by the cancellation token, but since we - // are blocking here we wait for slightly longer. Timeouts are handled - // gracefully by TryImportModuleAsync(), so we want those to trigger if - // possible, but if all else fails then we'll abort and treat it as an - // error. - // (And if we've got a debugger attached, don't time out at all.) - TryImportModuleResult result; - try { - result = await TryImportModuleAsync(name, cancellationToken); - } catch (OperationCanceledException) { - _log?.Log(TraceEventType.Error, $"Import timeout: {name}"); - Debug.Fail("Import timeout"); - return null; - } - - switch (result.Status) { - case TryImportModuleResultCode.Success: - return result.Module; - case TryImportModuleResultCode.ModuleNotFound: - _log?.Log(TraceEventType.Information, $"Import not found: {name}"); - return null; - case TryImportModuleResultCode.NeedRetry: - case TryImportModuleResultCode.Timeout: - break; - case TryImportModuleResultCode.NotSupported: - _log?.Log(TraceEventType.Error, $"Import not supported: {name}"); - return null; - } - } - // Never succeeded, so just log the error and fail - _log?.Log(TraceEventType.Error, $"Retry import failed: {name}"); - return null; - } + protected class ModuleRef { + private readonly object _syncObj = new object(); + private IPythonModule _module; + private bool _creating; - private async Task TryImportModuleAsync(string name, CancellationToken cancellationToken = default) { - if (string.IsNullOrEmpty(name)) { - return TryImportModuleResult.ModuleNotFound; - } - if (name == BuiltinModuleName) { - return new TryImportModuleResult(BuiltinsModule); + public ModuleRef(IPythonModule module) { + _module = module; } - Debug.Assert(!name.EndsWithOrdinal("."), $"{name} should not end with '.'"); - // Return any existing module - if (_modules.TryGetValue(name, out var module) && module != null) { - if (module is SentinelModule) { - // TODO: we can't just wait here or we hang. There are two cases: - // a. Recursion on the same analysis chain (A -> B -> A) - // b. Call from another chain (A -> B -> C and D -> B -> E). - // TODO: Both should be resolved at the dependency chain level. - //_log?.Log(TraceEventType.Warning, $"Recursive import: {name}"); - } - return new TryImportModuleResult(module); - } + public ModuleRef() {} - // Set up a sentinel so we can detect recursive imports - var sentinelValue = new SentinelModule(name, _services); - if (!_modules.TryAdd(name, sentinelValue)) { - // Try to get the new module, in case we raced with a .Clear() - if (_modules.TryGetValue(name, out module) && !(module is SentinelModule)) { - return new TryImportModuleResult(module); + public IPythonModule Value { + get { + lock (_syncObj) { + return _module; + } } - // If we reach here, the race is too complicated to recover - // from. Signal the caller to try importing again. - _log?.Log(TraceEventType.Warning, $"Retry import: {name}"); - return TryImportModuleResult.NeedRetry; } - // Do normal searches - try { - module = await DoImportAsync(name, cancellationToken); - } catch (OperationCanceledException) { - _log?.Log(TraceEventType.Error, $"Import timeout {name}"); - return TryImportModuleResult.Timeout; - } + public IPythonModule GetOrCreate(string name, ModuleResolutionBase mrb) { + bool create = false; + lock (_syncObj) { + if (_module != null) { + return _module; + } - if (ModuleCache != null) { - module = module ?? await ModuleCache.ImportFromCacheAsync(name, cancellationToken); - } + if (!_creating) { + create = true; + _creating = true; + } + } - // Replace our sentinel - if (!_modules.TryUpdate(name, module, sentinelValue)) { - // Try to get the new module, in case we raced - if (_modules.TryGetValue(name, out module) && !(module is SentinelModule)) { - return new TryImportModuleResult(module); + if (!create) { + return null; } - // If we reach here, the race is too complicated to recover - // from. Signal the caller to try importing again. - _log?.Log(TraceEventType.Warning, $"Retry import: {name}"); - return TryImportModuleResult.NeedRetry; - } - return new TryImportModuleResult(module); - } + var module = mrb.CreateModule(name); + ((IDocument)module)?.Reset(null); - protected void ReloadModulePaths(in IEnumerable rootPaths) { - foreach (var modulePath in rootPaths.Where(Directory.Exists).SelectMany(p => PathUtils.EnumerateFiles(p))) { - PathResolver.TryAddModulePath(modulePath, out _); + lock (_syncObj) { + _creating = false; + _module = module; + return module; + } } } - - protected async Task CreateStubModuleAsync(string moduleName, string filePath, CancellationToken cancellationToken = default) { - _log?.Log(TraceEventType.Verbose, "Import type stub", moduleName, filePath); - var module = new StubPythonModule(moduleName, filePath, _services); - await module.LoadAndAnalyzeAsync(cancellationToken); - return module; - } } -} +} \ No newline at end of file diff --git a/src/Analysis/Ast/Impl/Modules/Resolution/TypeshedResolution.cs b/src/Analysis/Ast/Impl/Modules/Resolution/TypeshedResolution.cs index 51376d6d6..751289c1e 100644 --- a/src/Analysis/Ast/Impl/Modules/Resolution/TypeshedResolution.cs +++ b/src/Analysis/Ast/Impl/Modules/Resolution/TypeshedResolution.cs @@ -30,7 +30,9 @@ internal sealed class TypeshedResolution : ModuleResolutionBase, IModuleResoluti private readonly IReadOnlyList _typeStubPaths; public TypeshedResolution(IServiceContainer services) : base(null, services) { - _modules[BuiltinModuleName] = BuiltinsModule = _interpreter.ModuleResolution.BuiltinsModule; + BuiltinsModule = _interpreter.ModuleResolution.BuiltinsModule; + _modules[BuiltinModuleName] = new ModuleRef(BuiltinsModule); + _root = _interpreter.Configuration?.TypeshedPath; // TODO: merge with user-provided stub paths _typeStubPaths = GetTypeShedPaths(_interpreter.Configuration?.TypeshedPath).ToArray(); @@ -44,14 +46,14 @@ public TypeshedResolution(IServiceContainer services) : base(null, services) { internal Task InitializeAsync(CancellationToken cancellationToken = default) => ReloadAsync(cancellationToken); - protected override async Task DoImportAsync(string name, CancellationToken cancellationToken = default) { + protected override IPythonModule CreateModule(string name) { var mp = FindModuleInSearchPath(_typeStubPaths, null, name); if (mp != null) { if (mp.Value.IsCompiled) { _log?.Log(TraceEventType.Warning, "Unsupported native module in stubs", mp.Value.FullName, mp.Value.SourceFile); return null; } - return await CreateStubModuleAsync(mp.Value.FullName, mp.Value.SourceFile, cancellationToken); + return new StubPythonModule(mp.Value.FullName, mp.Value.SourceFile, true, _services); } var i = name.IndexOf('.'); @@ -61,10 +63,10 @@ protected override async Task DoImportAsync(string name, Cancella } var stubPath = CurrentPathResolver.GetPossibleModuleStubPaths(name).FirstOrDefault(p => _fs.FileExists(p)); - return stubPath != null ? await CreateStubModuleAsync(name, stubPath, cancellationToken) : null; + return stubPath != null ? new StubPythonModule(name, stubPath, true, _services) : null; } - public override Task ReloadAsync(CancellationToken cancellationToken = default) { + public Task ReloadAsync(CancellationToken cancellationToken = default) { PathResolver = new PathResolver(_interpreter.LanguageVersion); var addedRoots = PathResolver.SetRoot(_root); diff --git a/src/Analysis/Ast/Impl/Modules/StubPythonModule.cs b/src/Analysis/Ast/Impl/Modules/StubPythonModule.cs index e3b2f7bb2..6606b0538 100644 --- a/src/Analysis/Ast/Impl/Modules/StubPythonModule.cs +++ b/src/Analysis/Ast/Impl/Modules/StubPythonModule.cs @@ -24,8 +24,11 @@ namespace Microsoft.Python.Analysis.Modules { /// Represents module that contains stub code such as from typeshed. /// internal class StubPythonModule : CompiledPythonModule { - public StubPythonModule(string moduleName, string stubPath, IServiceContainer services) + public bool IsTypeshed { get; } + + public StubPythonModule(string moduleName, string stubPath, bool isTypeshed, IServiceContainer services) : base(moduleName, ModuleType.Stub, stubPath, null, services) { + IsTypeshed = isTypeshed; } protected override string LoadContent() { diff --git a/src/Analysis/Ast/Impl/Types/PythonFunctionType.cs b/src/Analysis/Ast/Impl/Types/PythonFunctionType.cs index 05b3f012e..9f808c466 100644 --- a/src/Analysis/Ast/Impl/Types/PythonFunctionType.cs +++ b/src/Analysis/Ast/Impl/Types/PythonFunctionType.cs @@ -19,6 +19,7 @@ using System.Linq; using Microsoft.Python.Analysis.Values; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Parsing.Ast; namespace Microsoft.Python.Analysis.Types { @@ -29,7 +30,6 @@ internal class PythonFunctionType : PythonType, IPythonFunctionType { private readonly object _lock = new object(); private bool _isAbstract; private bool _isSpecialized; - private string[] _dependencies = Array.Empty(); /// /// Creates function for specializations @@ -137,10 +137,12 @@ internal override void SetDocumentationProvider(Func provider) { internal void Specialize(string[] dependencies) { _isSpecialized = true; - _dependencies = dependencies ?? Array.Empty(); + Dependencies = dependencies != null + ? ImmutableArray.Create(dependencies) + : ImmutableArray.Empty; } - internal IEnumerable Dependencies => _dependencies; + internal ImmutableArray Dependencies { get; private set; } = ImmutableArray.Empty; internal void AddOverload(IPythonFunctionOverload overload) { lock (_lock) { diff --git a/src/Analysis/Ast/Test/AnalysisTestBase.cs b/src/Analysis/Ast/Test/AnalysisTestBase.cs index 15a2795c7..2c7223b3b 100644 --- a/src/Analysis/Ast/Test/AnalysisTestBase.cs +++ b/src/Analysis/Ast/Test/AnalysisTestBase.cs @@ -81,12 +81,8 @@ protected async Task CreateServicesAsync(string root, Interpret sm.AddService(ds); } - TestLogger.Log(TraceEventType.Information, "Create TestDependencyResolver"); - var dependencyResolver = new TestDependencyResolver(); - sm.AddService(dependencyResolver); - TestLogger.Log(TraceEventType.Information, "Create PythonAnalyzer"); - var analyzer = new PythonAnalyzer(sm, root); + var analyzer = new PythonAnalyzer(sm); sm.AddService(analyzer); TestLogger.Log(TraceEventType.Information, "Create PythonInterpreter"); @@ -159,16 +155,12 @@ protected async Task GetAnalysisAsync( TestLogger.Log(TraceEventType.Information, "Ast end"); TestLogger.Log(TraceEventType.Information, "Analysis begin"); - var analysis = await doc.GetAnalysisAsync(CancellationToken.None); + await services.GetService().WaitForCompleteAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); analysis.Should().NotBeNull(); TestLogger.Log(TraceEventType.Information, "Analysis end"); return analysis; } - - private sealed class TestDependencyResolver : IDependencyResolver { - public Task GetDependencyChainAsync(IDocument document, CancellationToken cancellationToken) - => Task.FromResult(new DependencyChainNode(document)); - } } } diff --git a/src/Analysis/Ast/Test/DependencyResolverTests.cs b/src/Analysis/Ast/Test/DependencyResolverTests.cs new file mode 100644 index 000000000..98b647781 --- /dev/null +++ b/src/Analysis/Ast/Test/DependencyResolverTests.cs @@ -0,0 +1,182 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Python.Analysis.Dependencies; +using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.Python.UnitTests.Core.FluentAssertions; +using TestUtilities; + +namespace Microsoft.Python.Analysis.Tests { + [TestClass] + public class DependencyResolverTests { + public TestContext TestContext { get; set; } + + [TestInitialize] + public void TestInitialize() + => TestEnvironmentImpl.TestInitialize($"{TestContext.FullyQualifiedTestClassName}.{TestContext.TestName}"); + + [TestCleanup] + public void Cleanup() => TestEnvironmentImpl.TestCleanup(); + +// ReSharper disable StringLiteralTypo + [DataRow("A:BC|B:C|C", "CBA")] + [DataRow("C|A:BC|B:C", "CBA")] + [DataRow("C|B:AC|A:BC", "CBABA")] + [DataRow("A:CE|B:A|C:B|D:B|E", "[BE]CAB[DC]A")] + [DataRow("A:D|B:DA|C:BA|D:AE|E", "[AE]DADBC")] + [DataRow("A:C|C:B|B:A|D:AF|F:CE|E:BD", "ABCA[DB][EC]FDEF")] + [DataRow("A:BC|B:AC|C:BA|D:BC", "ACBACBD")] + [DataRow("A|B|C|D:AB|E:BC", "[ABC][DE]")] + [DataRow("A:CE|B:A|C:B|D:BC|E|F:C", "[BE]CABC[FDA]")] +// ReSharper restore StringLiteralTypo + [DataTestMethod] + public async Task AddChangesAsync(string input, string output) { + var resolver = new DependencyResolver(new AddChangesAsyncTestDependencyFinder()); + var splitInput = input.Split("|"); + + var walker = default(IDependencyChainWalker); + foreach (var value in splitInput) { + walker = await resolver.AddChangesAsync(value.Split(":")[0], value, 0, default); + } + + var result = new StringBuilder(); + var tasks = new List>>(); + while (!walker.IsCompleted) { + var nodeTask = walker.GetNextAsync(default); + if (!nodeTask.IsCompleted) { + if (tasks.Count > 1) { + result.Append('['); + } + + foreach (var task in tasks) { + result.Append(task.Result.Value[0]); + task.Result.MarkCompleted(); + } + + if (tasks.Count > 1) { + result.Append(']'); + } + + tasks.Clear(); + } + tasks.Add(nodeTask); + } + + result.ToString().Should().Be(output); + } + + [TestMethod] + public void AddChangesAsync_Parallel() { + var resolver = new DependencyResolver(new AddChangesAsyncParallelTestDependencyFinder()); + var tasks = new List>> { + resolver.AddChangesAsync(0, 0, 0, default), + resolver.AddChangesAsync(1, 1, 0, default), + resolver.AddChangesAsync(0, 0, 1, default), + resolver.AddChangesAsync(1, 1, 1, default), + resolver.AddChangesAsync(0, 0, 2, default), + resolver.AddChangesAsync(1, 1, 2, default) + }; + + tasks[0].Should().BeCanceled(); + tasks[1].Should().BeCanceled(); + tasks[2].Should().BeCanceled(); + tasks[3].Should().BeCanceled(); + tasks[4].Should().BeCanceled(); + tasks[5].Should().NotBeCompleted(); + } + + [TestMethod] + public void AddChangesAsync_WrongOrder() { + var resolver = new DependencyResolver(new AddChangesAsyncParallelTestDependencyFinder()); + var tasks = new List>> { + resolver.AddChangesAsync(0, 0, 2, default), + resolver.AddChangesAsync(0, 0, 0, default), + resolver.AddChangesAsync(0, 0, 1, default), + resolver.AddChangesAsync(1, 1, 0, default), + resolver.AddChangesAsync(1, 1, 1, default) + }; + + + tasks[0].Should().NotBeCompleted(); + tasks[1].Should().BeCanceled(); + tasks[2].Should().BeCanceled(); + tasks[3].Should().BeCanceled(); + tasks[4].Should().NotBeCompleted(); + } + + [TestMethod] + public async Task AddChangesAsync_RepeatedChange() { + var resolver = new DependencyResolver(new AddChangesAsyncTestDependencyFinder()); + resolver.AddChangesAsync("A", "A:B", 0, default).DoNotWait(); + resolver.AddChangesAsync("B", "B:C", 0, default).DoNotWait(); + var walker = await resolver.AddChangesAsync("C", "C", 0, default); + + var result = new StringBuilder(); + while (!walker.IsCompleted) { + var node = await walker.GetNextAsync(default); + result.Append(node.Value[0]); + node.MarkCompleted(); + } + + result.ToString().Should().Be("CBA"); + + walker = await resolver.AddChangesAsync("B", "B:C", 1, default); + result = new StringBuilder(); + while (!walker.IsCompleted) { + var node = await walker.GetNextAsync(default); + result.Append(node.Value[0]); + node.MarkCompleted(); + } + + result.ToString().Should().Be("BA"); + } + + [TestMethod] + public async Task AddChangesAsync_RepeatedChange2() { + var resolver = new DependencyResolver(new AddChangesAsyncTestDependencyFinder()); + resolver.AddChangesAsync("A", "A:B", 0, default).DoNotWait(); + resolver.AddChangesAsync("B", "B", 0, default).DoNotWait(); + resolver.AddChangesAsync("C", "C:D", 0, default).DoNotWait(); + var walker = await resolver.AddChangesAsync("D", "D", 0, default); + + var result = new StringBuilder(); + while (!walker.IsCompleted) { + var node = await walker.GetNextAsync(default); + result.Append(node.Value[0]); + node.MarkCompleted(); + } + + result.ToString().Should().Be("BDAC"); + + + resolver.AddChangesAsync("D", "D", 1, default).DoNotWait(); + walker = await resolver.AddChangesAsync("B", "B:C", 1, default); + result = new StringBuilder(); + while (!walker.IsCompleted) { + var node = await walker.GetNextAsync(default); + result.Append(node.Value[0]); + node.MarkCompleted(); + } + + result.ToString().Should().Be("DCBA"); + } + + private sealed class AddChangesAsyncParallelTestDependencyFinder : IDependencyFinder { + public Task> FindDependenciesAsync(int value, CancellationToken cancellationToken) + => new TaskCompletionSource>().Task.ContinueWith(t => t.GetAwaiter().GetResult(), cancellationToken); + } + + private sealed class AddChangesAsyncTestDependencyFinder : IDependencyFinder { + public Task> FindDependenciesAsync(string value, CancellationToken cancellationToken) { + var kv = value.Split(":"); + var dependencies = kv.Length == 1 ? ImmutableArray.Empty : ImmutableArray.Create(kv[1].Select(c => c.ToString()).ToList()); + return Task.FromResult(dependencies); + } + } + } +} diff --git a/src/Analysis/Ast/Test/ScrapeTests.cs b/src/Analysis/Ast/Test/ScrapeTests.cs index 647ed39d2..0c5472116 100644 --- a/src/Analysis/Ast/Test/ScrapeTests.cs +++ b/src/Analysis/Ast/Test/ScrapeTests.cs @@ -96,7 +96,8 @@ private async Task CompiledBuiltinScrapeAsync(InterpreterConfiguration configura } Console.WriteLine(@"Importing {0} from {1}", mp.ModuleName, mp.SourceFile); - var mod = await interpreter.ModuleResolution.ImportModuleAsync(mp.ModuleName); + var mod = interpreter.ModuleResolution.GetOrLoadModule(mp.ModuleName); + await mod.LoadAndAnalyzeAsync(); Assert.IsInstanceOfType(mod, typeof(CompiledPythonModule)); await ((ModuleCache)interpreter.ModuleResolution.ModuleCache).CacheWritingTask; @@ -147,7 +148,8 @@ private async Task BuiltinScrape(InterpreterConfiguration configuration) { var services = await CreateServicesAsync(moduleDirectory, configuration); var interpreter = services.GetService(); - var mod = await interpreter.ModuleResolution.ImportModuleAsync(interpreter.ModuleResolution.BuiltinModuleName, new CancellationTokenSource(5000).Token); + var mod = interpreter.ModuleResolution.GetOrLoadModule(interpreter.ModuleResolution.BuiltinModuleName); + await mod.LoadAndAnalyzeAsync(new CancellationTokenSource(5000).Token); Assert.IsInstanceOfType(mod, typeof(BuiltinsPythonModule)); var modPath = interpreter.ModuleResolution.ModuleCache.GetCacheFilePath(interpreter.Configuration.InterpreterPath); @@ -277,28 +279,29 @@ private async Task FullStdLibTest(InterpreterConfiguration configuration, params foreach (var r in set) { var modName = r.Item1; - var mod = await interpreter.ModuleResolution.ImportModuleAsync(r.Item2); + + var mod = interpreter.ModuleResolution.GetOrLoadModule(r.Item2); + await mod.LoadAndAnalyzeAsync(new CancellationTokenSource(10000).Token); anyExtensionSeen |= modName.IsNativeExtension; switch (mod) { case null: Trace.TraceWarning("failed to import {0} from {1}", modName.ModuleName, modName.SourceFile); break; - case CompiledPythonModule _: { - var errors = ((IDocument)mod).GetParseErrors().ToArray(); - if (errors.Any()) { - anyParseError = true; - Trace.TraceError("Parse errors in {0}", modName.SourceFile); - foreach (var e in errors) { - Trace.TraceError(e.Message); - } - } else { - anySuccess = true; - anyExtensionSuccess |= modName.IsNativeExtension; + case CompiledPythonModule compiledPythonModule: + var errors = compiledPythonModule.GetParseErrors().ToArray(); + if (errors.Any()) { + anyParseError = true; + Trace.TraceError("Parse errors in {0}", modName.SourceFile); + foreach (var e in errors) { + Trace.TraceError(e.Message); } - - break; + } else { + anySuccess = true; + anyExtensionSuccess |= modName.IsNativeExtension; } + + break; case IPythonModule _: { var filteredErrors = ((IDocument)mod).GetParseErrors().Where(e => !e.Message.Contains("encoding problem")).ToArray(); if (filteredErrors.Any()) { diff --git a/src/Analysis/Core/Impl/DependencyResolution/AstUtilities.cs b/src/Analysis/Core/Impl/DependencyResolution/AstUtilities.cs index 132d06ff1..85a3274c7 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/AstUtilities.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/AstUtilities.cs @@ -18,6 +18,9 @@ namespace Microsoft.Python.Analysis.Core.DependencyResolution { public static class AstUtilities { + public static IImportSearchResult FindImports(this PathResolverSnapshot pathResolver, string modulePath, ModuleName importName, bool forceAbsolute) + => pathResolver.GetImportsFromAbsoluteName(modulePath, importName.Names.Select(n => n.Name), forceAbsolute); + public static IImportSearchResult FindImports(this PathResolverSnapshot pathResolver, string modulePath, FromImportStatement fromImportStatement) { var rootNames = fromImportStatement.Root.Names.Select(n => n.Name); return fromImportStatement.Root is RelativeModuleName relativeName diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolver.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolver.cs index cf73189c8..d1be64623 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolver.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PathResolver.cs @@ -41,8 +41,8 @@ public IEnumerable SetInterpreterSearchPaths(in IEnumerable sear public void SetBuiltins(in IEnumerable builtinModuleNames) => _currentSnapshot = _currentSnapshot.SetBuiltins(builtinModuleNames); public void RemoveModulePath(in string path) => _currentSnapshot = _currentSnapshot.RemoveModulePath(path); - public bool TryAddModulePath(in string path, out string fullModuleName) { - _currentSnapshot = _currentSnapshot.AddModulePath(path, out fullModuleName); + public bool TryAddModulePath(in string path, in bool allowNonRooted, out string fullModuleName) { + _currentSnapshot = _currentSnapshot.AddModulePath(path, allowNonRooted, out fullModuleName); return fullModuleName != null; } diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Edge.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Edge.cs index 37b1ce011..10847d27d 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Edge.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Edge.cs @@ -14,9 +14,10 @@ // permissions and limitations under the License. using System.Diagnostics; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Analysis.Core.DependencyResolution { - public partial struct PathResolverSnapshot { + public partial class PathResolverSnapshot { /// /// Represents the edge between two nodes in the tree /// @@ -73,7 +74,7 @@ private Edge(ImmutableArray<(int index, Node node)> vertices, int index) { public Edge Append(int nextVertexIndex) { var nextVertex = End.Children[nextVertexIndex]; var trimLength = _vertices.Count - _index - 1; - var vertices = _vertices.TrimEnd(trimLength).Add((nextVertexIndex, nextVertex)); + var vertices = _vertices.ReplaceAt(_index + 1, trimLength, (nextVertexIndex, nextVertex)); return new Edge(vertices, _index + 1); } diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.ImmutableArray.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.ImmutableArray.cs deleted file mode 100644 index ed4d79181..000000000 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.ImmutableArray.cs +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright(c) Microsoft Corporation -// All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the License); you may not use -// this file except in compliance with the License. You may obtain a copy of the -// License at http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS -// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY -// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABILITY OR NON-INFRINGEMENT. -// -// See the Apache Version 2.0 License for specific language governing -// permissions and limitations under the License. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics.Contracts; -using System.Runtime.CompilerServices; - -namespace Microsoft.Python.Analysis.Core.DependencyResolution { - public partial struct PathResolverSnapshot { - /// - /// This type is a compromise between an array (fast access, slow and expensive copying for every immutable change) and binary tree based immutable types - /// Access is almost as fast as in array, adding is as fast as in List (almost identical implementation), - /// setting new value and removal of anything but last element always requires full array copying. - /// Can't be made public because TrimEnd changes the length, but preserves original array, so referenced objects are persisted, which can be a problem for bigger objects. - /// - /// - private struct ImmutableArray : IEnumerable { - private readonly T[] _items; - private readonly int _size; // Size of the part of array that is used. Equal or less than _items.Length - - private ImmutableArray(T[] items, int size, int count) { - _items = items; - _size = size; - Count = count; - } - - public static ImmutableArray Empty { get; } = new ImmutableArray(Array.Empty(), 0, 0); - - public T this[int index] => _items[index]; - public int Count { get; } // Length of the ImmutableArray. Equal or less than _size. - - [Pure] - public ImmutableArray Add(T item) { - var newCount = Count + 1; - var newItems = _items; - - if (_size != Count || newCount > _items.Length) { - var capacity = GetCapacity(newCount); - newItems = new T[capacity]; - Array.Copy(_items, 0, newItems, 0, Count); - } - - newItems[Count] = item; - return new ImmutableArray(newItems, newCount, newCount); - } - - [Pure] - public ImmutableArray AddRange(T[] items) { - if (items.Length == 0) { - return this; - } - - var newCount = Count + items.Length; - var newItems = _items; - - if (_size != Count || newCount > _items.Length) { - var capacity = GetCapacity(newCount); - newItems = new T[capacity]; - Array.Copy(_items, 0, newItems, 0, Count); - } - - Array.Copy(items, 0, newItems, Count, items.Length); - return new ImmutableArray(newItems, newCount, newCount); - } - - [Pure] - public ImmutableArray RemoveAt(int index) { - var newCount = Count - 1; - if (index == newCount) { - return new ImmutableArray(_items, _size, newCount); - } - - var capacity = GetCapacity(newCount); - var newArray = new T[capacity]; - - if (index > 0) { - Array.Copy(_items, newArray, index); - } - - Array.Copy(_items, index + 1, newArray, index, newCount - index); - return new ImmutableArray(newArray, newCount, newCount); - } - - [Pure] - public ImmutableArray TrimEnd(int trimLength) - => trimLength >= Count ? Empty : new ImmutableArray(_items, _size, Count - trimLength); - - [Pure] - public ImmutableArray ReplaceAt(int index, T value) { - var capacity = GetCapacity(Count); - var newArray = new T[capacity]; - Array.Copy(_items, newArray, Count); - newArray[index] = value; - return new ImmutableArray(newArray, Count, Count); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private int GetCapacity(int length) { - var capacity = _items.Length; - - if (capacity == 0) { - capacity = 4; - } - - while (length > capacity) { - capacity = capacity * 2; - } - - while (length < capacity / 2 && capacity > 4) { - capacity = capacity / 2; - } - - return capacity; - } - - private bool Equals(ImmutableArray other) - => Equals(_items, other._items) && _size == other._size && Count == other.Count; - - public override bool Equals(object obj) => obj is ImmutableArray other && Equals(other); - - public override int GetHashCode() { - unchecked { - var hashCode = (_items != null ? _items.GetHashCode() : 0); - hashCode = (hashCode * 397) ^ _size; - hashCode = (hashCode * 397) ^ Count; - return hashCode; - } - } - - public static bool operator ==(ImmutableArray left, ImmutableArray right) { - return left.Equals(right); - } - - public static bool operator !=(ImmutableArray left, ImmutableArray right) { - return !left.Equals(right); - } - - public Enumerator GetEnumerator() - => new Enumerator(this); - - IEnumerator IEnumerable.GetEnumerator() - => new Enumerator(this); - - IEnumerator IEnumerable.GetEnumerator() - => new Enumerator(this); - - - public struct Enumerator : IEnumerator { - private readonly ImmutableArray _owner; - private int _index; - - internal Enumerator(ImmutableArray owner) { - _owner = owner; - _index = 0; - Current = default; - } - - public void Dispose() {} - - public bool MoveNext() { - var localList = _owner; - - if (_index < localList.Count) { - Current = localList._items[_index]; - _index++; - return true; - } - - _index = _owner._size + 1; - Current = default; - return false; - } - - public T Current { get; private set; } - - object IEnumerator.Current => Current; - - void IEnumerator.Reset() { - _index = 0; - Current = default; - } - } - } - } -} diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Node.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Node.cs index da6147a3b..191ee5a7b 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Node.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.Node.cs @@ -16,11 +16,12 @@ using System.Diagnostics; using System.Text; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Core.IO; using Microsoft.Python.Core.Text; namespace Microsoft.Python.Analysis.Core.DependencyResolution { - public partial struct PathResolverSnapshot { + public partial class PathResolverSnapshot { [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] private class Node { public readonly ImmutableArray Children; diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs index 6f773a0f3..9807c34ed 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs @@ -21,15 +21,13 @@ using System.Runtime.InteropServices; using System.Text; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Core.IO; using Microsoft.Python.Core.Text; using Microsoft.Python.Parsing; namespace Microsoft.Python.Analysis.Core.DependencyResolution { - public partial struct PathResolverSnapshot { - private static readonly bool IgnoreCaseInPaths = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX); - private static readonly StringComparison PathsStringComparison = IgnoreCaseInPaths ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; - + public partial class PathResolverSnapshot { // This root contains module paths that don't belong to any known search path. // The directory of the module is stored on the first level, and name is stored on the second level // For example, "c:\dir\sub_dir2\module1.py" will be stored like this: @@ -68,17 +66,16 @@ private PathResolverSnapshot(PythonLanguageVersion pythonLanguageVersion, string Version = version; } - public IEnumerable GetAllModuleNames() => GetModuleNames(_roots.Prepend(_nonRooted).Append(_builtins)); + public IEnumerable GetAllModuleNames() => GetModuleNames(_roots.Prepend(_nonRooted)); public IEnumerable GetInterpreterModuleNames() => GetModuleNames(_roots.Skip(_userRootsCount).Append(_builtins)); - private IEnumerable GetModuleNames(IEnumerable roots) { - var builtins = new HashSet(_builtins.Children); - return roots.SelectMany(r => r.TraverseBreadthFirst(n => n.IsModule? Enumerable.Empty() : n.Children)) - .Where(n => n.IsModule || builtins.Contains(n)) + private IEnumerable GetModuleNames(IEnumerable roots) => roots + .SelectMany(r => r.TraverseBreadthFirst(n => n.IsModule ? Enumerable.Empty() : n.Children)) + .Where(n => n.IsModule) + .Concat(_builtins.Children) .Select(n => n.FullModuleName); - } - public ModuleImport GetModuleImportFromModuleName(in string fullModuleName) { + public ModuleImport GetModuleImportFromModuleName(in string fullModuleName) { foreach (var root in _roots) { var node = root; var matched = true; @@ -176,8 +173,9 @@ public IImportSearchResult GetImportsFromAbsoluteName(in string modulePath, in I var possibleFullName = string.Join(".", fullNameList); var rootPath = shortestPath.FirstEdge.End.Name; var existingModuleFullName = shortestPath.End.FullModuleName; + var existingModulePath = shortestPath.End.ModulePath; var remainingNameParts = fullNameList.Skip(shortestPath.PathLength - 1).ToList(); - return new PossibleModuleImport(possibleFullName, rootPath, existingModuleFullName, remainingNameParts); + return new PossibleModuleImport(possibleFullName, rootPath, existingModuleFullName, existingModulePath, remainingNameParts); } return new ImportNotFound(string.Join(".", fullNameList)); @@ -308,7 +306,7 @@ public PathResolverSnapshot SetWorkDirectory(in string workDirectory, out IEnume ? PathUtils.NormalizePath(workDirectory) : string.Empty; - if (_workDirectory.Equals(normalizedRootDirectory, PathsStringComparison)) { + if (_workDirectory.PathEquals(normalizedRootDirectory)) { addedRoots = Enumerable.Empty(); return this; } @@ -401,11 +399,11 @@ private ImmutableArray AddRootsFromSearchPaths(string[] userSearchPaths, s } private Node GetOrCreateRoot(string path) - => _roots.FirstOrDefault(r => r.Name.Equals(path, PathsStringComparison)) ?? Node.CreateRoot(path); + => _roots.FirstOrDefault(r => r.Name.PathEquals(path)) ?? Node.CreateRoot(path); - public PathResolverSnapshot AddModulePath(in string modulePath, out string fullModuleName) { + public PathResolverSnapshot AddModulePath(in string modulePath, in bool allowNonRooted, out string fullModuleName) { var isFound = TryFindModule(modulePath, out var lastEdge, out var unmatchedPathSpan); - if (unmatchedPathSpan.Source == default) { + if (unmatchedPathSpan.Source == default || (!allowNonRooted && lastEdge.IsNonRooted)) { // Not a module fullModuleName = null; return this; @@ -694,7 +692,7 @@ private bool TryFindModule(string modulePath, out Edge lastEdge, out StringSpan var rootIndex = 0; while (rootIndex < _roots.Count) { var rootPath = _roots[rootIndex].Name; - if (normalizedPath.StartsWithOrdinal(rootPath, IgnoreCaseInPaths) && IsRootedPathEndsWithValidNames(normalizedPath, rootPath.Length)) { + if (normalizedPath.PathStartsWith(rootPath) && IsRootedPathEndsWithValidNames(normalizedPath, rootPath.Length)) { break; } @@ -734,10 +732,10 @@ private static bool IsValidIdentifier(string str, int start, int length) => str[start].IsLatin1LetterOrUnderscore() && str.CharsAreLatin1LetterOrDigitOrUnderscore(start + 1, length - 1); private static bool IsPythonFile(string rootedPath) - => rootedPath.EndsWithAnyOrdinal(new[] { ".py", ".pyi", ".pyw" }, IgnoreCaseInPaths); + => rootedPath.PathEndsWithAny(".py", ".pyi", ".pyw"); private static bool IsPythonCompiled(string rootedPath) - => rootedPath.EndsWithAnyOrdinal(new[] { ".pyd", ".so", ".dylib" }, IgnoreCaseInPaths); + => rootedPath.PathEndsWithAny(".pyd", ".so", ".dylib"); private static int GetModuleNameStart(string rootedModulePath) => rootedModulePath.LastIndexOf(Path.DirectorySeparatorChar) + 1; diff --git a/src/Analysis/Core/Impl/DependencyResolution/PossibleModuleImport.cs b/src/Analysis/Core/Impl/DependencyResolution/PossibleModuleImport.cs index aca2d8b1d..5a3f42801 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PossibleModuleImport.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PossibleModuleImport.cs @@ -20,12 +20,14 @@ public class PossibleModuleImport : IImportSearchResult { public string PossibleModuleFullName { get; } public string RootPath { get; } public string PrecedingModuleFullName { get; } + public string PrecedingModulePath { get; } public IReadOnlyList RemainingNameParts { get; } - public PossibleModuleImport(string possibleModuleFullName, string rootPath, string precedingModuleFullName, IReadOnlyList remainingNameParts) { + public PossibleModuleImport(string possibleModuleFullName, string rootPath, string precedingModuleFullName, string precedingModulePath, IReadOnlyList remainingNameParts) { PossibleModuleFullName = possibleModuleFullName; RootPath = rootPath; PrecedingModuleFullName = precedingModuleFullName; + PrecedingModulePath = precedingModulePath; RemainingNameParts = remainingNameParts; } } diff --git a/src/Core/Impl/Collections/ImmutableArray.cs b/src/Core/Impl/Collections/ImmutableArray.cs new file mode 100644 index 000000000..64fd01443 --- /dev/null +++ b/src/Core/Impl/Collections/ImmutableArray.cs @@ -0,0 +1,308 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Runtime.CompilerServices; + +namespace Microsoft.Python.Core.Collections { + /// + /// This type is a compromise between an array (fast access, slow and expensive copying for every immutable change) and binary tree based immutable types + /// Access is almost as fast as in array, adding is as fast as in List (almost identical implementation), + /// setting new value and removal of anything but last element always requires full array copying. + /// + /// + public struct ImmutableArray : IReadOnlyList, IEquatable> { + private readonly T[] _items; + + private ImmutableArray(T[] items, int count) { + _items = items; + Count = count; + } + + public static ImmutableArray Empty { get; } = new ImmutableArray(Array.Empty(), 0); + + public static ImmutableArray Create(T item) { + var items = new T[1]; + items[0] = item; + return new ImmutableArray(items, 1); + } + + public static ImmutableArray Create(T[] array) { + var items = new T[array.Length]; + Array.Copy(array, items, array.Length); + return new ImmutableArray(items, items.Length); + } + + public static ImmutableArray Create(List list) { + var items = new T[list.Count]; + list.CopyTo(items); + return new ImmutableArray(items, items.Length); + } + + public static ImmutableArray Create(HashSet hashSet) { + var items = new T[hashSet.Count]; + hashSet.CopyTo(items); + return new ImmutableArray(items, items.Length); + } + + public static ImmutableArray Create(Dictionary.ValueCollection collection) { + var items = new T[collection.Count]; + collection.CopyTo(items, 0); + return new ImmutableArray(items, items.Length); + } + + public T this[int index] => _items[index]; + public int Count { get; } // Length of the ImmutableArray. + + [Pure] + public ImmutableArray Add(T item) { + var newCount = Count + 1; + var newItems = _items; + + if (newCount > _items.Length) { + var capacity = GetCapacity(newCount); + newItems = new T[capacity]; + Array.Copy(_items, 0, newItems, 0, Count); + } + + newItems[Count] = item; + return new ImmutableArray(newItems, newCount); + } + + [Pure] + public ImmutableArray AddRange(T[] items) { + if (items.Length == 0) { + return this; + } + + var newCount = Count + items.Length; + var newItems = _items; + + if (newCount > _items.Length) { + var capacity = GetCapacity(newCount); + newItems = new T[capacity]; + Array.Copy(_items, 0, newItems, 0, Count); + } + + Array.Copy(items, 0, newItems, Count, items.Length); + return new ImmutableArray(newItems, newCount); + } + + [Pure] + public ImmutableArray Remove(T value) { + var index = IndexOf(value); + return index >= 0 ? RemoveAt(index) : this; + } + + [Pure] + public ImmutableArray RemoveAt(int index) { + var newCount = Count - 1; + + var capacity = GetCapacity(newCount); + var newArray = new T[capacity]; + + if (index > 0) { + Array.Copy(_items, newArray, index); + } + + if (index < newCount) { + Array.Copy(_items, index + 1, newArray, index, newCount - index); + } + + return new ImmutableArray(newArray, newCount); + } + + [Pure] + public ImmutableArray InsertAt(int index, T value) { + if (index > Count) { + throw new IndexOutOfRangeException(); + } + + if (index == Count) { + return Add(value); + } + + var newCount = Count + 1; + var capacity = GetCapacity(newCount); + var newArray = new T[capacity]; + + if (index > 0) { + Array.Copy(_items, newArray, index); + } + + newArray[index] = value; + Array.Copy(_items, index, newArray, index + 1, Count - index); + + return new ImmutableArray(newArray, newCount); + } + + [Pure] + public ImmutableArray ReplaceAt(int startIndex, int length, T value) { + if (length == 0) { + return InsertAt(startIndex, value); + } + + if (length == 1) { + return ReplaceAt(startIndex, value); + } + + var newCount = Math.Max(Count - length + 1, startIndex + 1); + var capacity = GetCapacity(newCount); + var newArray = new T[capacity]; + + if (startIndex > 0) { + Array.Copy(_items, newArray, startIndex); + } + + newArray[startIndex + 1] = value; + + if (startIndex + 2 < newCount) { + Array.Copy(_items, startIndex + length + 1, newArray, startIndex + 1, newCount - startIndex - 2); + } + + return new ImmutableArray(newArray, newCount); + } + + [Pure] + public ImmutableArray ReplaceAt(int index, T value) { + var capacity = GetCapacity(Count); + var newArray = new T[capacity]; + Array.Copy(_items, newArray, Count); + newArray[index] = value; + return new ImmutableArray(newArray, Count); + } + + [Pure] + public ImmutableArray Where(Func predicate) { + var count = 0; + for (var i = 0; i < Count; i++) { + if (predicate(_items[i])) { + count++; + } + } + + var index = 0; + var items = new T[count]; + for (var i = 0; i < Count; i++) { + if (predicate(_items[i])) { + items[index] = _items[i]; + index++; + } + } + + return new ImmutableArray(items, items.Length); + } + + [Pure] + public ImmutableArray Select(Func selector) { + var items = new TResult[Count]; + for (var i = 0; i < Count; i++) { + items[i] = selector(_items[i]); + } + return new ImmutableArray(items, items.Length); + } + + [Pure] + public int IndexOf(T value) => Array.IndexOf(_items, value, 0, Count); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int GetCapacity(int length) { + var capacity = _items.Length; + + if (capacity == 0) { + capacity = 4; + } + + while (length > capacity) { + capacity = capacity * 2; + } + + while (length < capacity / 2 && capacity > 4) { + capacity = capacity / 2; + } + + return capacity; + } + + public bool Equals(ImmutableArray other) + => Equals(_items, other._items) && Count == other.Count; + + public override bool Equals(object obj) => obj is ImmutableArray other && Equals(other); + + public override int GetHashCode() { + unchecked { + var hashCode = (_items != null ? _items.GetHashCode() : 0); + hashCode = (hashCode * 397) ^ Count; + return hashCode; + } + } + + public static bool operator ==(ImmutableArray left, ImmutableArray right) { + return left.Equals(right); + } + + public static bool operator !=(ImmutableArray left, ImmutableArray right) { + return !left.Equals(right); + } + + public Enumerator GetEnumerator() + => new Enumerator(this); + + IEnumerator IEnumerable.GetEnumerator() + => new Enumerator(this); + + IEnumerator IEnumerable.GetEnumerator() + => new Enumerator(this); + + public struct Enumerator : IEnumerator { + private readonly ImmutableArray _owner; + private int _index; + + internal Enumerator(ImmutableArray owner) { + _owner = owner; + _index = 0; + Current = default; + } + + public void Dispose() { } + + public bool MoveNext() { + var localList = _owner; + + if (_index < localList.Count) { + Current = localList._items[_index]; + _index++; + return true; + } + + _index = _owner.Count + 1; + Current = default; + return false; + } + + public T Current { get; private set; } + + object IEnumerator.Current => Current; + + void IEnumerator.Reset() { + _index = 0; + Current = default; + } + } + } +} diff --git a/src/Core/Impl/Diagnostics/Check.cs b/src/Core/Impl/Diagnostics/Check.cs index 9437ca48e..21879458c 100644 --- a/src/Core/Impl/Diagnostics/Check.cs +++ b/src/Core/Impl/Diagnostics/Check.cs @@ -59,6 +59,20 @@ public static void ArgumentOutOfRange(string argumentName, Func predicate) } } + [DebuggerStepThrough] + public static void ArgumentOutOfRange(string argumentName, bool isInRange) { + if (isInRange) { + throw new ArgumentOutOfRangeException(argumentName); + } + } + + [DebuggerStepThrough] + public static void ArgumentOutOfRange(string argumentName, T value, params T[] allowedValues) where T : Enum { + if (Array.IndexOf(allowedValues, value) == -1) { + throw new ArgumentOutOfRangeException(argumentName); + } + } + [DebuggerStepThrough] public static void InvalidOperation(Func predicate, string message = null) { if (!predicate()) { @@ -66,6 +80,13 @@ public static void InvalidOperation(Func predicate, string message = null) } } + [DebuggerStepThrough] + public static void InvalidOperation(bool condition, string message = null) { + if (!condition) { + throw new InvalidOperationException(message ?? string.Empty); + } + } + [DebuggerStepThrough] public static void Argument(string argumentName, Func predicate) { if (!predicate()) { diff --git a/src/Core/Impl/Extensions/EnumerableExtensions.cs b/src/Core/Impl/Extensions/EnumerableExtensions.cs index f64014627..709b4ee43 100644 --- a/src/Core/Impl/Extensions/EnumerableExtensions.cs +++ b/src/Core/Impl/Extensions/EnumerableExtensions.cs @@ -89,6 +89,48 @@ public static IEnumerable IndexWhere(this IEnumerable source, Func ToDictionary(this IEnumerable source, Func keySelector, Func valueSelector) { + var dictionary = source is IReadOnlyCollection collection + ? new Dictionary(collection.Count) + : new Dictionary(); + + var index = 0; + foreach (var item in source) { + var key = keySelector(item, index); + var value = valueSelector(item, index); + dictionary.Add(key, value); + index++; + } + + return dictionary; + } + + public static IEnumerable TraverseDepthFirst(this T root, Func> selectChildren) { + var items = new Stack(); + var reverseChildren = new Stack(); + + items.Push(root); + while (items.Count > 0) { + var item = items.Pop(); + yield return item; + + var children = selectChildren(item); + if (children == null) { + continue; + } + + foreach (var child in children) { + reverseChildren.Push(child); + } + + foreach (var child in reverseChildren) { + items.Push(child); + } + + reverseChildren.Clear(); + } + } + public static IEnumerable TraverseBreadthFirst(this T root, Func> selectChildren) { var items = new Queue(); items.Enqueue(root); diff --git a/src/Core/Impl/Extensions/StringExtensions.cs b/src/Core/Impl/Extensions/StringExtensions.cs index 385cfb725..97babbc0a 100644 --- a/src/Core/Impl/Extensions/StringExtensions.cs +++ b/src/Core/Impl/Extensions/StringExtensions.cs @@ -18,11 +18,15 @@ using System.Diagnostics; using System.Globalization; using System.Linq; +using System.Runtime.InteropServices; using System.Text.RegularExpressions; using Microsoft.Python.Core.Text; namespace Microsoft.Python.Core { public static class StringExtensions { + private static readonly bool IgnoreCaseInPaths = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX); + private static readonly StringComparison PathsStringComparison = IgnoreCaseInPaths ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; + #if DEBUG private static readonly Regex SubstitutionRegex = new Regex( @"\{(\d+)", @@ -141,12 +145,18 @@ public static string QuoteArgument(this string arg) { return "\"{0}\"".FormatInvariant(arg); } + public static bool PathStartsWith(this string s, string prefix) + => s?.StartsWith(prefix, PathsStringComparison) ?? false; + public static bool StartsWithOrdinal(this string s, string prefix, bool ignoreCase = false) => s?.StartsWith(prefix, ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal) ?? false; public static bool EndsWithOrdinal(this string s, string suffix, bool ignoreCase = false) => s?.EndsWith(suffix, ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal) ?? false; + public static bool PathEndsWithAny(this string s, params string[] values) + => s.EndsWithAnyOrdinal(values, IgnoreCaseInPaths); + public static bool EndsWithAnyOrdinal(this string s, params string[] values) => s.EndsWithAnyOrdinal(values, false); @@ -190,9 +200,15 @@ public static bool EqualsIgnoreCase(this string s, string other) public static bool EqualsOrdinal(this string s, string other) => string.Equals(s, other, StringComparison.Ordinal); + public static bool PathEquals(this string s, string other) + => string.Equals(s, other, PathsStringComparison); + public static bool EqualsOrdinal(this string s, int index, string other, int otherIndex, int length, bool ignoreCase = false) => string.Compare(s, index, other, otherIndex, length, ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal) == 0; + public static int GetPathHashCode(this string s) + => IgnoreCaseInPaths ? StringComparer.OrdinalIgnoreCase.GetHashCode(s) : StringComparer.Ordinal.GetHashCode(s); + public static string[] Split(this string s, char separator, int startIndex, int length) { var count = 0; var endIndex = startIndex + length; diff --git a/src/Core/Impl/Extensions/TaskCompletionSourceExtensions.cs b/src/Core/Impl/Extensions/TaskCompletionSourceExtensions.cs index 7054a2cd0..20f72fbcb 100644 --- a/src/Core/Impl/Extensions/TaskCompletionSourceExtensions.cs +++ b/src/Core/Impl/Extensions/TaskCompletionSourceExtensions.cs @@ -32,18 +32,6 @@ public static CancellationTokenRegistration RegisterForCancellation(this Task return cancellationToken.Register(action.Invoke); } - private struct TrySetResultStateAction { - public TaskCompletionSource Tcs { get; } - public T Result { get; } - - public TrySetResultStateAction(TaskCompletionSource tcs, T result) { - Tcs = tcs; - Result = result; - } - - public void Invoke(object state) => Tcs.TrySetResult(Result); - } - private struct CancelOnTokenAction { private readonly TaskCompletionSource _taskCompletionSource; private readonly CancellationToken _cancellationToken; diff --git a/src/Core/Impl/Microsoft.Python.Core.csproj b/src/Core/Impl/Microsoft.Python.Core.csproj index 4b3da4419..6bcb47038 100644 --- a/src/Core/Impl/Microsoft.Python.Core.csproj +++ b/src/Core/Impl/Microsoft.Python.Core.csproj @@ -9,7 +9,7 @@ 1701, 1702 - "You may need to supply assembly policy" --> 1701;1702;$(NoWarn) - 7.2 + 7.3 diff --git a/src/Core/Impl/OS/ProcessHelper.cs b/src/Core/Impl/OS/ProcessHelper.cs index 75c25dbfc..0f4710735 100644 --- a/src/Core/Impl/OS/ProcessHelper.cs +++ b/src/Core/Impl/OS/ProcessHelper.cs @@ -24,7 +24,7 @@ namespace Microsoft.Python.Core.OS { public sealed class ProcessHelper : IDisposable { private Process _process; private int? _exitCode; - private readonly SemaphoreSlim _seenNullOutput, _seenNullError; + private readonly AsyncManualResetEvent _seenNullOutput, _seenNullError; public ProcessHelper(string filename, IEnumerable arguments, string workingDir = null) { if (!File.Exists(filename)) { @@ -44,8 +44,8 @@ public ProcessHelper(string filename, IEnumerable arguments, string work RedirectStandardError = true }; - _seenNullOutput = new SemaphoreSlim(1); - _seenNullError = new SemaphoreSlim(1); + _seenNullOutput = new AsyncManualResetEvent(); + _seenNullError = new AsyncManualResetEvent(); } public ProcessStartInfo StartInfo { get; } @@ -56,19 +56,20 @@ public ProcessHelper(string filename, IEnumerable arguments, string work public Action OnErrorLine { get; set; } public void Dispose() { - _seenNullOutput.Dispose(); - _seenNullError.Dispose(); + _seenNullOutput.Set(); + _seenNullError.Set(); _process?.Dispose(); } public void Start() { - _seenNullOutput.Wait(0); - _seenNullError.Wait(0); + _seenNullOutput.Reset(); + _seenNullError.Reset(); var p = new Process { StartInfo = StartInfo }; + p.Exited += Process_Exited; p.OutputDataReceived += Process_OutputDataReceived; p.ErrorDataReceived += Process_ErrorDataReceived; @@ -79,10 +80,11 @@ public void Start() { // clean up. _exitCode = ex.HResult; OnErrorLine?.Invoke(ex.ToString()); - _seenNullError.Release(); - _seenNullOutput.Release(); + _seenNullOutput.Set(); + _seenNullError.Set(); p.OutputDataReceived -= Process_OutputDataReceived; p.ErrorDataReceived -= Process_ErrorDataReceived; + p.Exited -= Process_Exited; return; } @@ -100,7 +102,7 @@ public void Start() { private void Process_ErrorDataReceived(object sender, DataReceivedEventArgs e) { try { if (e.Data == null) { - _seenNullError.Release(); + _seenNullError.Set(); } else { OnErrorLine?.Invoke(e.Data.TrimEnd()); } @@ -109,10 +111,15 @@ private void Process_ErrorDataReceived(object sender, DataReceivedEventArgs e) { } } + private void Process_Exited(object sender, EventArgs eventArgs) { + _seenNullOutput.Set(); + _seenNullError.Set(); + } + private void Process_OutputDataReceived(object sender, DataReceivedEventArgs e) { try { if (e.Data == null) { - _seenNullOutput.Release(); + _seenNullOutput.Set(); } else { OnOutputLine?.Invoke(e.Data.TrimEnd()); } diff --git a/src/Core/Impl/Threading/AsyncCountdownEvent.cs b/src/Core/Impl/Threading/AsyncCountdownEvent.cs new file mode 100644 index 000000000..ea3b8b761 --- /dev/null +++ b/src/Core/Impl/Threading/AsyncCountdownEvent.cs @@ -0,0 +1,66 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Core.Disposables; + +namespace Microsoft.Python.Core { + public class AsyncCountdownEvent { + private readonly AsyncManualResetEvent _mre = new AsyncManualResetEvent(); + private int _count; + + public AsyncCountdownEvent(int initialCount) { + if (initialCount < 0) { + throw new ArgumentOutOfRangeException(nameof(initialCount)); + } + + _count = initialCount; + if (initialCount == 0) { + _mre.Set(); + } + } + + public Task WaitAsync() => _mre.WaitAsync(); + + public Task WaitAsync(CancellationToken cancellationToken) => _mre.WaitAsync(cancellationToken); + + public void Signal() { + if (_count <= 0) { + throw new InvalidOperationException(); + } + + var count = Interlocked.Decrement(ref _count); + if (count < 0) { + throw new InvalidOperationException(); + } + + if (count == 0) { + _mre.Set(); + } + } + + public void AddOne() { + _mre.Reset(); + Interlocked.Increment(ref _count); + } + + public IDisposable AddOneDisposable() { + AddOne(); + return Disposable.Create(Signal); + } + } +} diff --git a/src/Core/Impl/Threading/AsyncManualResetEvent.cs b/src/Core/Impl/Threading/AsyncManualResetEvent.cs new file mode 100644 index 000000000..087647485 --- /dev/null +++ b/src/Core/Impl/Threading/AsyncManualResetEvent.cs @@ -0,0 +1,76 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Python.Core { + public class AsyncManualResetEvent { + private TaskCompletionSource _tcs; + public Task WaitAsync() => _tcs.Task; + public void Set() => _tcs.TrySetResult(true); + + public Task WaitAsync(CancellationToken cancellationToken) { + if (cancellationToken.IsCancellationRequested) { + return Task.FromCanceled(cancellationToken); + } + + var tcs = new TaskCompletionSource(); + cancellationToken.Register(CancelTcs, tcs); + _tcs.Task.ContinueWith(WaitContinuation, tcs, TaskContinuationOptions.ExecuteSynchronously); + return tcs.Task; + } + + private void WaitContinuation(Task task, object state) { + var tcs = (TaskCompletionSource)state; + switch (task.Status) { + case TaskStatus.Faulted: + tcs.TrySetException(task.Exception); + break; + case TaskStatus.Canceled: + tcs.TrySetCanceled(); + break; + case TaskStatus.RanToCompletion: + tcs.TrySetResult(task.Result); + break; + } + } + + public AsyncManualResetEvent() { + _tcs = new TaskCompletionSource(); + } + + public void Reset() { + while (true) { + var tcs = _tcs; + if (!tcs.Task.IsCompleted) { + return; + } + + if (Interlocked.CompareExchange(ref _tcs, new TaskCompletionSource(), tcs) == tcs) { + return; + } + } + } + + private static void CancelTcs(object obj) { + var tcs = (TaskCompletionSource)obj; + tcs.TrySetCanceled(); + } + } +} diff --git a/src/LanguageServer/Impl/Completion/ClassDefinitionCompletion.cs b/src/LanguageServer/Impl/Completion/ClassDefinitionCompletion.cs index 47366fabe..1bfea4fc7 100644 --- a/src/LanguageServer/Impl/Completion/ClassDefinitionCompletion.cs +++ b/src/LanguageServer/Impl/Completion/ClassDefinitionCompletion.cs @@ -30,7 +30,7 @@ public static bool NoCompletions(ClassDefinition cd, CompletionContext context, return true; } - if (cd.Bases.Length > 0 && context.Position >= cd.Bases[0].StartIndex) { + if (cd.Bases.Count > 0 && context.Position >= cd.Bases[0].StartIndex) { foreach (var p in cd.Bases.Reverse()) { if (context.Position >= p.StartIndex) { if (p.Name == null && context.Ast.LanguageVersion.Is3x() && cd.Bases.All(b => b.Name != @"metaclass")) { diff --git a/src/LanguageServer/Impl/Completion/CompletionSource.cs b/src/LanguageServer/Impl/Completion/CompletionSource.cs index 532f27a8f..1b6687518 100644 --- a/src/LanguageServer/Impl/Completion/CompletionSource.cs +++ b/src/LanguageServer/Impl/Completion/CompletionSource.cs @@ -40,7 +40,8 @@ public async Task GetCompletionsAsync(IDocumentAnalysis analys return new CompletionResult(await ExpressionCompletion.GetCompletionsFromMembersAsync(me.Target, scope, context, cancellationToken)); case ConstantExpression ce1 when ce1.Value is double || ce1.Value is float: // no completions on integer ., the user is typing a float - case ConstantExpression ce2 when ce2.Value is string: + return CompletionResult.Empty; + case ConstantExpression ce2 when ce2.Value is string: // no completions in strings case null when context.Ast.IsInsideComment(context.Location): case null when context.Ast.IsInsideString(context.Location): diff --git a/src/LanguageServer/Impl/Implementation/Server.Documents.cs b/src/LanguageServer/Impl/Implementation/Server.Documents.cs index ea8de0cba..74760817c 100644 --- a/src/LanguageServer/Impl/Implementation/Server.Documents.cs +++ b/src/LanguageServer/Impl/Implementation/Server.Documents.cs @@ -65,18 +65,14 @@ public void DidCloseTextDocument(DidCloseTextDocumentParams @params) { _rdt.CloseDocument(@params.textDocument.uri); } - private IDocumentAnalysis GetAnalysis(Uri uri, CancellationToken cancellationToken) { + private Task GetAnalysisAsync(Uri uri, CancellationToken cancellationToken) { var document = _rdt.GetDocument(uri); - if (document != null) { - try { - document.GetAnalysisAsync(cancellationToken).Wait(200); - return document.GetAnyAnalysis(); - } catch (OperationCanceledException) { - return null; - } + if (document == null) { + _log?.Log(TraceEventType.Error, $"Unable to find document {uri}"); + return Task.FromResult(default(IDocumentAnalysis)); } - _log?.Log(TraceEventType.Error, $"Unable to find document {uri}"); - return null; + + return document.GetAnalysisAsync(200, cancellationToken); } } } diff --git a/src/LanguageServer/Impl/Implementation/Server.Editor.cs b/src/LanguageServer/Impl/Implementation/Server.Editor.cs index e83214514..52c01f2c3 100644 --- a/src/LanguageServer/Impl/Implementation/Server.Editor.cs +++ b/src/LanguageServer/Impl/Implementation/Server.Editor.cs @@ -33,7 +33,7 @@ public async Task Completion(CompletionParams @params, Cancellat _log?.Log(TraceEventType.Verbose, $"Completions in {uri} at {@params.position}"); var res = new CompletionList(); - var analysis = GetAnalysis(uri, cancellationToken); + var analysis = await GetAnalysisAsync(uri, cancellationToken); if(analysis != null) { var result = await _completionSource.GetCompletionsAsync(analysis, @params.position, cancellationToken); res.items = result.Completions.ToArray(); @@ -48,7 +48,7 @@ public async Task Hover(TextDocumentPositionParams @params, CancellationT var uri = @params.textDocument.uri; _log?.Log(TraceEventType.Verbose, $"Hover in {uri} at {@params.position}"); - var analysis = GetAnalysis(uri, cancellationToken); + var analysis = await GetAnalysisAsync(uri, cancellationToken); if (analysis != null) { return await _hoverSource.GetHoverAsync(analysis, @params.position, cancellationToken); } @@ -59,7 +59,7 @@ public async Task SignatureHelp(TextDocumentPositionParams @param var uri = @params.textDocument.uri; _log?.Log(TraceEventType.Verbose, $"Signatures in {uri} at {@params.position}"); - var analysis = GetAnalysis(uri, cancellationToken); + var analysis = await GetAnalysisAsync(uri, cancellationToken); if (analysis != null) { return await _signatureSource.GetSignatureAsync(analysis, @params.position, cancellationToken); } @@ -70,7 +70,7 @@ public async Task GotoDefinition(TextDocumentPositionParams @params var uri = @params.textDocument.uri; _log?.Log(TraceEventType.Verbose, $"Goto Definition in {uri} at {@params.position}"); - var analysis = GetAnalysis(uri, cancellationToken); + var analysis = await GetAnalysisAsync(uri, cancellationToken); var ds = new DefinitionSource(); var reference = await ds.FindDefinitionAsync(analysis, @params.position, cancellationToken); return reference != null ? new[] { reference } : Array.Empty(); diff --git a/src/LanguageServer/Impl/Implementation/Server.cs b/src/LanguageServer/Impl/Implementation/Server.cs index cb66d4eb8..a0241f6ce 100644 --- a/src/LanguageServer/Impl/Implementation/Server.cs +++ b/src/LanguageServer/Impl/Implementation/Server.cs @@ -105,7 +105,7 @@ public async Task InitializeAsync(InitializeParams @params, Ca _services.AddService(new DiagnosticsService(_services)); - var analyzer = new PythonAnalyzer(_services, @params.rootPath); + var analyzer = new PythonAnalyzer(_services); _services.AddService(analyzer); _services.AddService(new RunningDocumentTable(@params.rootPath, _services)); diff --git a/src/LanguageServer/Impl/Sources/SignatureSource.cs b/src/LanguageServer/Impl/Sources/SignatureSource.cs index b3f8afb73..e4aa64845 100644 --- a/src/LanguageServer/Impl/Sources/SignatureSource.cs +++ b/src/LanguageServer/Impl/Sources/SignatureSource.cs @@ -89,10 +89,10 @@ public async Task GetSignatureAsync(IDocumentAnalysis analysis, S if (activeParameter >= 0) { // TODO: Better selection of active signature by argument set activeSignature = signatures - .Select((s, i) => Tuple.Create(s, i)) - .OrderBy(t => t.Item1.parameters.Length) - .FirstOrDefault(t => t.Item1.parameters.Length > activeParameter) - ?.Item2 ?? -1; + .Select((s, i) => Tuple.Create(s, i)) + .OrderBy(t => t.Item1.parameters.Length) + .FirstOrDefault(t => t.Item1.parameters.Length > activeParameter) + ?.Item2 ?? -1; } activeSignature = activeSignature >= 0 diff --git a/src/LanguageServer/Test/CompletionTests.cs b/src/LanguageServer/Test/CompletionTests.cs index bb587cfb7..c9495a56b 100644 --- a/src/LanguageServer/Test/CompletionTests.cs +++ b/src/LanguageServer/Test/CompletionTests.cs @@ -1,4 +1,4 @@ -// Copyright(c) Microsoft Corporation +// Copyright(c) Microsoft Corporation // All rights reserved. // // Licensed under the Apache License, Version 2.0 (the License); you may not use @@ -17,7 +17,6 @@ using System.Linq; using System.Threading.Tasks; using FluentAssertions; -using FluentAssertions.Common; using Microsoft.Python.Analysis.Types; using Microsoft.Python.Core; using Microsoft.Python.Core.Text; @@ -68,7 +67,7 @@ public async Task StringMembers() { var analysis = await GetAnalysisAsync(code); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(3, 3)); - comps.Should().HaveLabels(@"isupper", @"capitalize", @"split" ); + comps.Should().HaveLabels(@"isupper", @"capitalize", @"split"); } [TestMethod, Priority(0)] diff --git a/src/LanguageServer/Test/DiagnosticsTests.cs b/src/LanguageServer/Test/DiagnosticsTests.cs index 6c1b1bf1c..29cbba4ff 100644 --- a/src/LanguageServer/Test/DiagnosticsTests.cs +++ b/src/LanguageServer/Test/DiagnosticsTests.cs @@ -55,7 +55,8 @@ public async Task BasicChange() { ReplacedSpan = new SourceSpan(1, 5, 1, 5) } }); - await doc.GetAnalysisAsync(); + await doc.GetAstAsync(); + await doc.GetAnalysisAsync(0); ds.Diagnostics[doc.Uri].Count.Should().Be(0); doc.Update(new[] {new DocumentChange { @@ -63,7 +64,8 @@ public async Task BasicChange() { ReplacedSpan = new SourceSpan(1, 5, 1, 6) } }); - await doc.GetAnalysisAsync(); + await doc.GetAstAsync(); + await doc.GetAnalysisAsync(0); ds.Diagnostics[doc.Uri].Count.Should().Be(1); } @@ -87,7 +89,8 @@ public async Task TwoDocuments() { ReplacedSpan = new SourceSpan(1, 5, 1, 5) } }); - await doc2.GetAnalysisAsync(); + await doc2.GetAstAsync(); + await doc2.GetAnalysisAsync(0); ds.Diagnostics[doc1.Uri].Count.Should().Be(1); ds.Diagnostics[doc2.Uri].Count.Should().Be(0); @@ -96,7 +99,8 @@ public async Task TwoDocuments() { ReplacedSpan = new SourceSpan(1, 5, 1, 6) } }); - await doc2.GetAnalysisAsync(); + await doc2.GetAstAsync(); + await doc2.GetAnalysisAsync(0); ds.Diagnostics[doc2.Uri].Count.Should().Be(1); doc1.Dispose(); @@ -130,7 +134,7 @@ public async Task Publish() { ReplacedSpan = new SourceSpan(1, 5, 1, 5) } }); - await doc.GetAnalysisAsync(); + await doc.GetAnalysisAsync(0); idle.Idle += Raise.EventWith(null, EventArgs.Empty); } diff --git a/src/LanguageServer/Test/ImportsTests.cs b/src/LanguageServer/Test/ImportsTests.cs index 82ad89b28..98d3382da 100644 --- a/src/LanguageServer/Test/ImportsTests.cs +++ b/src/LanguageServer/Test/ImportsTests.cs @@ -19,6 +19,7 @@ using System.Threading.Tasks; using FluentAssertions; using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Analyzer; using Microsoft.Python.Analysis.Documents; using Microsoft.Python.Core.Text; using Microsoft.Python.LanguageServer.Completion; @@ -66,7 +67,7 @@ import projectB.foo.baz rdt.OpenDocument(new Uri(init4Path), string.Empty); var doc = rdt.OpenDocument(new Uri(appPath), appCode, appPath); - var analysis = await doc.GetAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(7, 10)); @@ -94,7 +95,7 @@ public async Task SysModuleChain() { rdt.OpenDocument(uri2, content2); rdt.OpenDocument(uri3, content3); - var analysis = await doc1.GetAnalysisAsync(); + var analysis = await doc1.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 5)); @@ -115,7 +116,9 @@ await TestData.CreateTestSpecificFileAsync("module2.py", @"import sys var rdt = Services.GetService(); var doc = rdt.OpenDocument(TestData.GetDefaultModuleUri(), content); - var analysis = await doc.GetAnalysisAsync(); + await doc.GetAstAsync(); + await Services.GetService().WaitForCompleteAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 5)); @@ -143,7 +146,8 @@ public async Task UncSearchPaths() { rdt.OpenDocument(new Uri(module2Path), "Y = 6 * 9"); var doc = rdt.OpenDocument(new Uri(appPath), appCode1); - var analysis = await doc.GetAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); + var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(1, 21)); @@ -156,7 +160,8 @@ public async Task UncSearchPaths() { } }); - analysis = await doc.GetAnalysisAsync(); + await doc.GetAstAsync(); + analysis = await doc.GetAnalysisAsync(0); comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 9)); comps.Should().HaveLabels("X").And.NotContainLabels("Y"); @@ -198,7 +203,7 @@ def method2(): var mainPath = Path.Combine(root, "main.py"); var doc = rdt.OpenDocument(new Uri(mainPath), mainContent); - var analysis = await doc.GetAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 6)); @@ -237,7 +242,7 @@ import package.sub_package.module2 rdt.OpenDocument(new Uri(module2Path), "Y = 6 * 9"); var doc = rdt.OpenDocument(new Uri(appPath), appCode); - var analysis = await doc.GetAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(5, 9)); @@ -280,20 +285,20 @@ public async Task FromImport_ModuleAffectsPackage(string appCodeImport) { rdt.OpenDocument(new Uri(modulePath), "X = 42"); var doc = rdt.OpenDocument(new Uri(appPath), appCode1); - var analysis = await doc.GetAnalysisAsync(); + var analysis = await doc.GetAnalysisAsync(0); var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion); var comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 13)); comps.Should().OnlyHaveLabels("module"); - doc.Update(new [] { + doc.Update(new[] { new DocumentChange { InsertedText = appCode2, ReplacedSpan = new SourceSpan(1, 1, 2, 13) } }); - analysis = await doc.GetAnalysisAsync(); + analysis = await doc.GetAnalysisAsync(0); comps = await cs.GetCompletionsAsync(analysis, new SourceLocation(2, 21)); comps.Should().HaveLabels("X"); } diff --git a/src/Parsing/Impl/Ast/AndExpression.cs b/src/Parsing/Impl/Ast/AndExpression.cs index c0ba6f464..cca5eceae 100644 --- a/src/Parsing/Impl/Ast/AndExpression.cs +++ b/src/Parsing/Impl/Ast/AndExpression.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -34,6 +35,11 @@ public AndExpression(Expression left, Expression right, int andIndex) { public override string NodeName => "and expression"; + public override IEnumerable GetChildNodes() { + yield return Left; + if (Right != null) yield return Right; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left.Walk(walker); diff --git a/src/Parsing/Impl/Ast/Arg.cs b/src/Parsing/Impl/Ast/Arg.cs index cdd15157a..506ccb4e2 100644 --- a/src/Parsing/Impl/Ast/Arg.cs +++ b/src/Parsing/Impl/Ast/Arg.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -40,6 +41,10 @@ public int EndIndexIncludingWhitespace { public override string ToString() => base.ToString() + ":" + NameExpression; + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/AssertStatement.cs b/src/Parsing/Impl/Ast/AssertStatement.cs index de407caef..78145bf11 100644 --- a/src/Parsing/Impl/Ast/AssertStatement.cs +++ b/src/Parsing/Impl/Ast/AssertStatement.cs @@ -14,6 +14,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -26,9 +27,13 @@ public AssertStatement(Expression test, Expression message) { } public Expression Test { get; } - public Expression Message { get; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + if (Message != null) yield return Message; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/AssignmentStatement.cs b/src/Parsing/Impl/Ast/AssignmentStatement.cs index 35b56e68c..09d22059e 100644 --- a/src/Parsing/Impl/Ast/AssignmentStatement.cs +++ b/src/Parsing/Impl/Ast/AssignmentStatement.cs @@ -34,6 +34,13 @@ public AssignmentStatement(Expression[] left, Expression right) { public Expression Right { get; } + public override IEnumerable GetChildNodes() { + foreach (var expression in _left) { + yield return expression; + } + if (Right != null) yield return Right; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { foreach (var e in _left) { diff --git a/src/Parsing/Impl/Ast/AugmentedAssignStatement.cs b/src/Parsing/Impl/Ast/AugmentedAssignStatement.cs index ecb660e4e..6b5e4efac 100644 --- a/src/Parsing/Impl/Ast/AugmentedAssignStatement.cs +++ b/src/Parsing/Impl/Ast/AugmentedAssignStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -31,6 +32,11 @@ public AugmentedAssignStatement(PythonOperator op, Expression left, Expression r public Expression Right { get; } + public override IEnumerable GetChildNodes() { + if (Left != null) yield return Left; + if (Right != null) yield return Right; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/AwaitExpression.cs b/src/Parsing/Impl/Ast/AwaitExpression.cs index 6b8df18f6..bcfe7eac5 100644 --- a/src/Parsing/Impl/Ast/AwaitExpression.cs +++ b/src/Parsing/Impl/Ast/AwaitExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -29,6 +30,10 @@ public AwaitExpression(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/BackQuoteExpression.cs b/src/Parsing/Impl/Ast/BackQuoteExpression.cs index 7e1c247d3..4c49f09c8 100644 --- a/src/Parsing/Impl/Ast/BackQuoteExpression.cs +++ b/src/Parsing/Impl/Ast/BackQuoteExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -25,6 +26,10 @@ public BackQuoteExpression(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/BinaryExpression.cs b/src/Parsing/Impl/Ast/BinaryExpression.cs index 28650a2c0..3b0497b30 100644 --- a/src/Parsing/Impl/Ast/BinaryExpression.cs +++ b/src/Parsing/Impl/Ast/BinaryExpression.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Text; using System.Threading; @@ -61,6 +62,11 @@ private bool IsComparison() { public override string NodeName => "binary operator"; + public override IEnumerable GetChildNodes() { + if (Left != null) yield return Left; + if (Right != null) yield return Right; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/BreakStatement.cs b/src/Parsing/Impl/Ast/BreakStatement.cs index c0543889e..3d7a4cc9c 100644 --- a/src/Parsing/Impl/Ast/BreakStatement.cs +++ b/src/Parsing/Impl/Ast/BreakStatement.cs @@ -13,14 +13,17 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Python.Parsing.Ast { public class BreakStatement : Statement { - public BreakStatement() { - } + public BreakStatement() {} + + public override IEnumerable GetChildNodes() => Enumerable.Empty(); public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { diff --git a/src/Parsing/Impl/Ast/CallExpression.cs b/src/Parsing/Impl/Ast/CallExpression.cs index dcee3b745..36f6cdc0e 100644 --- a/src/Parsing/Impl/Ast/CallExpression.cs +++ b/src/Parsing/Impl/Ast/CallExpression.cs @@ -18,25 +18,24 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class CallExpression : Expression { - private readonly Arg[] _args; - - public CallExpression(Expression target, Arg[] args) { + public CallExpression(Expression target, ImmutableArray args) { Target = target; - _args = args; + Args = args; } public Expression Target { get; } - public IList Args => _args; + public ImmutableArray Args { get; } public bool NeedsLocalsDictionary() { if (!(Target is NameExpression nameExpr)) { return false; } - if (_args.Length == 0) { + if (Args.Count == 0) { switch (nameExpr.Name) { case "locals": case "vars": @@ -47,13 +46,13 @@ public bool NeedsLocalsDictionary() { } } - if (_args.Length == 1 && (nameExpr.Name == "dir" || nameExpr.Name == "vars")) { - if (_args[0].Name == "*" || _args[0].Name == "**") { + if (Args.Count == 1 && (nameExpr.Name == "dir" || nameExpr.Name == "vars")) { + if (Args[0].Name == "*" || Args[0].Name == "**") { // could be splatting empty list or dict resulting in 0-param call which needs context return true; } - } else if (_args.Length == 2 && (nameExpr.Name == "dir" || nameExpr.Name == "vars")) { - if (_args[0].Name == "*" && _args[1].Name == "**") { + } else if (Args.Count == 2 && (nameExpr.Name == "dir" || nameExpr.Name == "vars")) { + if (Args[0].Name == "*" && Args[1].Name == "**") { // could be splatting empty list and dict resulting in 0-param call which needs context return true; } @@ -71,10 +70,17 @@ public bool NeedsLocalsDictionary() { internal override string CheckDelete() => "can't delete function call"; + public override IEnumerable GetChildNodes() { + if (Target != null) yield return Target; + foreach (var arg in Args) { + yield return arg; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Target?.Walk(walker); - foreach (var arg in _args.MaybeEnumerate()) { + foreach (var arg in Args) { arg.Walk(walker); } } @@ -87,7 +93,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (Target != null) { await Target.WalkAsync(walker, cancellationToken); } - foreach (var arg in _args.MaybeEnumerate()) { + foreach (var arg in Args) { await arg.WalkAsync(walker, cancellationToken); } } @@ -106,14 +112,14 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo res.Append('('); - if (_args.Length == 0) { + if (Args.Count == 0) { if (format.SpaceWithinEmptyCallArgumentList != null && format.SpaceWithinEmptyCallArgumentList.Value) { res.Append(' '); } } else { var listWhiteSpace = format.SpaceBeforeComma == null ? this.GetListWhiteSpace(ast) : null; var spaceAfterComma = format.SpaceAfterComma.HasValue ? (format.SpaceAfterComma.Value ? " " : string.Empty) : (string)null; - for (var i = 0; i < _args.Length; i++) { + for (var i = 0; i < Args.Count; i++) { if (i > 0) { if (format.SpaceBeforeComma == true) { res.Append(' '); @@ -122,14 +128,14 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo } res.Append(','); } else if (format.SpaceWithinCallParens != null) { - _args[i].AppendCodeString(res, ast, format, format.SpaceWithinCallParens.Value ? " " : string.Empty); + Args[i].AppendCodeString(res, ast, format, format.SpaceWithinCallParens.Value ? " " : string.Empty); continue; } - _args[i].AppendCodeString(res, ast, format, spaceAfterComma); + Args[i].AppendCodeString(res, ast, format, spaceAfterComma); } - if (listWhiteSpace != null && listWhiteSpace.Length == _args.Length) { + if (listWhiteSpace != null && listWhiteSpace.Length == Args.Count) { // trailing comma res.Append(listWhiteSpace[listWhiteSpace.Length - 1]); res.Append(","); diff --git a/src/Parsing/Impl/Ast/ClassDefinition.cs b/src/Parsing/Impl/Ast/ClassDefinition.cs index fed23c9ea..588e2b5e1 100644 --- a/src/Parsing/Impl/Ast/ClassDefinition.cs +++ b/src/Parsing/Impl/Ast/ClassDefinition.cs @@ -15,22 +15,23 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Core.Text; namespace Microsoft.Python.Parsing.Ast { public class ClassDefinition : ScopeStatement { private readonly NameExpression/*!*/ _name; private readonly Statement _body; - private readonly Arg[] _bases; private DecoratorStatement _decorators; - public ClassDefinition(NameExpression/*!*/ name, Arg[] bases, Statement body) { + public ClassDefinition(NameExpression/*!*/ name, ImmutableArray bases, Statement body) { _name = name; - _bases = bases; + Bases = bases; _body = body; } @@ -42,7 +43,7 @@ public ClassDefinition(NameExpression/*!*/ name, Arg[] bases, Statement body) { public NameExpression/*!*/ NameExpression => _name; - public Arg[] Bases => _bases ?? Array.Empty(); + public ImmutableArray Bases { get; } public override Statement Body => _body; @@ -128,11 +129,20 @@ internal override PythonVariable BindReference(PythonNameBinder binder, string n return null; } + public override IEnumerable GetChildNodes() { + if (_name != null) yield return _name; + if (_decorators != null) yield return _decorators; + foreach (var b in Bases) { + yield return b; + } + if (_body != null) yield return _body; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { _name?.Walk(walker); _decorators?.Walk(walker); - foreach (var b in _bases.MaybeEnumerate()) { + foreach (var b in Bases) { b.Walk(walker); } _body?.Walk(walker); @@ -149,7 +159,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (_decorators != null) { await _decorators.WalkAsync(walker, cancellationToken); } - foreach (var b in _bases.MaybeEnumerate()) { + foreach (var b in Bases) { await b.WalkAsync(walker, cancellationToken); } if (_body != null) { @@ -183,7 +193,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co res.Append('('); } - if (Bases.Length != 0) { + if (Bases.Count != 0) { ListExpression.AppendItems( res, ast, @@ -191,7 +201,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co "", "", this, - Bases.Length, + Bases.Count, (i, sb) => { if (format.SpaceWithinClassDeclarationParens != null && i == 0) { // need to remove any leading whitespace which was preserved for @@ -209,7 +219,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co } if (!this.IsAltForm(ast) && !this.IsMissingCloseGrouping(ast)) { - if (Bases.Length != 0 || + if (Bases.Count != 0 || format.SpaceWithinEmptyBaseClassList == null || !string.IsNullOrWhiteSpace(this.GetFourthWhiteSpace(ast))) { format.Append( diff --git a/src/Parsing/Impl/Ast/Comprehension.cs b/src/Parsing/Impl/Ast/Comprehension.cs index 7fa73c3ad..29f22b4cf 100644 --- a/src/Parsing/Impl/Ast/Comprehension.cs +++ b/src/Parsing/Impl/Ast/Comprehension.cs @@ -18,17 +18,16 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public abstract class ComprehensionIterator : Node { } public abstract class Comprehension : Expression { - public abstract IList Iterators { get; } + public abstract ImmutableArray Iterators { get; } public abstract override string NodeName { get; } - public abstract override void Walk(PythonWalker walker); - internal void AppendCodeString(StringBuilder res, PythonAst ast, CodeFormattingOptions format, string start, string end, Expression item) { if (!string.IsNullOrEmpty(start)) { format.ReflowComment(res, this.GetPreceedingWhiteSpace(ast)); @@ -49,23 +48,28 @@ internal void AppendCodeString(StringBuilder res, PythonAst ast, CodeFormattingO } public sealed class ListComprehension : Comprehension { - private readonly ComprehensionIterator[] _iterators; - - public ListComprehension(Expression item, ComprehensionIterator[] iterators) { + public ListComprehension(Expression item, ImmutableArray iterators) { Item = item; - _iterators = iterators; + Iterators = iterators; } public Expression Item { get; } - public override IList Iterators => _iterators; + public override ImmutableArray Iterators { get; } public override string NodeName => "list comprehension"; + public override IEnumerable GetChildNodes() { + if (Item != null) yield return Item; + foreach (var iterator in Iterators) { + yield return iterator; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Item?.Walk(walker); - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators) { ci.Walk(walker); } } @@ -77,7 +81,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (Item != null) { await Item.WalkAsync(walker, cancellationToken); } - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators) { await ci.WalkAsync(walker, cancellationToken); } } @@ -88,23 +92,28 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken } public sealed class SetComprehension : Comprehension { - private readonly ComprehensionIterator[] _iterators; - - public SetComprehension(Expression item, ComprehensionIterator[] iterators) { + public SetComprehension(Expression item, ImmutableArray iterators) { Item = item; - _iterators = iterators; + Iterators = iterators; } public Expression Item { get; } - public override IList Iterators => _iterators; + public override ImmutableArray Iterators { get; } public override string NodeName => "set comprehension"; + public override IEnumerable GetChildNodes() { + if (Item != null) yield return Item; + foreach (var iterator in Iterators) { + yield return iterator; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Item?.Walk(walker); - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { ci.Walk(walker); } } @@ -116,7 +125,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (Item != null) { await Item.WalkAsync(walker, cancellationToken); } - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { await ci.WalkAsync(walker, cancellationToken); } } @@ -127,26 +136,32 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken } public sealed class DictionaryComprehension : Comprehension { - private readonly ComprehensionIterator[] _iterators; private readonly SliceExpression _value; - public DictionaryComprehension(SliceExpression value, ComprehensionIterator[] iterators) { + public DictionaryComprehension(SliceExpression value, ImmutableArray iterators) { _value = value; - _iterators = iterators; + Iterators = iterators; } public Expression Key => _value.SliceStart; public Expression Value => _value.SliceStop; - public override IList Iterators => _iterators; + public override ImmutableArray Iterators { get; } public override string NodeName => "dict comprehension"; + public override IEnumerable GetChildNodes() { + if (_value != null) yield return _value; + foreach (var iterator in Iterators) { + yield return iterator; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { _value?.Walk(walker); - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { ci.Walk(walker); } } @@ -158,7 +173,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (_value != null) { await _value.WalkAsync(walker, cancellationToken); } - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { await ci.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/ComprehensionFor.cs b/src/Parsing/Impl/Ast/ComprehensionFor.cs index 8efb417ef..e21685b75 100644 --- a/src/Parsing/Impl/Ast/ComprehensionFor.cs +++ b/src/Parsing/Impl/Ast/ComprehensionFor.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -35,6 +36,11 @@ public ComprehensionFor(Expression lhs, Expression list, bool isAsync) public bool IsAsync { get; } + public override IEnumerable GetChildNodes() { + if (Left != null) yield return Left; + if (List != null) yield return List; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ComprehensionIf.cs b/src/Parsing/Impl/Ast/ComprehensionIf.cs index 9ae87ebb7..8b19d0041 100644 --- a/src/Parsing/Impl/Ast/ComprehensionIf.cs +++ b/src/Parsing/Impl/Ast/ComprehensionIf.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -25,6 +26,10 @@ public ComprehensionIf(Expression test) { public Expression Test { get; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ConditionalExpression.cs b/src/Parsing/Impl/Ast/ConditionalExpression.cs index 017a5245f..8cb8c9201 100644 --- a/src/Parsing/Impl/Ast/ConditionalExpression.cs +++ b/src/Parsing/Impl/Ast/ConditionalExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -35,6 +36,12 @@ public ConditionalExpression(Expression testExpression, Expression trueExpressio public int IfIndex { get; } public int ElseIndex { get; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + if (TrueExpression != null) yield return TrueExpression; + if (FalseExpression != null) yield return FalseExpression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ConstantExpression.cs b/src/Parsing/Impl/Ast/ConstantExpression.cs index 59960280f..4e1892dea 100644 --- a/src/Parsing/Impl/Ast/ConstantExpression.cs +++ b/src/Parsing/Impl/Ast/ConstantExpression.cs @@ -14,7 +14,9 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Globalization; +using System.Linq; using System.Numerics; using System.Text; using System.Threading; @@ -31,6 +33,8 @@ public ConstantExpression(object value) { internal override string CheckAssign() => Value == null ? "assignment to None" : "can't assign to literal"; + public override IEnumerable GetChildNodes() => Enumerable.Empty(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { } diff --git a/src/Parsing/Impl/Ast/ContinueStatement.cs b/src/Parsing/Impl/Ast/ContinueStatement.cs index 8ee939161..070a29280 100644 --- a/src/Parsing/Impl/Ast/ContinueStatement.cs +++ b/src/Parsing/Impl/Ast/ContinueStatement.cs @@ -13,15 +13,15 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Python.Parsing.Ast { - - public class ContinueStatement : Statement { - public ContinueStatement() { - } + public class ContinueStatement : Statement { + public override IEnumerable GetChildNodes() => Enumerable.Empty(); public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { diff --git a/src/Parsing/Impl/Ast/DecoratorStatement.cs b/src/Parsing/Impl/Ast/DecoratorStatement.cs index 23ade24de..aebb3b9e1 100644 --- a/src/Parsing/Impl/Ast/DecoratorStatement.cs +++ b/src/Parsing/Impl/Ast/DecoratorStatement.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -27,9 +28,11 @@ public DecoratorStatement(Expression[] decorators) { public Expression[] Decorators { get; } + public override IEnumerable GetChildNodes() => Decorators.ExcludeDefault(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var decorator in Decorators.MaybeEnumerate()) { + foreach (var decorator in Decorators) { decorator?.Walk(walker); } } @@ -38,7 +41,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var decorator in Decorators.MaybeEnumerate().ExcludeDefault()) { + foreach (var decorator in Decorators.ExcludeDefault()) { await decorator.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/DelStatement.cs b/src/Parsing/Impl/Ast/DelStatement.cs index 5da27a348..dbb3e2789 100644 --- a/src/Parsing/Impl/Ast/DelStatement.cs +++ b/src/Parsing/Impl/Ast/DelStatement.cs @@ -18,21 +18,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { - public class DelStatement : Statement { - private readonly Expression[] _expressions; - - public DelStatement(Expression[] expressions) { - _expressions = expressions; + public DelStatement(ImmutableArray expressions) { + Expressions = expressions; } - public IList Expressions => _expressions; + public ImmutableArray Expressions { get; } + + public override IEnumerable GetChildNodes() => Expressions; public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var expression in _expressions.MaybeEnumerate()) { + foreach (var expression in Expressions.MaybeEnumerate()) { expression.Walk(walker); } } @@ -41,7 +41,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var expression in _expressions.MaybeEnumerate()) { + foreach (var expression in Expressions.MaybeEnumerate()) { await expression.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/DictionaryExpression.cs b/src/Parsing/Impl/Ast/DictionaryExpression.cs index 6bf6e3470..0feb297be 100644 --- a/src/Parsing/Impl/Ast/DictionaryExpression.cs +++ b/src/Parsing/Impl/Ast/DictionaryExpression.cs @@ -18,22 +18,23 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class DictionaryExpression : Expression { - private readonly SliceExpression[] _items; - - public DictionaryExpression(params SliceExpression[] items) { - _items = items; + public DictionaryExpression(ImmutableArray items) { + Items = items; } - public IList Items => _items; + public ImmutableArray Items { get; } public override string NodeName => "dictionary display"; + public override IEnumerable GetChildNodes() => Items; + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var s in _items.MaybeEnumerate()) { + foreach (var s in Items) { s.Walk(walker); } } @@ -42,7 +43,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var s in _items.MaybeEnumerate()) { + foreach (var s in Items) { await s.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/DottedName.cs b/src/Parsing/Impl/Ast/DottedName.cs index a63590f01..6ddc5e6ef 100644 --- a/src/Parsing/Impl/Ast/DottedName.cs +++ b/src/Parsing/Impl/Ast/DottedName.cs @@ -14,33 +14,36 @@ // permissions and limitations under the License. using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class DottedName : Node { - private readonly NameExpression[] _names; - - public DottedName(NameExpression[]/*!*/ names) { - _names = names; + public DottedName(ImmutableArray names) { + Names = names; } - public IList Names => _names; + public ImmutableArray Names { get; } public virtual string MakeString() { - if (_names.Length == 0) { + if (Names.Count == 0) { return string.Empty; } - var ret = new StringBuilder(_names[0].Name); - for (var i = 1; i < _names.Length; i++) { + var ret = new StringBuilder(Names[0].Name); + for (var i = 1; i < Names.Count; i++) { ret.Append('.'); - ret.Append(_names[i].Name); + ret.Append(Names[i].Name); } return ret.ToString(); } + public override IEnumerable GetChildNodes() => Names; + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { } @@ -56,7 +59,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFormattingOptions format) { var whitespace = this.GetNamesWhiteSpace(ast); - for (int i = 0, whitespaceIndex = 0; i < _names.Length; i++) { + for (int i = 0, whitespaceIndex = 0; i < Names.Count; i++) { if (whitespace != null) { res.Append(whitespace[whitespaceIndex++]); } @@ -66,7 +69,7 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo res.Append(whitespace[whitespaceIndex++]); } } - _names[i].AppendCodeString(res, ast, format); + Names[i].AppendCodeString(res, ast, format); } } diff --git a/src/Parsing/Impl/Ast/EmptyStatement.cs b/src/Parsing/Impl/Ast/EmptyStatement.cs index ac72c941f..4d276ec86 100644 --- a/src/Parsing/Impl/Ast/EmptyStatement.cs +++ b/src/Parsing/Impl/Ast/EmptyStatement.cs @@ -13,6 +13,8 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -24,6 +26,8 @@ public EmptyStatement() { public override int KeywordLength => 4; + public override IEnumerable GetChildNodes() => Enumerable.Empty(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { } diff --git a/src/Parsing/Impl/Ast/ErrorExpression.cs b/src/Parsing/Impl/Ast/ErrorExpression.cs index e9297a651..f15c2ee5d 100644 --- a/src/Parsing/Impl/Ast/ErrorExpression.cs +++ b/src/Parsing/Impl/Ast/ErrorExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -41,6 +42,11 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo _nested?.AppendCodeString(res, ast, format); } + public override IEnumerable GetChildNodes() { + if (_preceding != null) yield return _preceding; + if (_nested != null) yield return _nested; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { _preceding?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ErrorStatement.cs b/src/Parsing/Impl/Ast/ErrorStatement.cs index a7eb61655..7c9dbf4a1 100644 --- a/src/Parsing/Impl/Ast/ErrorStatement.cs +++ b/src/Parsing/Impl/Ast/ErrorStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -33,6 +34,8 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co res.Append(this.GetVerbatimImage(ast) ?? ""); } + public override IEnumerable GetChildNodes() => _preceeding.WhereNotNull(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { foreach (var preceeding in _preceeding.MaybeEnumerate()) { diff --git a/src/Parsing/Impl/Ast/ExecStatement.cs b/src/Parsing/Impl/Ast/ExecStatement.cs index 90adbd54a..911147218 100644 --- a/src/Parsing/Impl/Ast/ExecStatement.cs +++ b/src/Parsing/Impl/Ast/ExecStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; @@ -38,6 +39,13 @@ public ExecStatement(Expression code, Expression locals, Expression globals, Tup public bool NeedsLocalsDictionary() => Globals == null && Locals == null; + public override IEnumerable GetChildNodes() { + if (_code != null) yield return _code; + if (_codeTuple != null) yield return _codeTuple; + if (_locals != null) yield return _locals; + if (_globals != null) yield return _globals; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { _code?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ExpressionStatement.cs b/src/Parsing/Impl/Ast/ExpressionStatement.cs index 02790875c..04c42ed2b 100644 --- a/src/Parsing/Impl/Ast/ExpressionStatement.cs +++ b/src/Parsing/Impl/Ast/ExpressionStatement.cs @@ -13,12 +13,12 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Python.Parsing.Ast { - public class ExpressionStatement : Statement { public ExpressionStatement(Expression expression) { Expression = expression; @@ -26,6 +26,10 @@ public ExpressionStatement(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ExpressionWithAnnotation.cs b/src/Parsing/Impl/Ast/ExpressionWithAnnotation.cs index 4574f101f..39211c41d 100644 --- a/src/Parsing/Impl/Ast/ExpressionWithAnnotation.cs +++ b/src/Parsing/Impl/Ast/ExpressionWithAnnotation.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -26,7 +27,7 @@ public ExpressionWithAnnotation(Expression expression, Expression annotation) { public override string ToString() { if (Annotation != null) { - return Expression.ToString() + ":" + Annotation.ToString(); + return Expression + ":" + Annotation; } return Expression.ToString(); } @@ -40,6 +41,8 @@ public override string ToString() { internal override string CheckAugmentedAssign() => "cannot assign to " + NodeName; internal override string CheckDelete() => "cannot delete " + NodeName; + public override IEnumerable GetChildNodes() => new[] {Expression, Annotation}; + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ForStatement.cs b/src/Parsing/Impl/Ast/ForStatement.cs index 874249ce6..abfc15568 100644 --- a/src/Parsing/Impl/Ast/ForStatement.cs +++ b/src/Parsing/Impl/Ast/ForStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -47,6 +48,13 @@ public ForStatement(Expression left, Expression list, Statement body, Statement public Statement Else { get; } public bool IsAsync { get; } + public override IEnumerable GetChildNodes() { + if (Left != null) yield return Left; + if (List != null) yield return List; + if (Body != null) yield return Body; + if (Else != null) yield return Else; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/FromImportStatement.cs b/src/Parsing/Impl/Ast/FromImportStatement.cs index d8c3859d0..ef89925bf 100644 --- a/src/Parsing/Impl/Ast/FromImportStatement.cs +++ b/src/Parsing/Impl/Ast/FromImportStatement.cs @@ -16,14 +16,15 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { - public class FromImportStatement : Statement { - public FromImportStatement(ModuleName/*!*/ root, NameExpression/*!*/[] names, NameExpression/*!*/[] asNames, bool fromFuture, bool forceAbsolute, int importIndex) { + public FromImportStatement(ModuleName/*!*/ root, ImmutableArray names, ImmutableArray asNames, bool fromFuture, bool forceAbsolute, int importIndex) { Root = root; Names = names; AsNames = asNames; @@ -32,11 +33,11 @@ public FromImportStatement(ModuleName/*!*/ root, NameExpression/*!*/[] names, Na ImportIndex = importIndex; } - public ModuleName Root { get; } + public ModuleName/*!*/ Root { get; } + public ImmutableArray Names { get; } + public ImmutableArray AsNames { get; } public bool IsFromFuture { get; } public bool ForceAbsolute { get; } - public IList Names { get; } - public IList AsNames { get; } public int ImportIndex { get; } public override int KeywordLength => 4; @@ -50,6 +51,9 @@ public FromImportStatement(ModuleName/*!*/ root, NameExpression/*!*/[] names, Na public PythonReference[] GetReferences(PythonAst ast) => GetVariableReferences(this, ast); + // TODO: return names and aliases when they are united into one node + public override IEnumerable GetChildNodes() => Enumerable.Empty(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { } @@ -63,106 +67,6 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken await walker.PostWalkAsync(this, cancellationToken); } - /// - /// Returns a new FromImport statement that is identical to this one but has - /// removed the specified import statement. Otherwise preserves any attributes - /// for the statement. - /// - /// New in 1.1. - /// The parent AST whose attributes should be updated for the new node. - /// The index in Names of the import to be removed. - /// - public FromImportStatement RemoveImport(PythonAst ast, int index) { - if (index < 0 || index >= Names.Count) { - throw new ArgumentOutOfRangeException("index"); - } - if (ast == null) { - throw new ArgumentNullException("ast"); - } - - var names = new NameExpression[Names.Count - 1]; - var asNames = AsNames == null ? null : new NameExpression[AsNames.Count - 1]; - var asNameWhiteSpace = this.GetNamesWhiteSpace(ast); - var newAsNameWhiteSpace = new List(); - var importIndex = ImportIndex; - var asIndex = 0; - for (int i = 0, write = 0; i < Names.Count; i++) { - var includingCurrentName = i != index; - - // track the white space, this needs to be kept in sync w/ ToCodeString and how the - // parser creates the white space. - - if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { - if (write > 0) { - if (includingCurrentName) { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } else { - asIndex++; - } - } else if (i > 0) { - asIndex++; - } - } - - if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { - if (includingCurrentName) { - if (newAsNameWhiteSpace.Count == 0) { - // no matter what we want the 1st entry to have the whitespace after the import keyword - newAsNameWhiteSpace.Add(asNameWhiteSpace[0]); - asIndex++; - } else { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } - } else { - asIndex++; - } - } - - if (includingCurrentName) { - names[write] = Names[i]; - - if (AsNames != null) { - asNames[write] = AsNames[i]; - } - - write++; - } - - if (AsNames != null && AsNames[i] != null) { - if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { - if (i != index) { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } else { - asIndex++; - } - } - - if (AsNames[i].Name.Length != 0) { - if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { - if (i != index) { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } else { - asIndex++; - } - } - } else { - asIndex++; - } - } - } - - if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { - // trailing comma - newAsNameWhiteSpace.Add(asNameWhiteSpace[asNameWhiteSpace.Length - 1]); - } - - var res = new FromImportStatement(Root, names, asNames, IsFromFuture, ForceAbsolute, importIndex); - ast.CopyAttributes(this, res); - ast.SetAttribute(res, NodeAttributes.NamesWhiteSpace, newAsNameWhiteSpace.ToArray()); - - return res; - } - internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, CodeFormattingOptions format) { format.ReflowComment(res, this.GetPreceedingWhiteSpace(ast)); res.Append("from"); @@ -193,7 +97,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co } Names[i].AppendCodeString(res, ast, format); - if (AsNames != null && AsNames[i] != null) { + if (AsNames[i] != null) { if (asNameWhiteSpace != null && asIndex < asNameWhiteSpace.Length) { res.Append(asNameWhiteSpace[asIndex++]); } diff --git a/src/Parsing/Impl/Ast/FunctionDefinition.cs b/src/Parsing/Impl/Ast/FunctionDefinition.cs index 86ca85838..1b70c2c44 100644 --- a/src/Parsing/Impl/Ast/FunctionDefinition.cs +++ b/src/Parsing/Impl/Ast/FunctionDefinition.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Text; using System.Threading; @@ -26,7 +27,6 @@ namespace Microsoft.Python.Parsing.Ast { public class FunctionDefinition : ScopeStatement, IMaybeAsyncStatement { internal static readonly object WhitespaceAfterAsync = new object(); - private readonly Parameter[] _parameters; private int? _keywordEndIndex; protected Statement _body; @@ -43,14 +43,14 @@ public FunctionDefinition(NameExpression name, Parameter[] parameters, Statement NameExpression = name; } - _parameters = parameters; + Parameters = parameters ?? Array.Empty(); _body = body; Decorators = decorators; } public bool IsLambda { get; } - public Parameter[] Parameters => _parameters ?? Array.Empty(); + public Parameter[] Parameters { get; } public override int ArgCount => Parameters.Length; @@ -185,10 +185,20 @@ public int GetIndexOfDef(PythonAst ast) { return DefIndex + NodeAttributes.GetWhiteSpace(this, ast, WhitespaceAfterAsync).Length + 5; } + public override IEnumerable GetChildNodes() { + if (NameExpression != null) yield return NameExpression; + foreach (var parameter in Parameters) { + yield return parameter; + } + if (Decorators != null) yield return Decorators; + if (_body != null) yield return _body; + if (ReturnAnnotation != null) yield return ReturnAnnotation; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { NameExpression?.Walk(walker); - foreach (var p in _parameters.MaybeEnumerate()) { + foreach (var p in Parameters) { p.Walk(walker); } Decorators?.Walk(walker); @@ -204,7 +214,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken await NameExpression.WalkAsync(walker, cancellationToken); } - foreach (var p in _parameters.MaybeEnumerate()) { + foreach (var p in Parameters) { await p.WalkAsync(walker, cancellationToken); } diff --git a/src/Parsing/Impl/Ast/GeneratorExpression.cs b/src/Parsing/Impl/Ast/GeneratorExpression.cs index 464a47e39..f976d4e43 100644 --- a/src/Parsing/Impl/Ast/GeneratorExpression.cs +++ b/src/Parsing/Impl/Ast/GeneratorExpression.cs @@ -18,17 +18,16 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public sealed class GeneratorExpression : Comprehension { - private readonly ComprehensionIterator[] _iterators; - - public GeneratorExpression(Expression item, ComprehensionIterator[] iterators) { + public GeneratorExpression(Expression item, ImmutableArray iterators) { Item = item; - _iterators = iterators; + Iterators = iterators; } - public override IList Iterators => _iterators; + public override ImmutableArray Iterators { get; } public override string NodeName => "generator"; @@ -40,10 +39,17 @@ public GeneratorExpression(Expression item, ComprehensionIterator[] iterators) { internal override string CheckDelete() => "can't delete generator expression"; + public override IEnumerable GetChildNodes() { + if (Item != null) yield return Item; + foreach (var iterator in Iterators) { + yield return iterator; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Item?.Walk(walker); - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { ci.Walk(walker); } } @@ -55,7 +61,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (Item != null) { await Item.WalkAsync(walker, cancellationToken); } - foreach (var ci in _iterators.MaybeEnumerate()) { + foreach (var ci in Iterators.MaybeEnumerate()) { await ci.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/GlobalStatement.cs b/src/Parsing/Impl/Ast/GlobalStatement.cs index a6abf8098..5d49c6e5a 100644 --- a/src/Parsing/Impl/Ast/GlobalStatement.cs +++ b/src/Parsing/Impl/Ast/GlobalStatement.cs @@ -18,21 +18,22 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class GlobalStatement : Statement { - private readonly NameExpression[] _names; - - public GlobalStatement(NameExpression[] names) { - _names = names; + public GlobalStatement(ImmutableArray names) { + Names = names; } - public IList Names => _names; + public ImmutableArray Names { get; } + + public override IEnumerable GetChildNodes() => Names; public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var n in _names.MaybeEnumerate()) { - n?.Walk(walker); + foreach (var n in Names) { + n.Walk(walker); } } walker.PostWalk(this); @@ -40,7 +41,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var n in _names.MaybeEnumerate().ExcludeDefault()) { + foreach (var n in Names) { await n.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/IfStatement.cs b/src/Parsing/Impl/Ast/IfStatement.cs index 04764de37..ef3cb1fa2 100644 --- a/src/Parsing/Impl/Ast/IfStatement.cs +++ b/src/Parsing/Impl/Ast/IfStatement.cs @@ -18,24 +18,32 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class IfStatement : Statement { - private readonly IfStatementTest[] _tests; - - public IfStatement(IfStatementTest[] tests, Statement else_) { - _tests = tests; + public IfStatement(ImmutableArray tests, Statement else_) { + Tests = tests; ElseStatement = else_; } - public IList Tests => _tests; + public ImmutableArray Tests { get; } + public Statement ElseStatement { get; } public int ElseIndex { get; set; } + public override IEnumerable GetChildNodes() { + foreach (var test in Tests) { + yield return test; + } + + if (ElseStatement != null) yield return ElseStatement; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var test in _tests.MaybeEnumerate()) { + foreach (var test in Tests) { test.Walk(walker); } @@ -46,7 +54,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var test in _tests.MaybeEnumerate()) { + foreach (var test in Tests) { await test.WalkAsync(walker, cancellationToken); } if (ElseStatement != null) { @@ -58,7 +66,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, CodeFormattingOptions format) { var itemWhiteSpace = this.GetListWhiteSpace(ast); - for (var i = 0; i < _tests.Length; i++) { + for (var i = 0; i < Tests.Count; i++) { if (itemWhiteSpace != null) { format.ReflowComment(res, itemWhiteSpace[i]); } @@ -68,7 +76,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co } else { res.Append("elif"); } - _tests[i].AppendCodeString(res, ast, format); + Tests[i].AppendCodeString(res, ast, format); } if (ElseStatement != null) { diff --git a/src/Parsing/Impl/Ast/IfStatementTest.cs b/src/Parsing/Impl/Ast/IfStatementTest.cs index 4dfb3cabe..3aa785614 100644 --- a/src/Parsing/Impl/Ast/IfStatementTest.cs +++ b/src/Parsing/Impl/Ast/IfStatementTest.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -31,6 +32,11 @@ public IfStatementTest(Expression test, Statement body) { public Statement Body { get; set; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + if (Body != null) yield return Body; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ImportStatement.cs b/src/Parsing/Impl/Ast/ImportStatement.cs index 1962289bf..1160e6182 100644 --- a/src/Parsing/Impl/Ast/ImportStatement.cs +++ b/src/Parsing/Impl/Ast/ImportStatement.cs @@ -16,18 +16,17 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class ImportStatement : Statement { - private readonly ModuleName[] _names; - private readonly NameExpression[] _asNames; - - public ImportStatement(ModuleName[] names, NameExpression[] asNames, bool forceAbsolute) { - _names = names; - _asNames = asNames; + public ImportStatement(ImmutableArray names, ImmutableArray asNames, bool forceAbsolute) { + Names = names; + AsNames = asNames; ForceAbsolute = forceAbsolute; } @@ -40,8 +39,11 @@ public ImportStatement(ModuleName[] names, NameExpression[] asNames, bool forceA public PythonReference[] GetReferences(PythonAst ast) => GetVariableReferences(this, ast); - public IList Names => _names; - public IList AsNames => _asNames; + public ImmutableArray Names { get; } + public ImmutableArray AsNames { get; } + + // TODO: return names and aliases when they are united into one node + public override IEnumerable GetChildNodes() => Enumerable.Empty(); public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { @@ -56,78 +58,13 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken await walker.PostWalkAsync(this, cancellationToken); } - /// - /// Removes the import at the specified index (which must be in the range of - /// the Names property) and returns a new ImportStatement which is the same - /// as this one minus the imported name. Preserves all round-tripping metadata - /// in the process. - /// - /// New in 1.1. - /// - public ImportStatement RemoveImport(PythonAst ast, int index) { - if (index < 0 || index >= _names.Length) { - throw new ArgumentOutOfRangeException("index"); - } - if (ast == null) { - throw new ArgumentNullException("ast"); - } - - var names = new ModuleName[_names.Length - 1]; - var asNames = _asNames == null ? null : new NameExpression[_asNames.Length - 1]; - var asNameWhiteSpace = this.GetNamesWhiteSpace(ast); - var itemWhiteSpace = this.GetListWhiteSpace(ast); - var newAsNameWhiteSpace = new List(); - var newListWhiteSpace = new List(); - var asIndex = 0; - for (int i = 0, write = 0; i < _names.Length; i++) { - var includingCurrentName = i != index; - - // track the white space, this needs to be kept in sync w/ ToCodeString and how the - // parser creates the white space. - if (i > 0 && itemWhiteSpace != null) { - if (includingCurrentName) { - newListWhiteSpace.Add(itemWhiteSpace[i - 1]); - } - } - - if (includingCurrentName) { - names[write] = _names[i]; - - if (_asNames != null) { - asNames[write] = _asNames[i]; - } - - write++; - } - - if (AsNames[i] != null && includingCurrentName) { - if (asNameWhiteSpace != null) { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } - - if (_asNames[i].Name.Length != 0) { - if (asNameWhiteSpace != null) { - newAsNameWhiteSpace.Add(asNameWhiteSpace[asIndex++]); - } - } - } - } - - var res = new ImportStatement(names, asNames, ForceAbsolute); - ast.CopyAttributes(this, res); - ast.SetAttribute(res, NodeAttributes.NamesWhiteSpace, newAsNameWhiteSpace.ToArray()); - ast.SetAttribute(res, NodeAttributes.ListWhiteSpace, newListWhiteSpace.ToArray()); - - return res; - } - internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, CodeFormattingOptions format) { var asNameWhiteSpace = this.GetNamesWhiteSpace(ast); if (format.ReplaceMultipleImportsWithMultipleStatements) { var proceeding = this.GetPreceedingWhiteSpace(ast); var additionalProceeding = format.GetNextLineProceedingText(proceeding); - for (int i = 0, asIndex = 0; i < _names.Length; i++) { + for (int i = 0, asIndex = 0; i < Names.Count; i++) { if (i == 0) { format.ReflowComment(res, proceeding) ; } else { @@ -135,7 +72,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co } res.Append("import"); - _names[i].AppendCodeString(res, ast, format); + Names[i].AppendCodeString(res, ast, format); AppendAs(res, ast, format, asNameWhiteSpace, i, ref asIndex); } return; @@ -144,13 +81,13 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co res.Append("import"); var itemWhiteSpace = this.GetListWhiteSpace(ast); - for (int i = 0, asIndex = 0; i < _names.Length; i++) { + for (int i = 0, asIndex = 0; i < Names.Count; i++) { if (i > 0 && itemWhiteSpace != null) { res.Append(itemWhiteSpace[i - 1]); res.Append(','); } - _names[i].AppendCodeString(res, ast, format); + Names[i].AppendCodeString(res, ast, format); AppendAs(res, ast, format, asNameWhiteSpace, i, ref asIndex); } } @@ -163,12 +100,12 @@ private void AppendAs(StringBuilder res, PythonAst ast, CodeFormattingOptions fo } res.Append("as"); - if (_asNames[i].Name.Length != 0) { + if (AsNames[i].Name.Length != 0) { if (asNameWhiteSpace != null) { res.Append(asNameWhiteSpace[asIndex++]); } - _asNames[i].AppendCodeString(res, ast, format); + AsNames[i].AppendCodeString(res, ast, format); } } } diff --git a/src/Parsing/Impl/Ast/IndexExpression.cs b/src/Parsing/Impl/Ast/IndexExpression.cs index d3cce0cf5..9042fc63c 100644 --- a/src/Parsing/Impl/Ast/IndexExpression.cs +++ b/src/Parsing/Impl/Ast/IndexExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -32,6 +33,11 @@ public IndexExpression(Expression target, Expression index) { internal override string CheckDelete() => null; + public override IEnumerable GetChildNodes() { + if (Target != null) yield return Target; + if (Index != null) yield return Index; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Target?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/LambdaExpression.cs b/src/Parsing/Impl/Ast/LambdaExpression.cs index 3cdfc176b..1c9c47d1c 100644 --- a/src/Parsing/Impl/Ast/LambdaExpression.cs +++ b/src/Parsing/Impl/Ast/LambdaExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Diagnostics; using System.Text; using System.Threading; @@ -26,6 +27,10 @@ public LambdaExpression(FunctionDefinition function) { public FunctionDefinition Function { get; } + public override IEnumerable GetChildNodes() { + if (Function != null) yield return Function; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Function?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ListExpression.cs b/src/Parsing/Impl/Ast/ListExpression.cs index a6523822c..3abd5bbc4 100644 --- a/src/Parsing/Impl/Ast/ListExpression.cs +++ b/src/Parsing/Impl/Ast/ListExpression.cs @@ -19,15 +19,18 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class ListExpression : SequenceExpression { - public ListExpression(params Expression[] items) + public ListExpression(ImmutableArray items) : base(items) { } public override string NodeName => "list display"; + public override IEnumerable GetChildNodes() => Items; + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { foreach (var e in Items.MaybeEnumerate()) { @@ -61,7 +64,7 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo } } - internal static void AppendItems(StringBuilder res, PythonAst ast, CodeFormattingOptions format, string start, string end, Node node, IList items, bool? delimiterWhiteSpace = null) where T : Expression { + internal static void AppendItems(StringBuilder res, PythonAst ast, CodeFormattingOptions format, string start, string end, Node node, ImmutableArray items, bool? delimiterWhiteSpace = null) where T : Expression { string initialWs = null, ws = null; if (delimiterWhiteSpace.HasValue) { initialWs = delimiterWhiteSpace.Value ? " " : string.Empty; diff --git a/src/Parsing/Impl/Ast/MemberExpression.cs b/src/Parsing/Impl/Ast/MemberExpression.cs index a4f394231..ccac71cea 100644 --- a/src/Parsing/Impl/Ast/MemberExpression.cs +++ b/src/Parsing/Impl/Ast/MemberExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -50,6 +51,10 @@ public void SetLoc(int start, int name, int end) { internal override string CheckDelete() => null; + public override IEnumerable GetChildNodes() { + if (Target != null) yield return Target; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Target?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ModuleName.cs b/src/Parsing/Impl/Ast/ModuleName.cs index e52ed6eaf..85e75a59e 100644 --- a/src/Parsing/Impl/Ast/ModuleName.cs +++ b/src/Parsing/Impl/Ast/ModuleName.cs @@ -13,9 +13,11 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using Microsoft.Python.Core.Collections; + namespace Microsoft.Python.Parsing.Ast { public class ModuleName : DottedName { - public ModuleName(NameExpression[]/*!*/ names) + public ModuleName(ImmutableArray names) : base(names) { } } diff --git a/src/Parsing/Impl/Ast/NameExpression.cs b/src/Parsing/Impl/Ast/NameExpression.cs index 827b4e1c1..b9748400c 100644 --- a/src/Parsing/Impl/Ast/NameExpression.cs +++ b/src/Parsing/Impl/Ast/NameExpression.cs @@ -13,6 +13,8 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -32,6 +34,8 @@ public NameExpression(string name) { internal override string CheckAssign() => null; internal override string CheckDelete() => null; + public override IEnumerable GetChildNodes() => Enumerable.Empty(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { } diff --git a/src/Parsing/Impl/Ast/Node.cs b/src/Parsing/Impl/Ast/Node.cs index 58aa67997..7edf6da23 100644 --- a/src/Parsing/Impl/Ast/Node.cs +++ b/src/Parsing/Impl/Ast/Node.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -35,6 +36,7 @@ public int StartIndex { set => IndexSpan = new IndexSpan(value, 0); } + public abstract IEnumerable GetChildNodes(); public abstract void Walk(PythonWalker walker); public abstract Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default); diff --git a/src/Parsing/Impl/Ast/NonlocalStatement.cs b/src/Parsing/Impl/Ast/NonlocalStatement.cs index 80f84ad3c..1fde33aa6 100644 --- a/src/Parsing/Impl/Ast/NonlocalStatement.cs +++ b/src/Parsing/Impl/Ast/NonlocalStatement.cs @@ -18,20 +18,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class NonlocalStatement : Statement { - private readonly NameExpression[] _names; - - public NonlocalStatement(NameExpression[] names) { - _names = names; + public NonlocalStatement(ImmutableArray names) { + Names = names; } - public IList Names => _names; + public ImmutableArray Names { get; } + + public override IEnumerable GetChildNodes() => Names; public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var n in _names.MaybeEnumerate().ExcludeDefault()) { + foreach (var n in Names) { n.Walk(walker); } } @@ -40,7 +41,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var n in _names.MaybeEnumerate().ExcludeDefault()) { + foreach (var n in Names) { await n.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/OrExpression.cs b/src/Parsing/Impl/Ast/OrExpression.cs index 210f682dc..d634cbb8c 100644 --- a/src/Parsing/Impl/Ast/OrExpression.cs +++ b/src/Parsing/Impl/Ast/OrExpression.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -36,6 +37,11 @@ public OrExpression(Expression left, Expression right, int orIndex) { public override string NodeName => "or expression"; + public override IEnumerable GetChildNodes() { + yield return Left; + if (Right != null) yield return Right; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Left.Walk(walker); diff --git a/src/Parsing/Impl/Ast/Parameter.cs b/src/Parsing/Impl/Ast/Parameter.cs index 108c7a455..624d3dc4d 100644 --- a/src/Parsing/Impl/Ast/Parameter.cs +++ b/src/Parsing/Impl/Ast/Parameter.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -53,6 +54,11 @@ public Parameter(NameExpression name, ParameterKind kind) { public ParameterKind Kind { get; } + public override IEnumerable GetChildNodes() { + if (Annotation != null) yield return Annotation; + if (DefaultValue != null) yield return DefaultValue; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Annotation?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/ParenthesisExpression.cs b/src/Parsing/Impl/Ast/ParenthesisExpression.cs index 16e44dca0..eef3a7b2f 100644 --- a/src/Parsing/Impl/Ast/ParenthesisExpression.cs +++ b/src/Parsing/Impl/Ast/ParenthesisExpression.cs @@ -13,36 +13,38 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Python.Parsing.Ast { - public class ParenthesisExpression : Expression { - private readonly Expression _expression; - public ParenthesisExpression(Expression expression) { - _expression = expression; + Expression = expression; } - public Expression Expression => _expression; + public Expression Expression { get; } - internal override string CheckAssign() => _expression.CheckAssign(); + internal override string CheckAssign() => Expression.CheckAssign(); - internal override string CheckDelete() => _expression.CheckDelete(); + internal override string CheckDelete() => Expression.CheckDelete(); + + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - _expression?.Walk(walker); + Expression?.Walk(walker); } walker.PostWalk(this); } public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - if (_expression != null) { - await _expression.WalkAsync(walker, cancellationToken); + if (Expression != null) { + await Expression.WalkAsync(walker, cancellationToken); } } await walker.PostWalkAsync(this, cancellationToken); @@ -53,7 +55,7 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo format.ReflowComment(res, this.GetPreceedingWhiteSpace(ast)); res.Append('('); - _expression.AppendCodeString( + Expression.AppendCodeString( res, ast, format, diff --git a/src/Parsing/Impl/Ast/PrintStatement.cs b/src/Parsing/Impl/Ast/PrintStatement.cs index a912bec00..314bf2d5e 100644 --- a/src/Parsing/Impl/Ast/PrintStatement.cs +++ b/src/Parsing/Impl/Ast/PrintStatement.cs @@ -18,27 +18,33 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class PrintStatement : Statement { - private readonly Expression[] _expressions; - - public PrintStatement(Expression destination, Expression[] expressions, bool trailingComma) { + public PrintStatement(Expression destination, ImmutableArray expressions, bool trailingComma) { Destination = destination; - _expressions = expressions; + Expressions = expressions; TrailingComma = trailingComma; } public Expression Destination { get; } - public IList Expressions => _expressions; + public ImmutableArray Expressions { get; } public bool TrailingComma { get; } + public override IEnumerable GetChildNodes() { + if (Destination != null) yield return Destination; + foreach (var expression in Expressions) { + yield return expression; + } + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Destination?.Walk(walker); - foreach (var expression in _expressions.MaybeEnumerate()) { + foreach (var expression in Expressions.MaybeEnumerate()) { expression.Walk(walker); } } @@ -50,7 +56,7 @@ public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken if (Destination != null) { await Destination.WalkAsync(walker, cancellationToken); } - foreach (var expression in _expressions.MaybeEnumerate()) { + foreach (var expression in Expressions.MaybeEnumerate()) { await expression.WalkAsync(walker, cancellationToken); } } @@ -64,7 +70,7 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co res.Append(this.GetSecondWhiteSpace(ast)); res.Append(">>"); Destination.AppendCodeString(res, ast, format); - if (_expressions.Length > 0) { + if (Expressions.Count > 0) { res.Append(this.GetThirdWhiteSpace(ast)); res.Append(','); } diff --git a/src/Parsing/Impl/Ast/PythonAst.cs b/src/Parsing/Impl/Ast/PythonAst.cs index 39aba7fa9..142c352ab 100644 --- a/src/Parsing/Impl/Ast/PythonAst.cs +++ b/src/Parsing/Impl/Ast/PythonAst.cs @@ -73,6 +73,8 @@ public PythonAst(IEnumerable existingAst) { /// public bool HasVerbatim { get; internal set; } + public override IEnumerable GetChildNodes() => new[] {_body}; + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { _body.Walk(walker); diff --git a/src/Parsing/Impl/Ast/RaiseStatement.cs b/src/Parsing/Impl/Ast/RaiseStatement.cs index d18cf595b..f8038bca7 100644 --- a/src/Parsing/Impl/Ast/RaiseStatement.cs +++ b/src/Parsing/Impl/Ast/RaiseStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -36,6 +37,13 @@ public RaiseStatement(Expression exceptionType, Expression exceptionValue, Expre public override int KeywordLength => 5; + public override IEnumerable GetChildNodes() { + if (ExceptType != null) yield return ExceptType; + if (Value != null) yield return Value; + if (Traceback != null) yield return Traceback; + if (Cause != null) yield return Cause; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { ExceptType?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/RelativeModuleName.cs b/src/Parsing/Impl/Ast/RelativeModuleName.cs index 26a69e83e..43841e0bb 100644 --- a/src/Parsing/Impl/Ast/RelativeModuleName.cs +++ b/src/Parsing/Impl/Ast/RelativeModuleName.cs @@ -15,10 +15,11 @@ // permissions and limitations under the License. using System.Text; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class RelativeModuleName : ModuleName { - public RelativeModuleName(NameExpression[]/*!*/ names, int dotCount) + public RelativeModuleName(ImmutableArray names, int dotCount) : base(names) { DotCount = dotCount; } diff --git a/src/Parsing/Impl/Ast/ReturnStatement.cs b/src/Parsing/Impl/Ast/ReturnStatement.cs index b58d47c3d..f5c5a7c0b 100644 --- a/src/Parsing/Impl/Ast/ReturnStatement.cs +++ b/src/Parsing/Impl/Ast/ReturnStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -25,6 +26,10 @@ public ReturnStatement(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/SequenceExpression.cs b/src/Parsing/Impl/Ast/SequenceExpression.cs index 0e76b9cf3..1cffa0173 100644 --- a/src/Parsing/Impl/Ast/SequenceExpression.cs +++ b/src/Parsing/Impl/Ast/SequenceExpression.cs @@ -15,16 +15,15 @@ // permissions and limitations under the License. using System.Collections.Generic; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public abstract class SequenceExpression : Expression { - private readonly Expression[] _items; - - protected SequenceExpression(Expression[] items) { - _items = items; + protected SequenceExpression(ImmutableArray items) { + Items = items; } - public IList Items => _items; + public ImmutableArray Items { get; } internal override string CheckAssign() { for (var i = 0; i < Items.Count; i++) { diff --git a/src/Parsing/Impl/Ast/SetExpression.cs b/src/Parsing/Impl/Ast/SetExpression.cs index cd5d7e58a..395bf4a82 100644 --- a/src/Parsing/Impl/Ast/SetExpression.cs +++ b/src/Parsing/Impl/Ast/SetExpression.cs @@ -17,23 +17,25 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class SetExpression : Expression { - private readonly Expression[] _items; - - public SetExpression(params Expression[] items) { - _items = items; + public SetExpression(ImmutableArray items) { + Items = items; } - public IList Items => _items; + public ImmutableArray Items { get; } public override string NodeName => "set display"; + public override IEnumerable GetChildNodes() => Items.WhereNotNull(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var s in _items) { + foreach (var s in Items) { s.Walk(walker); } } @@ -42,7 +44,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var s in _items) { + foreach (var s in Items) { await s.WalkAsync(walker, cancellationToken); } } diff --git a/src/Parsing/Impl/Ast/SliceExpression.cs b/src/Parsing/Impl/Ast/SliceExpression.cs index 022a14c92..ffa79055b 100644 --- a/src/Parsing/Impl/Ast/SliceExpression.cs +++ b/src/Parsing/Impl/Ast/SliceExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -38,6 +39,12 @@ public SliceExpression(Expression start, Expression stop, Expression step, bool /// public bool StepProvided { get; } + public override IEnumerable GetChildNodes() { + if (SliceStart != null) yield return SliceStart; + if (SliceStop != null) yield return SliceStop; + if (SliceStep != null) yield return SliceStep; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { SliceStart?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/StarredExpression.cs b/src/Parsing/Impl/Ast/StarredExpression.cs index b3100f568..90550d3da 100644 --- a/src/Parsing/Impl/Ast/StarredExpression.cs +++ b/src/Parsing/Impl/Ast/StarredExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -30,6 +31,10 @@ public StarredExpression(Expression expr, int starCount) { public int StarCount { get; } + public override IEnumerable GetChildNodes() { + yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression.Walk(walker); diff --git a/src/Parsing/Impl/Ast/SublistParameter.cs b/src/Parsing/Impl/Ast/SublistParameter.cs index bd387ce77..ad5e3d256 100644 --- a/src/Parsing/Impl/Ast/SublistParameter.cs +++ b/src/Parsing/Impl/Ast/SublistParameter.cs @@ -14,6 +14,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using Microsoft.Python.Core; @@ -31,6 +32,11 @@ public SublistParameter(int position, TupleExpression tuple) public TupleExpression Tuple { get; } + public override IEnumerable GetChildNodes() { + if (Tuple != null) yield return Tuple; + if (DefaultValue != null) yield return DefaultValue; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Tuple?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/SuiteStatement.cs b/src/Parsing/Impl/Ast/SuiteStatement.cs index 1556e621f..ddb4c87ac 100644 --- a/src/Parsing/Impl/Ast/SuiteStatement.cs +++ b/src/Parsing/Impl/Ast/SuiteStatement.cs @@ -30,6 +30,8 @@ public SuiteStatement(Statement[] statements) { public IList Statements => _statements; + public override IEnumerable GetChildNodes() => _statements.WhereNotNull(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { foreach (var s in _statements.MaybeEnumerate()) { diff --git a/src/Parsing/Impl/Ast/TryStatement.cs b/src/Parsing/Impl/Ast/TryStatement.cs index bd7c9c445..6b9b66634 100644 --- a/src/Parsing/Impl/Ast/TryStatement.cs +++ b/src/Parsing/Impl/Ast/TryStatement.cs @@ -55,6 +55,18 @@ public TryStatement(Statement body, TryStatementHandler[] handlers, Statement el /// public IList Handlers => _handlers; + public override IEnumerable GetChildNodes() { + if (Body != null) yield return Body; + if (_handlers != null) { + foreach (var handler in _handlers) { + yield return handler; + } + } + + if (Else != null) yield return Else; + if (Finally != null) yield return Finally; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Body?.Walk(walker); @@ -124,6 +136,12 @@ public TryStatementHandler(Expression test, Expression target, Statement body) { public Expression Target { get; } public Statement Body { get; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + if (Target != null) yield return Target; + if (Body != null) yield return Body; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/TupleExpression.cs b/src/Parsing/Impl/Ast/TupleExpression.cs index 114fed653..c0998c229 100644 --- a/src/Parsing/Impl/Ast/TupleExpression.cs +++ b/src/Parsing/Impl/Ast/TupleExpression.cs @@ -13,14 +13,16 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class TupleExpression : SequenceExpression { - public TupleExpression(bool expandable, params Expression[] items) + public TupleExpression(bool expandable, ImmutableArray items) : base(items) { IsExpandable = expandable; } @@ -32,6 +34,8 @@ internal override string CheckAssign() { return base.CheckAssign(); } + public override IEnumerable GetChildNodes() => Items.WhereNotNull(); + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { foreach (var e in Items.MaybeEnumerate()) { diff --git a/src/Parsing/Impl/Ast/UnaryExpression.cs b/src/Parsing/Impl/Ast/UnaryExpression.cs index d9181fb8e..57de8258a 100644 --- a/src/Parsing/Impl/Ast/UnaryExpression.cs +++ b/src/Parsing/Impl/Ast/UnaryExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -37,9 +38,13 @@ internal override void AppendCodeString(StringBuilder res, PythonAst ast, CodeFo Expression.AppendCodeString(res, ast, format); } + public override IEnumerable GetChildNodes() { + yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - Expression?.Walk(walker); + Expression.Walk(walker); } walker.PostWalk(this); } diff --git a/src/Parsing/Impl/Ast/WhileStatement.cs b/src/Parsing/Impl/Ast/WhileStatement.cs index cbd0c792e..26388ed5f 100644 --- a/src/Parsing/Impl/Ast/WhileStatement.cs +++ b/src/Parsing/Impl/Ast/WhileStatement.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -39,6 +40,12 @@ public void SetLoc(int start, int header, int end, int elseIndex) { ElseIndex = elseIndex; } + public override IEnumerable GetChildNodes() { + if (Test != null) yield return Test; + if (Body != null) yield return Body; + if (ElseStatement != null) yield return ElseStatement; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Test?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/WithStatement.cs b/src/Parsing/Impl/Ast/WithStatement.cs index d55fc9dcd..89ac1e603 100644 --- a/src/Parsing/Impl/Ast/WithStatement.cs +++ b/src/Parsing/Impl/Ast/WithStatement.cs @@ -18,23 +18,22 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Core.Collections; namespace Microsoft.Python.Parsing.Ast { public class WithStatement : Statement, IMaybeAsyncStatement { - private readonly WithItem[] _items; private int? _keywordEndIndex; - public WithStatement(WithItem[] items, Statement body) { - _items = items; + public WithStatement(ImmutableArray items, Statement body) { + Items = items; Body = body; } - public WithStatement(WithItem[] items, Statement body, bool isAsync) : this(items, body) { + public WithStatement(ImmutableArray items, Statement body, bool isAsync) : this(items, body) { IsAsync = isAsync; } - - public IList Items => _items; + public ImmutableArray Items { get; } public int HeaderIndex { get; set; } internal void SetKeywordEndIndex(int index) => _keywordEndIndex = index; @@ -44,9 +43,16 @@ public WithStatement(WithItem[] items, Statement body, bool isAsync) : this(item public Statement Body { get; } public bool IsAsync { get; } + public override IEnumerable GetChildNodes() { + foreach (var item in Items) { + yield return item; + } + if (Body != null) yield return Body; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - foreach (var item in _items) { + foreach (var item in Items) { item.Walk(walker); } @@ -59,7 +65,7 @@ public override void Walk(PythonWalker walker) { public override async Task WalkAsync(PythonWalkerAsync walker, CancellationToken cancellationToken = default) { if (await walker.WalkAsync(this, cancellationToken)) { - foreach (var item in _items) { + foreach (var item in Items) { await item.WalkAsync(walker, cancellationToken); } if (Body != null) { @@ -85,8 +91,8 @@ internal override void AppendCodeStringStmt(StringBuilder res, PythonAst ast, Co res.Append("with"); var itemWhiteSpace = this.GetListWhiteSpace(ast); var whiteSpaceIndex = 0; - for (var i = 0; i < _items.Length; i++) { - var item = _items[i]; + for (var i = 0; i < Items.Count; i++) { + var item = Items[i]; if (i != 0) { if (itemWhiteSpace != null) { res.Append(itemWhiteSpace[whiteSpaceIndex++]); @@ -121,6 +127,11 @@ public WithItem(Expression contextManager, Expression variable, int asIndex) { public Expression Variable { get; } public int AsIndex { get; } + public override IEnumerable GetChildNodes() { + if (ContextManager != null) yield return ContextManager; + if (Variable != null) yield return Variable; + } + public override void Walk(PythonWalker walker) { ContextManager?.Walk(walker); Variable?.Walk(walker); diff --git a/src/Parsing/Impl/Ast/YieldExpression.cs b/src/Parsing/Impl/Ast/YieldExpression.cs index dd5a4db16..9d5cc9fd2 100644 --- a/src/Parsing/Impl/Ast/YieldExpression.cs +++ b/src/Parsing/Impl/Ast/YieldExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -29,9 +30,13 @@ public YieldExpression(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { - Expression?.Walk(walker); + Expression?.Walk(walker); } walker.PostWalk(this); } diff --git a/src/Parsing/Impl/Ast/YieldFromExpression.cs b/src/Parsing/Impl/Ast/YieldFromExpression.cs index 6a7a55fd1..de45d838b 100644 --- a/src/Parsing/Impl/Ast/YieldFromExpression.cs +++ b/src/Parsing/Impl/Ast/YieldFromExpression.cs @@ -13,6 +13,7 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. +using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -30,6 +31,10 @@ public YieldFromExpression(Expression expression) { public Expression Expression { get; } + public override IEnumerable GetChildNodes() { + if (Expression != null) yield return Expression; + } + public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { Expression?.Walk(walker); diff --git a/src/Parsing/Impl/Parser.cs b/src/Parsing/Impl/Parser.cs index 1ada9a6ae..f066a6781 100644 --- a/src/Parsing/Impl/Parser.cs +++ b/src/Parsing/Impl/Parser.cs @@ -24,6 +24,7 @@ using System.Text; using System.Text.RegularExpressions; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Core.Text; using Microsoft.Python.Parsing.Ast; @@ -592,7 +593,7 @@ private Statement ParseDelStmt() { DelStatement ret; if (PeekToken(TokenKind.NewLine) || PeekToken(TokenKind.EndOfFile)) { ReportSyntaxError(curLookahead.Span.Start, curLookahead.Span.End, "expected expression after del"); - ret = new DelStatement(new Expression[0]); + ret = new DelStatement(ImmutableArray.Empty); } else { var l = ParseExprList(out var itemWhiteSpace); foreach (var e in l) { @@ -605,7 +606,7 @@ private Statement ParseDelStmt() { } } - ret = new DelStatement(l.ToArray()); + ret = new DelStatement(ImmutableArray.Create(l)); if (itemWhiteSpace != null) { AddListWhiteSpace(ret, itemWhiteSpace.ToArray()); } @@ -1044,8 +1045,8 @@ private ImportStatement ParseImportStmt() { las.Add(MaybeParseAsName(asNameWhiteSpace)); } } - var names = l.ToArray(); - var asNames = las.ToArray(); + var names = ImmutableArray.Create(l); + var asNames = ImmutableArray.Create(las); var ret = new ImportStatement(names, asNames, AbsoluteImports); if (_verbatim) { @@ -1069,8 +1070,6 @@ private ModuleName ParseModuleName() { return ret; } - private static readonly NameExpression[] EmptyNames = new NameExpression[0]; - // relative_module: "."* module | "."+ private ModuleName ParseRelativeModuleName() { var start = -1; @@ -1099,10 +1098,10 @@ private ModuleName ParseRelativeModuleName() { } List nameWhiteSpace = null; - var names = EmptyNames; + var names = ImmutableArray.Empty; if (PeekToken() is NameToken) { names = ReadDottedName(out nameWhiteSpace); - if (!isStartSetCorrectly && names.Length > 0) { + if (!isStartSetCorrectly && names.Count > 0) { start = names[0].StartIndex; isStartSetCorrectly = true; } @@ -1123,7 +1122,7 @@ private ModuleName ParseRelativeModuleName() { AddListWhiteSpace(ret, dotWhiteSpace.ToArray()); } } else { - if (names.Length == 0) { + if (names.Count == 0) { ReportSyntaxError(_lookahead.Span.Start, _lookahead.Span.End, "missing module name"); } ret = new ModuleName(names); @@ -1137,14 +1136,14 @@ private ModuleName ParseRelativeModuleName() { return ret; } - private NameExpression[] ReadDottedName(out List dotWhiteSpace) { - var l = new List(); + private ImmutableArray ReadDottedName(out List dotWhiteSpace) { + var names = ImmutableArray.Empty; dotWhiteSpace = MakeWhiteSpaceList(); var name = ReadName(); if (name.HasName) { var nameExpr = MakeName(name); - l.Add(nameExpr); + names = names.Add(nameExpr); if (_verbatim) { dotWhiteSpace.Add(_tokenWhiteSpace); @@ -1155,13 +1154,14 @@ private NameExpression[] ReadDottedName(out List dotWhiteSpace) { } name = ReadName(); nameExpr = MakeName(name); - l.Add(nameExpr); + names = names.Add(nameExpr); if (_verbatim) { dotWhiteSpace.Add(_tokenWhiteSpace); } } } - return l.ToArray(); + + return names; } @@ -1181,8 +1181,8 @@ private FromImportStatement ParseFromImportStmt() { var ateParen = ateImport && MaybeEat(TokenKind.LeftParenthesis); var parenWhiteSpace = ateParen ? _tokenWhiteSpace : null; - NameExpression/*!*/[] names; - NameExpression[] asNames; + var names = ImmutableArray.Empty; + var asNames = ImmutableArray.Empty; var fromFuture = false; List namesWhiteSpace = null; @@ -1191,18 +1191,15 @@ private FromImportStatement ParseFromImportStmt() { var las = new List(); ParseAsNameList(l, las, out namesWhiteSpace); - names = l.ToArray(); - asNames = las.ToArray(); + names = ImmutableArray.Create(l); + asNames = ImmutableArray.Create(las); if (_langVersion.Is3x() && ((_functions != null && _functions.Count > 0) || _classDepth > 0)) { foreach (var n in names.Where(n => n.Name == "*")) { ReportSyntaxError(n.StartIndex, n.EndIndex, "import * only allowed at module level"); } } - } else { - names = EmptyNames; - asNames = EmptyNames; - } + } // Process from __future__ statement if (dname.Names.Count == 1 && dname.Names[0].Name == "__future__") { @@ -1236,11 +1233,11 @@ private FromImportStatement ParseFromImportStmt() { return ret; } - private bool ProcessFutureStatements(int start, NameExpression/*!*/[] names, bool fromFuture) { + private bool ProcessFutureStatements(int start, ImmutableArray names, bool fromFuture) { if (!_fromFutureAllowed) { ReportSyntaxError(start, GetEnd(), "from __future__ imports must occur at the beginning of the file"); } - if (names.Length == 1 && names[0].Name == "*") { + if (names.Count == 1 && names[0].Name == "*") { ReportSyntaxError(start, GetEnd(), "future statement does not support import *"); } fromFuture = true; @@ -1389,8 +1386,7 @@ private GlobalStatement ParseGlobalStmt() { var start = GetStart(); var globalWhiteSpace = _tokenWhiteSpace; - var l = ReadNameList(out var commaWhiteSpace, out var namesWhiteSpace); - var names = l.ToArray(); + var names = ReadNameList(out var commaWhiteSpace, out var namesWhiteSpace); var ret = new GlobalStatement(names); ret.SetLoc(start, GetEndForStatement()); if (_verbatim) { @@ -1410,8 +1406,7 @@ private NonlocalStatement ParseNonlocalStmt() { var localWhiteSpace = _tokenWhiteSpace; var start = GetStart(); - var l = ReadNameList(out var commaWhiteSpace, out var namesWhiteSpace); - var names = l.ToArray(); + var names = ReadNameList(out var commaWhiteSpace, out var namesWhiteSpace); var ret = new NonlocalStatement(names); ret.SetLoc(start, GetEndForStatement()); if (_verbatim) { @@ -1422,15 +1417,15 @@ private NonlocalStatement ParseNonlocalStmt() { return ret; } - private List ReadNameList(out List commaWhiteSpace, out List namesWhiteSpace) { - var l = new List(); + private ImmutableArray ReadNameList(out List commaWhiteSpace, out List namesWhiteSpace) { + var l = ImmutableArray.Empty; commaWhiteSpace = MakeWhiteSpaceList(); namesWhiteSpace = MakeWhiteSpaceList(); var name = ReadName(); if (name.HasName) { var nameExpr = MakeName(name); - l.Add(nameExpr); + l = l.Add(nameExpr); if (_verbatim) { namesWhiteSpace.Add(_tokenWhiteSpace); } @@ -1440,7 +1435,7 @@ private List ReadNameList(out List commaWhiteSpace, out } name = ReadName(); nameExpr = MakeName(name); - l.Add(nameExpr); + l = l.Add(nameExpr); if (_verbatim) { namesWhiteSpace.Add(_tokenWhiteSpace); } @@ -1554,7 +1549,7 @@ private PrintStatement ParsePrintStmt() { needNonEmptyTestList = true; end = GetEnd(); } else { - ret = new PrintStatement(dest, new Expression[0], false); + ret = new PrintStatement(dest, ImmutableArray.Empty, false); if (_verbatim) { AddPreceedingWhiteSpace(ret, printWhiteSpace); AddSecondPreceedingWhiteSpace(ret, rightShiftWhiteSpace); @@ -1567,25 +1562,21 @@ private PrintStatement ParsePrintStmt() { var trailingComma = false; List commaWhiteSpace = null; - Expression[] exprs; + var expressions = ImmutableArray.Empty; if (!NeverTestToken(PeekToken())) { var expr = ParseExpression(); if (!MaybeEat(TokenKind.Comma)) { - exprs = new[] { expr }; + expressions = expressions.Add(expr); } else { var exprList = ParseTestListAsExpr(expr, out commaWhiteSpace, out trailingComma); - exprs = exprList.ToArray(); - } - } else { - if (needNonEmptyTestList) { - ReportSyntaxError(start, end, "print statement expected expression to be printed"); - exprs = new[] { Error("") }; - } else { - exprs = new Expression[0]; + expressions = ImmutableArray.Create(exprList); } + } else if (needNonEmptyTestList) { + ReportSyntaxError(start, end, "print statement expected expression to be printed"); + expressions = expressions.Add(Error("")); } - ret = new PrintStatement(dest, exprs, trailingComma); + ret = new PrintStatement(dest, expressions, trailingComma); if (_verbatim) { AddPreceedingWhiteSpace(ret, printWhiteSpace); AddSecondPreceedingWhiteSpace(ret, rightShiftWhiteSpace); @@ -1636,7 +1627,7 @@ private Statement ParseClassDef() { var isParenFree = false; string leftParenWhiteSpace = null, rightParenWhiteSpace = null; List commaWhiteSpace = null; - Arg[] args; + var args = ImmutableArray.Empty; var ateTerminator = true; if (MaybeEat(TokenKind.LeftParenthesis)) { leftParenWhiteSpace = _tokenWhiteSpace; @@ -1652,7 +1643,6 @@ private Statement ParseClassDef() { } } else { isParenFree = true; - args = new Arg[0]; } var mid = _lookahead.Span.Start; @@ -2268,17 +2258,18 @@ private WithStatement ParseWithStmt(bool isAsync) { var withWhiteSpace = _tokenWhiteSpace; var itemWhiteSpace = MakeWhiteSpaceList(); - var items = new List { ParseWithItem(itemWhiteSpace) }; + var items = ImmutableArray.Empty + .Add(ParseWithItem(itemWhiteSpace));; while (MaybeEat(TokenKind.Comma)) { itemWhiteSpace?.Add(_tokenWhiteSpace); - items.Add(ParseWithItem(itemWhiteSpace)); + items = items.Add(ParseWithItem(itemWhiteSpace)); } var header = PeekToken(TokenKind.Colon) ? GetEnd() : -1; var body = ParseSuite(); - var ret = new WithStatement(items.ToArray(), body, isAsync) { HeaderIndex = header }; + var ret = new WithStatement(items, body, isAsync) { HeaderIndex = header }; if (_verbatim) { AddPreceedingWhiteSpace(ret, isAsync ? asyncWhiteSpace : withWhiteSpace); AddSecondPreceedingWhiteSpace(ret, isAsync ? withWhiteSpace : null); @@ -2446,7 +2437,7 @@ private IfStatement ParseIfStmt() { else_ = ParseSuite(); } - var tests = l.ToArray(); + var tests = ImmutableArray.Create(l); var ret = new IfStatement(tests, else_); if (_verbatim) { if (elseWhiteSpace != null) { @@ -2457,7 +2448,7 @@ private IfStatement ParseIfStmt() { } } ret.ElseIndex = elseIndex; - ret.SetLoc(start, else_ != null ? else_.EndIndex : tests[tests.Length - 1].EndIndex); + ret.SetLoc(start, else_ != null ? else_.EndIndex : tests[tests.Count - 1].EndIndex); return ret; } @@ -3271,7 +3262,7 @@ private Expression AddTrailers(Expression ret, bool allowGeneratorExpression) { if (args != null) { call = FinishCallExpr(ret, args); } else { - call = new CallExpression(ret, new Arg[0]); + call = new CallExpression(ret, ImmutableArray.Empty); } if (_verbatim) { @@ -3477,7 +3468,7 @@ private List ParseExprList(out List commaWhiteSpace) { // expression "=" expression rest_of_arguments // expression "for" gen_expr_rest // - private Arg[] FinishArgListOrGenExpr(out List commaWhiteSpace, out bool ateTerminator) { + private ImmutableArray FinishArgListOrGenExpr(out List commaWhiteSpace, out bool ateTerminator) { Arg a = null; commaWhiteSpace = MakeWhiteSpaceList(); @@ -3489,7 +3480,7 @@ private Arg[] FinishArgListOrGenExpr(out List commaWhiteSpace, out bool a = new Arg(e); a.SetLoc(e.StartIndex, e.EndIndex); a.EndIndexIncludingWhitespace = e.EndIndex; - return new[] { a }; + return ImmutableArray.Create(a); } if (MaybeEat(TokenKind.Assign)) { // Keyword argument @@ -3501,7 +3492,7 @@ private Arg[] FinishArgListOrGenExpr(out List commaWhiteSpace, out bool a.SetLoc(e.StartIndex, GetEnd()); ateTerminator = Eat(TokenKind.RightParenthesis); a.EndIndexIncludingWhitespace = GetStart(); - return new Arg[1] { a }; // Generator expression is the argument + return ImmutableArray.Create(a); // Generator expression is the argument } else { a = new Arg(e); a.SetLoc(e.StartIndex, e.EndIndex); @@ -3517,7 +3508,7 @@ private Arg[] FinishArgListOrGenExpr(out List commaWhiteSpace, out bool } else { ateTerminator = Eat(TokenKind.RightParenthesis); a.EndIndexIncludingWhitespace = GetStart(); - return new Arg[1] { a }; + return ImmutableArray.Create(a); } } @@ -3560,7 +3551,7 @@ private void CheckUniqueArgument(List names, Arg arg) { //arglist: (argument ',')* (argument [',']| '*' expression [',' '**' expression] | '**' expression) //argument: [expression '='] expression # Really [keyword '='] expression - private Arg[] FinishArgumentList(Arg first, List commaWhiteSpace, out bool ateTerminator) { + private ImmutableArray FinishArgumentList(Arg first, List commaWhiteSpace, out bool ateTerminator) { const TokenKind terminator = TokenKind.RightParenthesis; var l = new List(); @@ -3618,7 +3609,7 @@ private Arg[] FinishArgumentList(Arg first, List commaWhiteSpace, out bo } } - return l.ToArray(); + return ImmutableArray.Create(l); } private Expression ParseOldExpressionListAsExpr() { @@ -3867,7 +3858,6 @@ private Expression ParseGeneratorExpression(Expression expr, string rightParenWh _inGeneratorExpression = true; try { var iters = ParseCompIter(); - ret = new GeneratorExpression(expr, iters); } finally { _inGeneratorExpression = prevIn; @@ -4080,15 +4070,10 @@ private Expression FinishDictOrSetValue() { Expression ret; if (dictMembers != null || setMembers == null) { - SliceExpression[] exprs; - if (dictMembers != null) { - exprs = dictMembers.ToArray(); - } else { - exprs = new SliceExpression[0]; - } - ret = new DictionaryExpression(exprs); + var expressions = dictMembers != null ? ImmutableArray.Create(dictMembers) : ImmutableArray.Empty; + ret = new DictionaryExpression(expressions); } else { - ret = new SetExpression(setMembers.ToArray()); + ret = new SetExpression(ImmutableArray.Create(setMembers)); } ret.SetLoc(oStart, GetEnd()); if (_verbatim) { @@ -4117,22 +4102,22 @@ private DictionaryComprehension FinishDictComp(SliceExpression value, out bool a } // comp_iter: comp_for | comp_if - private ComprehensionIterator[] ParseCompIter() { - var iters = new List(); + private ImmutableArray ParseCompIter() { + var iterators = ImmutableArray.Empty; var firstFor = ParseCompFor(); - iters.Add(firstFor); + iterators = iterators.Add(firstFor); while (true) { if (PeekTokenForOrAsyncFor) { - iters.Add(ParseCompFor()); + iterators = iterators.Add(ParseCompFor()); } else if (PeekToken(Tokens.KeywordIfToken)) { - iters.Add(ParseCompIf()); + iterators = iterators.Add(ParseCompIf()); } else { break; } } - return iters.ToArray(); + return iterators; } private bool PeekTokenForOrAsyncFor { @@ -4224,7 +4209,7 @@ private Expression FinishListValue() { Expression ret; bool ateRightBracket; if (MaybeEat(TokenKind.RightBracket)) { - ret = new ListExpression(); + ret = new ListExpression(ImmutableArray.Empty); ateRightBracket = true; } else { var prevAllow = _allowIncomplete; @@ -4240,7 +4225,7 @@ private Expression FinishListValue() { var l = ParseTestListAsExpr(t0, out var listWhiteSpace, out var trailingComma); ateRightBracket = Eat(TokenKind.RightBracket); - ret = new ListExpression(l.ToArray()); + ret = new ListExpression(ImmutableArray.Create(l)); if (listWhiteSpace != null) { AddListWhiteSpace(ret, listWhiteSpace.ToArray()); @@ -4249,7 +4234,7 @@ private Expression FinishListValue() { ret = FinishListComp(t0, out ateRightBracket); } else { ateRightBracket = Eat(TokenKind.RightBracket); - ret = new ListExpression(t0); + ret = new ListExpression(ImmutableArray.Create(t0)); } } finally { _allowIncomplete = prevAllow; @@ -4276,11 +4261,11 @@ private ListComprehension FinishListComp(Expression item, out bool ateRightBrack } // list_iter: list_for | list_if - private ComprehensionIterator[] ParseListCompIter() { - var iters = new List(); + private ImmutableArray ParseListCompIter() { + var iterators = ImmutableArray.Empty; var firstFor = ParseListCompFor(); - iters.Add(firstFor); + iterators = iterators.Add(firstFor); while (true) { ComprehensionIterator iterator; @@ -4293,10 +4278,10 @@ private ComprehensionIterator[] ParseListCompIter() { break; } - iters.Add(iterator); + iterators = iterators.Add(iterator); } - return iters.ToArray(); + return iterators; } // list_for: 'for' target_list 'in' old_expression_list [list_iter] @@ -4403,8 +4388,8 @@ private Expression MakeTupleOrExpr(List l, List itemWhiteSpa return l[0]; } - var exprs = l.ToArray(); - var te = new TupleExpression(expandable && !trailingComma, exprs); + var expressions = ImmutableArray.Create(l); + var te = new TupleExpression(expandable && !trailingComma, expressions); if (_verbatim) { if (itemWhiteSpace != null) { AddListWhiteSpace(te, itemWhiteSpace.ToArray()); @@ -4413,8 +4398,8 @@ private Expression MakeTupleOrExpr(List l, List itemWhiteSpa AddIsAltForm(te); } } - if (exprs.Length > 0) { - te.SetLoc(exprs[0].StartIndex, exprs[exprs.Length - 1].EndIndex); + if (expressions.Count > 0) { + te.SetLoc(expressions[0].StartIndex, expressions[expressions.Count - 1].EndIndex); } return te; } @@ -4481,7 +4466,7 @@ private void PushFunction(FunctionDefinition function) { _functions.Push(function); } - private CallExpression FinishCallExpr(Expression target, params Arg[] args) { + private CallExpression FinishCallExpr(Expression target, ImmutableArray args) { var hasArgsTuple = false; var hasKeywordDict = false; var keywordCount = 0; diff --git a/src/Parsing/Test/ParserTests.cs b/src/Parsing/Test/ParserTests.cs index b4eb6f68d..baedb7a12 100644 --- a/src/Parsing/Test/ParserTests.cs +++ b/src/Parsing/Test/ParserTests.cs @@ -24,6 +24,7 @@ using FluentAssertions; using Microsoft.Python.Analysis.Core.Interpreter; using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; using Microsoft.Python.Core.Text; using Microsoft.Python.Parsing.Ast; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -3306,7 +3307,7 @@ private static Action IfTest(Action expectedTest, A }; } - private static Action> IfTests(params Action[] expectedTests) { + private static Action> IfTests(params Action[] expectedTests) { return tests => { Assert.AreEqual(expectedTests.Length, tests.Count); for (var i = 0; i < expectedTests.Length; i++) { @@ -3315,7 +3316,7 @@ private static Action> IfTests(params Action CheckIfStmt(Action> tests, Action _else = null) { + private static Action CheckIfStmt(Action> tests, Action _else = null) { return stmt => { Assert.AreEqual(typeof(IfStatement), stmt.GetType()); var ifStmt = (IfStatement)stmt; @@ -3492,12 +3493,12 @@ private static Action CheckClassDef(string name, Action bo } if (bases != null) { - Assert.AreEqual(bases.Length, classDef.Bases.Length); + Assert.AreEqual(bases.Length, classDef.Bases.Count); for (var i = 0; i < bases.Length; i++) { bases[i](classDef.Bases[i]); } } else { - Assert.AreEqual(0, classDef.Bases.Length); + Assert.AreEqual(0, classDef.Bases.Count); } body(classDef.Body); diff --git a/src/UnitTests/Core/Impl/FluentAssertions/AsyncAssertions.cs b/src/UnitTests/Core/Impl/FluentAssertions/AsyncAssertions.cs new file mode 100644 index 000000000..10cc07ce2 --- /dev/null +++ b/src/UnitTests/Core/Impl/FluentAssertions/AsyncAssertions.cs @@ -0,0 +1,81 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions.Execution; +using FluentAssertions.Specialized; + +namespace Microsoft.Python.UnitTests.Core.FluentAssertions { + internal sealed class AsyncAssertions { + private readonly Func _asyncAction; + + public AsyncAssertions(Func asyncAction) { + _asyncAction = asyncAction; + } + + public async Task ShouldNotThrowAsync(string because, object[] becauseArgs) { + var exceptions = await InvokeAction(); + + Execute.Assertion + .ForCondition(exceptions.Count == 0) + .BecauseOf(because, becauseArgs) + .FailWith("Did not expect any exception{reason}, but found:{0}.", exceptions.Select(e => $"\n\t{e.GetType()} with message {e.Message}")); + } + + public async Task> ShouldThrowAsync(string because, object[] becauseArgs) + where TException : Exception { + var exceptions = await InvokeAction(); + + Execute.Assertion + .ForCondition(exceptions.Any()) + .BecauseOf(because, becauseArgs) + .FailWith("Expected {0}{reason}, but no exception was thrown.", typeof(TException)); + + var typedExceptions = exceptions.OfType().ToList(); + Execute.Assertion + .ForCondition(typedExceptions.Any()) + .BecauseOf(because, becauseArgs) + .FailWith("Expected {0}{reason}, but found {1}.", typeof(TException), exceptions); + + return new ExceptionAssertions(typedExceptions); + } + + private Task> InvokeAction() => _asyncAction().ContinueWith(GetExceptions); + + public static List GetExceptions(Task task) { + switch (task.Status) { + case TaskStatus.Canceled: + return GetCanceledException(task); + case TaskStatus.Faulted: + return new List(task.Exception.Flatten().InnerExceptions); + default: + return new List(); + } + } + + private static List GetCanceledException(Task task) { + var exceptions = new List(); + try { + task.GetAwaiter().GetResult(); + } catch (Exception ex) { + exceptions.Add(ex); + } + return exceptions; + } + } +} diff --git a/src/UnitTests/Core/Impl/FluentAssertions/TaskAssertions.cs b/src/UnitTests/Core/Impl/FluentAssertions/TaskAssertions.cs new file mode 100644 index 000000000..99b2fe356 --- /dev/null +++ b/src/UnitTests/Core/Impl/FluentAssertions/TaskAssertions.cs @@ -0,0 +1,173 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using FluentAssertions.Execution; +using FluentAssertions.Primitives; + +namespace Microsoft.Python.UnitTests.Core.FluentAssertions { + public sealed class TaskAssertions : TaskAssertionsBase { + public TaskAssertions(Task task) : base(task) { } + } + + public sealed class TaskAssertions : TaskAssertionsBase, TaskAssertions> { + public TaskAssertions(Task task) : base(task) { } + + public Task>> HaveResultAsync(TResult result, int timeout = 10000, string because = "", params object[] reasonArgs) + => BeInTimeAsync(HaveResultAsyncContinuation, result, timeout, because: because, reasonArgs: reasonArgs); + + private AndConstraint> HaveResultAsyncContinuation(Task task, object state) { + var data = (TimeoutContinuationState)state; + var and = AssertStatus(TaskStatus.RanToCompletion, true, data.Because, data.ReasonArgs, + "Expected task to be completed in {0} milliseconds{reason}, but it is {1}.", data.Timeout, Subject.Status); + + Subject.Result.Should().Be(data.Argument); + return and; + } + } + + public abstract class TaskAssertionsBase : ReferenceTypeAssertions + where TTask : Task + where TAssertions : TaskAssertionsBase { + + protected TaskAssertionsBase(TTask task) { + Subject = task; + } + + protected override string Identifier { get; } = "System.Threading.Tasks.Task"; + + public AndConstraint BeCompleted(string because = "", params object[] reasonArgs) { + Subject.Should().NotBeNull(); + + Execute.Assertion.ForCondition(Subject.IsCompleted) + .BecauseOf(because, reasonArgs) + .FailWith($"Expected task to be completed{{reason}}, but it is {Subject.Status}."); + + return new AndConstraint((TAssertions)this); + } + + public AndConstraint NotBeCompleted(string because = "", params object[] reasonArgs) { + Subject.Should().NotBeNull(); + + Execute.Assertion.ForCondition(!Subject.IsCompleted) + .BecauseOf(because, reasonArgs) + .FailWith($"Expected task not to be completed{{reason}}, but {GetNotBeCompletedMessage()}."); + + return new AndConstraint((TAssertions)this); + } + + private string GetNotBeCompletedMessage() { + var exceptions = AsyncAssertions.GetExceptions(Subject); + switch (Subject.Status) { + case TaskStatus.RanToCompletion: + return "it has run to completion successfully"; + case TaskStatus.Canceled: + return $"it is canceled with exception of type {exceptions[0].GetType()}: {exceptions[0].Message}"; + case TaskStatus.Faulted: + return $@"it is faulted with the following exceptions: {string.Join(Environment.NewLine, exceptions.Select(e => $" {e.GetType()}: {e.Message}"))}"; + default: + return string.Empty; + } + } + + public Task> BeCompletedAsync(int timeout = 30000, string because = "", params object[] reasonArgs) + => BeInTimeAsync(BeCompletedAsyncContinuation, false, timeout, because: because, reasonArgs: reasonArgs); + + public Task> BeCanceledAsync(int timeout = 30000, string because = "", params object[] reasonArgs) + => BeInTimeAsync(BeCanceledAsyncContinuation, false, timeout, because: because, reasonArgs: reasonArgs); + + public Task> NotBeCompletedAsync(int timeout = 5000, string because = "", params object[] reasonArgs) + => BeInTimeAsync(NotBeCompletedAsyncContinuation, false, timeout, 5000, because, reasonArgs); + + protected Task> BeInTimeAsync(Func, object, AndConstraint> continuation, TArg argument, int timeout = 10000, int debuggerTimeout = 100000, string because = "", params object[] reasonArgs) { + Subject.Should().NotBeNull(); + if (Debugger.IsAttached) { + timeout = Math.Max(debuggerTimeout, timeout); + } + + var timeoutTask = Task.Delay(timeout); + var state = new TimeoutContinuationState(argument, timeout, because, reasonArgs); + return Task.WhenAny(timeoutTask, Subject) + .ContinueWith(continuation, state, default(CancellationToken), TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } + + private AndConstraint BeCompletedAsyncContinuation(Task task, object state) { + var data = (TimeoutContinuationState)state; + return AssertStatus(TaskStatus.RanToCompletion, true, data.Because, data.ReasonArgs, + "Expected task to be completed in {0} milliseconds{reason}, but it is {1}.", data.Timeout, Subject.Status); + } + + private AndConstraint BeCanceledAsyncContinuation(Task task, object state) { + var data = (TimeoutContinuationState)state; + return AssertStatus(TaskStatus.Canceled, true, data.Because, data.ReasonArgs, + "Expected task to be canceled in {0} milliseconds{reason}, but it is {1}.", data.Timeout, Subject.Status); + } + + private AndConstraint NotBeCompletedAsyncContinuation(Task task, object state) { + var data = (TimeoutContinuationState)state; + Execute.Assertion.ForCondition(!Subject.IsCompleted) + .BecauseOf(data.Because, data.ReasonArgs) + .FailWith($"Expected task not to be completed in {data.Timeout} milliseconds{{reason}}, but {GetNotBeCompletedMessage()}."); + + return new AndConstraint((TAssertions)this); + } + + public AndConstraint BeRanToCompletion(string because = "", params object[] reasonArgs) + => AssertStatus(TaskStatus.RanToCompletion, true, because, reasonArgs, "Expected task to completed execution successfully{reason}, but it has status {0}.", Subject.Status); + + public AndConstraint BeCanceled(string because = "", params object[] reasonArgs) + => AssertStatus(TaskStatus.Canceled, true, because, reasonArgs, "Expected task to be canceled{reason}, but it has status {0}.", Subject.Status); + + public AndConstraint NotBeCanceled(string because = "", params object[] reasonArgs) + => AssertStatus(TaskStatus.Canceled, false, because, reasonArgs, "Expected task not to be canceled{reason}, but it has status {0}.", Subject.Status); + + public AndConstraint BeFaulted(string because = "", params object[] reasonArgs) + => AssertStatus(TaskStatus.Faulted, true, because, reasonArgs, "Expected task to be faulted{reason}, but it has status {0}.", Subject.Status); + + public AndConstraint BeFaulted(string because = "", params object[] reasonArgs) where TException : Exception + => AssertStatus(TaskStatus.Faulted, false, because, reasonArgs, "Expected task to be faulted with exception of type {0}{reason}, but it has status {1}.", typeof(TException), Subject.Status); + + public AndConstraint NotBeFaulted(string because = "", params object[] reasonArgs) + => AssertStatus(TaskStatus.Faulted, false, because, reasonArgs, "Expected task not to be faulted{reason}, but it has status {0}.", Subject.Status); + + protected AndConstraint AssertStatus(TaskStatus status, bool hasStatus, string because, object[] reasonArgs, string message, params object[] messageArgs) { + Subject.Should().NotBeNull(); + + Execute.Assertion.ForCondition(status == Subject.Status == hasStatus) + .BecauseOf(because, reasonArgs) + .FailWith(message, messageArgs); + + return new AndConstraint((TAssertions)this); + } + + protected class TimeoutContinuationState { + public TimeoutContinuationState(TArg argument, int timeout, string because, object[] reasonArgs) { + Argument = argument; + Because = because; + ReasonArgs = reasonArgs; + Timeout = timeout; + } + public TArg Argument { get; } + public int Timeout { get; } + public string Because { get; } + public object[] ReasonArgs { get; } + } + } +} diff --git a/src/Analysis/Ast/Impl/Extensions/ModuleResolutionExtensions.cs b/src/UnitTests/Core/Impl/FluentAssertions/TaskAssertionsExtensions.cs similarity index 59% rename from src/Analysis/Ast/Impl/Extensions/ModuleResolutionExtensions.cs rename to src/UnitTests/Core/Impl/FluentAssertions/TaskAssertionsExtensions.cs index b23aecd38..0e8855fb4 100644 --- a/src/Analysis/Ast/Impl/Extensions/ModuleResolutionExtensions.cs +++ b/src/UnitTests/Core/Impl/FluentAssertions/TaskAssertionsExtensions.cs @@ -1,5 +1,4 @@ -// Python Tools for Visual Studio -// Copyright(c) Microsoft Corporation +// Copyright(c) Microsoft Corporation // All rights reserved. // // Licensed under the Apache License, Version 2.0 (the License); you may not use @@ -14,14 +13,16 @@ // See the Apache Version 2.0 License for specific language governing // permissions and limitations under the License. -using Microsoft.Python.Analysis.Modules; -using Microsoft.Python.Analysis.Types; +using System.Threading.Tasks; -namespace Microsoft.Python.Analysis { - public static class ModuleResolutionExtensions { - public static IPythonModule ImportModule(this IModuleResolution m, string name, int timeout) { - m.ImportModuleAsync(name).Wait(timeout); - return m.GetImportedModule(name); +namespace Microsoft.Python.UnitTests.Core.FluentAssertions { + public static class TaskAssertionsExtensions { + public static TaskAssertions Should(this Task task) { + return new TaskAssertions(task); + } + + public static TaskAssertions Should(this Task task) { + return new TaskAssertions(task); } } } diff --git a/src/UnitTests/Core/Impl/TestEnvironmentImpl.cs b/src/UnitTests/Core/Impl/TestEnvironmentImpl.cs index db1d02cc0..be91f43e0 100644 --- a/src/UnitTests/Core/Impl/TestEnvironmentImpl.cs +++ b/src/UnitTests/Core/Impl/TestEnvironmentImpl.cs @@ -122,4 +122,4 @@ private AggregateException RunDisposablesSafe(Stack disposables) { return exceptions.Count > 0 ? new AggregateException(exceptions) : null; } } -} \ No newline at end of file +}