diff --git a/src/CommonLib/AdaptiveTimeout.cs b/src/CommonLib/AdaptiveTimeout.cs new file mode 100644 index 000000000..31aeb9033 --- /dev/null +++ b/src/CommonLib/AdaptiveTimeout.cs @@ -0,0 +1,226 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using SharpHoundRPC.NetAPINative; + +namespace SharpHoundCommonLib; + +public sealed class AdaptiveTimeout : IDisposable { + private readonly ExecutionTimeSampler _sampler; + private readonly ILogger _log; + private readonly TimeSpan _maxTimeout; + private readonly bool _useAdaptiveTimeout; + private readonly int _minSamplesForAdaptiveTimeout; + private int _clearSamplesDecay; + private const int TimeSpikePenalty = 2; + private const int TimeSpikeForgiveness = 1; + private const int ClearSamplesThreshold = 5; + private const int StdDevMultiplier = 5; + + public AdaptiveTimeout(TimeSpan maxTimeout, ILogger log, int sampleCount = 100, int logFrequency = 1000, int minSamplesForAdaptiveTimeout = 30, bool useAdaptiveTimeout = true) { + if (maxTimeout <= TimeSpan.Zero) + throw new ArgumentException("maxTimeout must be positive", nameof(maxTimeout)); + if (sampleCount <= 0) + throw new ArgumentException("sampleCount must be positive", nameof(sampleCount)); + if (logFrequency <= 0) + throw new ArgumentException("logFrequency must be positive", nameof(logFrequency)); + if (minSamplesForAdaptiveTimeout <= 0) + throw new ArgumentException("minSamplesForAdaptiveTimeout must be positive", nameof(minSamplesForAdaptiveTimeout)); + if (log == null) + throw new ArgumentNullException(nameof(log)); + + _sampler = new ExecutionTimeSampler(log, sampleCount, logFrequency); + _log = log; + _maxTimeout = maxTimeout; + _useAdaptiveTimeout = useAdaptiveTimeout; + _minSamplesForAdaptiveTimeout = minSamplesForAdaptiveTimeout; + } + + public void ClearSamples() { + _clearSamplesDecay = 0; + _sampler.ClearSamples(); + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task> ExecuteWithTimeout(Func func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task ExecuteWithTimeout(Action func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task> ExecuteWithTimeout(Func> func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task ExecuteWithTimeout(Func func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task> ExecuteNetAPIWithTimeout(Func> func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteNetAPIWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task> ExecuteRPCWithTimeout(Func> func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteRPCWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + /// + /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller. + /// Logs aggregate execution time data. + /// Manages its own timeout. + /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached. + /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better. + /// DO NOT use a single AdaptiveTimeout for multiple functions. + /// + /// + /// + /// + /// Returns a Fail result if a task runs longer than its budgeted time. + public async Task> ExecuteRPCWithTimeout(Func>> func, CancellationToken parentToken = default) { + var result = await Timeout.ExecuteRPCWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken); + TimeSpikeSafetyValve(result.IsSuccess); + return result; + } + + public void Dispose() { + _sampler.Dispose(); + } + + // Within 5 standard deviations will have a conservative lower bound of catching 96% of executions (1 - 1/5^2), + // regardless of sample shape + // so long as those samples are independent and identically distributed + // (and if they're not, our TimeSpikeSafetyValve should provide us with some adaptability) + // But the effective collection rate is probably closer to 98+% + // (in part because we don't need to filter out "too fast" outliers) + // But we'll cap at configured maximum timeout + // https://modelassist.epixanalytics.com/space/EA/26574957/Tchebysheffs+Rule + // https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables + public TimeSpan GetAdaptiveTimeout() { + if (!UseAdaptiveTimeout()) + return _maxTimeout; + + try { + var stdDev = _sampler.StandardDeviation(); + var adaptiveTimeoutMs = _sampler.Average() + (stdDev * StdDevMultiplier); + var cappedTimeoutMS = Math.Min(adaptiveTimeoutMs, _maxTimeout.TotalMilliseconds); + return TimeSpan.FromMilliseconds(cappedTimeoutMS); + } + catch (Exception ex) { + _log.LogError(ex, "Error calculating adaptive timeout, defaulting to max timeout."); + return _maxTimeout; + } + } + + // AdaptiveTimeout will not respond well to rapid spikes in execution time + // imagine the wrapped function very regularly executes in 10ms + // then suddenly starts taking a regular 100ms + // this is fine (if it fits in our max timeout budget), and we shouldn't block + // so we should create a safety valve in case this happens to reset our data samples + private void TimeSpikeSafetyValve(bool isSuccess) { + if (isSuccess) { + _clearSamplesDecay -= TimeSpikeForgiveness; + _clearSamplesDecay = Math.Max(0, _clearSamplesDecay); + } + else + _clearSamplesDecay += TimeSpikePenalty; + + + if (_clearSamplesDecay >= ClearSamplesThreshold) { + if (UseAdaptiveTimeout()) { + ClearSamples(); + _log.LogTrace("Time spike safety valve event at timeout {CurrentTimeout}.", GetAdaptiveTimeout()); + } + else { + _log.LogWarning("This call is frequently running over the maximum allowed timeout of {MaxTimeout}.", _maxTimeout); + } + } + } + + private bool UseAdaptiveTimeout() { + return _useAdaptiveTimeout && _sampler.Count >= _minSamplesForAdaptiveTimeout; + } +} \ No newline at end of file diff --git a/src/CommonLib/ExecutionTimeSampler.cs b/src/CommonLib/ExecutionTimeSampler.cs new file mode 100644 index 000000000..fc99f7e45 --- /dev/null +++ b/src/CommonLib/ExecutionTimeSampler.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace SharpHoundCommonLib; + +/// +/// Holds a rolling sample of execution times on a function, providing and logging data aggregates. +/// +public class ExecutionTimeSampler : IDisposable { + private readonly ILogger _log; + private readonly int _sampleCount; + private readonly int _logFrequency; + private int _samplesSinceLastLog; + private ConcurrentQueue _samples; + + public int Count => _samples.Count; + + public ExecutionTimeSampler(ILogger log, int sampleCount, int logFrequency) { + _log = log; + _sampleCount = sampleCount; + _logFrequency = logFrequency; + _samplesSinceLastLog = 0; + _samples = new ConcurrentQueue(); + } + + public void ClearSamples() { + Log(flush: true); + _samples = new ConcurrentQueue(); + } + + public double StandardDeviation() { + double average = _samples.Average(); + double sumOfSquaresOfDifferences = _samples.Select(val => (val - average) * (val - average)).Sum(); + double stddiv = Math.Sqrt(sumOfSquaresOfDifferences / _samples.Count); + + return stddiv; + } + + public double Average() => _samples.Average(); + + public async Task SampleExecutionTime(Func> func) { + var stopwatch = Stopwatch.StartNew(); + var result = await func.Invoke(); + stopwatch.Stop(); + AddTimeSample(stopwatch.Elapsed); + + return result; + } + + public async Task SampleExecutionTime(Func func) { + var stopwatch = Stopwatch.StartNew(); + await func.Invoke(); + stopwatch.Stop(); + AddTimeSample(stopwatch.Elapsed); + } + + public T SampleExecutionTime(Func func) { + var stopwatch = Stopwatch.StartNew(); + var result = func.Invoke(); + stopwatch.Stop(); + AddTimeSample(stopwatch.Elapsed); + + return result; + } + + public void SampleExecutionTime(Action func) { + var stopwatch = Stopwatch.StartNew(); + func.Invoke(); + stopwatch.Stop(); + AddTimeSample(stopwatch.Elapsed); + } + + public void Dispose() { + Log(flush: true); + } + + private void AddTimeSample(TimeSpan timeSpan) { + while (_samples.Count >= _sampleCount) { + _samples.TryDequeue(out _); + } + + _samples.Enqueue(timeSpan.TotalMilliseconds); + _samplesSinceLastLog++; + + Log(); + } + + private void Log(bool flush = false) { + if ((flush || _samplesSinceLastLog >= _logFrequency) && _samples.Count > 0) { + try { + _log.LogInformation("Execution time Average: {Average}ms, StdDiv: {StandardDeviation}ms", _samples.Average(), StandardDeviation()); + } + catch (Exception ex) { + _log.LogWarning("Failed to calculate execution time statistics: {Error}", ex.Message); + } + + _samplesSinceLastLog = 0; + } + } +} \ No newline at end of file diff --git a/src/CommonLib/IRegistryKey.cs b/src/CommonLib/IRegistryKey.cs index c61991082..5c1d0a9c7 100644 --- a/src/CommonLib/IRegistryKey.cs +++ b/src/CommonLib/IRegistryKey.cs @@ -10,6 +10,7 @@ public interface IRegistryKey { public class SHRegistryKey : IRegistryKey, IDisposable { private readonly RegistryKey _currentKey; + private static readonly AdaptiveTimeout _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromSeconds(10), Logging.LogProvider.CreateLogger(nameof(SHRegistryKey))); private SHRegistryKey(RegistryKey registryKey) { _currentKey = registryKey; @@ -35,7 +36,7 @@ public object GetValue(string subkey, string name) { /// /// public static async Task Connect(RegistryHive hive, string machineName) { - var remoteKey = await Timeout.ExecuteWithTimeout(TimeSpan.FromSeconds(10), (_) => RegistryKey.OpenRemoteBaseKey(hive, machineName)); + var remoteKey = await _adaptiveTimeout.ExecuteWithTimeout((_) => RegistryKey.OpenRemoteBaseKey(hive, machineName)); if (remoteKey.IsSuccess) return new SHRegistryKey(remoteKey.Value); else diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index 3862c68d0..84719438d 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -27,6 +27,10 @@ internal class LdapConnectionPool : IDisposable { private readonly ILogger _log; private readonly IPortScanner _portScanner; private readonly NativeMethods _nativeMethods; + private readonly AdaptiveTimeout _queryAdaptiveTimeout; + private readonly AdaptiveTimeout _pagedQueryAdaptiveTimeout; + private readonly AdaptiveTimeout _rangedRetrievalAdaptiveTimeout; + private readonly AdaptiveTimeout _testConnectionAdaptiveTimeout; private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2); private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20); private const int BackoffDelayMultiplier = 2; @@ -54,6 +58,10 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c _log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool"); _portScanner = scanner ?? new PortScanner(); _nativeMethods = nativeMethods ?? new NativeMethods(); + _queryAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapQuery")); + _pagedQueryAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapPagedQuery")); + _rangedRetrievalAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapRangedRetrieval")); + _testConnectionAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("TestLdapConnection")); } private async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection( @@ -100,7 +108,7 @@ public async IAsyncEnumerable> Query(LdapQueryParam try { _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2)); + response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _queryAdaptiveTimeout); if (response != null) { querySuccess = true; @@ -165,6 +173,16 @@ public async IAsyncEnumerable> Query(LdapQueryParam var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } + catch (TimeoutException) when (busyRetryCount < MaxRetries) { + /* + * Treat a timeout as a busy error + */ + busyRetryCount++; + _log.LogDebug("Query - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", + queryParameters.GetQueryInfo(), busyRetryCount); + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) { /* * This is our fallback catch. If our retry counts have been exhausted this will trigger and break us out of our loop @@ -255,7 +273,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery SearchResponse response = null; try { _log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2)); + response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _pagedQueryAdaptiveTimeout); if (response != null) { pageResponse = (PageResultResponseControl)response.Controls .Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault(); @@ -326,6 +344,16 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } + catch (TimeoutException) when (busyRetryCount < MaxRetries) { + /* + * Treat a timeout as a busy error + */ + busyRetryCount++; + _log.LogDebug("PagedQuery - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", + queryParameters.GetQueryInfo(), busyRetryCount); + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) { tempResult = LdapResult.Fail( $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", @@ -468,7 +496,7 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish } try { - response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2)); + response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _rangedRetrievalAdaptiveTimeout); } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { busyRetryCount++; @@ -477,6 +505,16 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } + catch (TimeoutException) when (busyRetryCount < MaxRetries) { + /* + * Treat a timeout as a busy error + */ + busyRetryCount++; + _log.LogDebug("RangedRetrieval - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", + queryParameters.GetQueryInfo(), busyRetryCount); + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && queryRetryCount < MaxRetries) { queryRetryCount++; @@ -741,7 +779,7 @@ public void Dispose() { } string tempDomainName; - + // Blocking External Call var dsGetDcNameResult = _nativeMethods.CallDsGetDcName(null, _identifier, (uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY | @@ -941,7 +979,7 @@ private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl var searchRequest = CreateSearchRequest("", new LdapFilter().AddAllObjects().GetFilter(), SearchScope.Base, null); - response = (SearchResponse)await connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2)); + response = await SendRequestWithTimeout(connection, searchRequest, _testConnectionAdaptiveTimeout); } catch (LdapException e) { /* @@ -998,5 +1036,17 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); return searchRequest; } + + private async Task SendRequestWithTimeout(LdapConnection connection, SearchRequest request, AdaptiveTimeout adaptiveTimeout) { + // Add padding to account for network latency and processing overhead + const int TimeoutPaddingSeconds = 3; + var timeout = adaptiveTimeout.GetAdaptiveTimeout(); + var timeoutWithPadding = timeout + TimeSpan.FromSeconds(TimeoutPaddingSeconds); + var result = await adaptiveTimeout.ExecuteWithTimeout((_) => connection.SendRequestAsync(request, timeoutWithPadding)); + if (result.IsSuccess) + return (SearchResponse)result.Value; + else + throw new TimeoutException($"LDAP {request.Scope} query to '{request.DistinguishedName}' timed out after {timeout.TotalMilliseconds}ms."); + } } } \ No newline at end of file diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 6d847c31d..14612da12 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -7,7 +7,6 @@ using System.Linq; using System.Net; using System.Net.Sockets; -using System.Security.Cryptography; using System.Security.Principal; using System.Text; using System.Text.RegularExpressions; @@ -38,6 +37,10 @@ public class LdapUtils : ILdapUtils { private static readonly ConcurrentDictionary SeenWellKnownPrincipals = new(); + private static readonly AdaptiveTimeout _requestNetBiosNameAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(RequestNETBIOSNameFromComputerAsync))); + + private static readonly AdaptiveTimeout _callNetWkstaGetInfoAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(NativeMethods.CallNetWkstaGetInfo))); + private readonly ConcurrentDictionary _hostResolutionMap = new(StringComparer.OrdinalIgnoreCase); @@ -713,7 +716,7 @@ public bool GetDomain(out Domain domain) { } // Blocking External Call - var result = await Timeout.ExecuteNetAPIWithTimeout(TimeSpan.FromMinutes(2), (_) => _nativeMethods.CallNetWkstaGetInfo(hostname)); + var result = await _callNetWkstaGetInfoAdaptiveTimeout.ExecuteNetAPIWithTimeout((_) => _nativeMethods.CallNetWkstaGetInfo(hostname)); if (result.IsSuccess) return (true, result.Value); @@ -781,7 +784,7 @@ public bool GetDomain(out Domain domain) { } private static async Task<(bool Success, string NetBiosName)> RequestNETBIOSNameFromComputerWithTimeout(string server, string domain) { - var result = await Timeout.ExecuteWithTimeout(TimeSpan.FromMinutes(1), async (timeoutToken) => await RequestNETBIOSNameFromComputerAsync(server, domain, timeoutToken)); + var result = await _requestNetBiosNameAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => await RequestNETBIOSNameFromComputerAsync(server, domain, timeoutToken)); if (result.IsSuccess) return (result.Value.Success, result.Value.NetBiosName); else diff --git a/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs b/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs index e765a11b0..0138a632b 100644 --- a/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs +++ b/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs @@ -16,35 +16,37 @@ namespace SharpHoundCommonLib.Ntlm; public class HttpNtlmAuthenticationService { private readonly ILogger _logger; private readonly IHttpClientFactory _httpClientFactory; + private readonly AdaptiveTimeout _getSupportedNTLMAuthSchemesAdaptiveTimeout; + private readonly AdaptiveTimeout _ntlmAuthAdaptiveTimeout; + private readonly AdaptiveTimeout _authWithChannelBindingAdaptiveTimeout; public HttpNtlmAuthenticationService(IHttpClientFactory httpClientFactory, ILogger logger = null) { _logger = logger ?? Logging.LogProvider.CreateLogger(nameof(HttpNtlmAuthenticationService)); _httpClientFactory = httpClientFactory; + _getSupportedNTLMAuthSchemesAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(GetSupportedNtlmAuthSchemesAsync))); + _ntlmAuthAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(NtlmAuthenticationHandler.PerformNtlmAuthenticationAsync))); + _authWithChannelBindingAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(AuthWithBadChannelBindingsAsync))); } - public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings, TimeSpan timeout = default) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - + public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings) { if (url == null) throw new ArgumentException("Url property is null"); if (useBadChannelBindings == null && url.Scheme == "https") throw new ArgumentException("When using HTTPS, useBadChannelBindings must be set"); - var supportedAuthSchemes = await GetSupportedNtlmAuthSchemesAsync(url, timeout); + var supportedAuthSchemes = await GetSupportedNtlmAuthSchemesAsync(url); _logger.LogDebug($"Supported NTLM auth schemes for {url}: " + string.Join(",", supportedAuthSchemes)); foreach (var authScheme in supportedAuthSchemes) { if (useBadChannelBindings == null) { - await AuthWithBadChannelBindingsAsync(url, authScheme, timeout); + await AuthWithBadChannelBindingsAsync(url, authScheme); } else { if ((bool)useBadChannelBindings) { - await AuthWithBadChannelBindingsAsync(url, authScheme, timeout); + await AuthWithBadChannelBindingsAsync(url, authScheme); } else { - await AuthWithChannelBindingAsync(url, authScheme, timeout); + await AuthWithChannelBindingAsync(url, authScheme); } } @@ -53,11 +55,11 @@ public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings, TimeS } } - private async Task GetSupportedNtlmAuthSchemesAsync(Uri url, TimeSpan timeout) { + private async Task GetSupportedNtlmAuthSchemesAsync(Uri url) { var httpClient = _httpClientFactory.CreateUnauthenticatedClient(); using var getRequest = new HttpRequestMessage(HttpMethod.Get, url); - var result = await Timeout.ExecuteWithTimeout(timeout, async (timeoutToken) => { + var result = await _getSupportedNTLMAuthSchemesAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => { var getResponse = await httpClient.SendAsync(getRequest, timeoutToken); return ExtractAuthSchemes(getResponse); }); @@ -101,12 +103,12 @@ internal string[] ExtractAuthSchemes(HttpResponseMessage response) { return schemes; } - private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, TimeSpan timeout, NtlmAuthenticationHandler ntlmAuth = null) { + private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, NtlmAuthenticationHandler ntlmAuth = null) { var httpClient = _httpClientFactory.CreateUnauthenticatedClient(); var transport = new HttpTransport(httpClient, url, authScheme, _logger); var ntlmAuthHandler = ntlmAuth ?? new NtlmAuthenticationHandler($"HTTP/{url.Host}"); - var result = await Timeout.ExecuteWithTimeout(timeout, (timeoutToken) => ntlmAuthHandler.PerformNtlmAuthenticationAsync(transport, timeoutToken)); + var result = await _ntlmAuthAdaptiveTimeout.ExecuteWithTimeout((timeoutToken) => ntlmAuthHandler.PerformNtlmAuthenticationAsync(transport, timeoutToken)); if (!result.IsSuccess) { throw new TimeoutException($"Timeout during NTLM authentication for {url} with {authScheme}"); @@ -139,7 +141,7 @@ private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, T response.EnsureSuccessStatusCode(); } - private async Task AuthWithChannelBindingAsync(Uri url, string authScheme, TimeSpan timeout) { + private async Task AuthWithChannelBindingAsync(Uri url, string authScheme) { var handler = new HttpClientHandler { ServerCertificateCustomValidationCallback = (httpRequestMessage, cert, cetChain, policyErrors) => true, }; @@ -153,7 +155,7 @@ private async Task AuthWithChannelBindingAsync(Uri url, string authScheme, using var client = new HttpClient(handler); - var result = await Timeout.ExecuteWithTimeout(timeout, async (timeoutToken) => { + var result = await _authWithChannelBindingAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => { try { HttpResponseMessage response = await client.GetAsync(url, timeoutToken); return response.StatusCode == HttpStatusCode.OK; diff --git a/src/CommonLib/Processors/CertAbuseProcessor.cs b/src/CommonLib/Processors/CertAbuseProcessor.cs index b0c0ca088..4da05f191 100644 --- a/src/CommonLib/Processors/CertAbuseProcessor.cs +++ b/src/CommonLib/Processors/CertAbuseProcessor.cs @@ -18,14 +18,17 @@ public class CertAbuseProcessor { private readonly ILogger _log; private readonly ILdapUtils _utils; + private readonly AdaptiveTimeout _getMachineSidAdaptiveTimeout; + private readonly AdaptiveTimeout _openSamServerAdaptiveTimeout; public delegate Task ComputerStatusDelegate(CSVComputerStatus status); public event ComputerStatusDelegate ComputerStatusEvent; - - public CertAbuseProcessor(ILdapUtils utils, ILogger log = null) - { + + public CertAbuseProcessor(ILdapUtils utils, ILogger log = null) { _utils = utils; _log = log ?? Logging.LogProvider.CreateLogger("CAProc"); + _getMachineSidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetMachineSid))); + _openSamServerAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(SAMServer.OpenServer))); } /// @@ -365,7 +368,7 @@ await SendComputerStatus(new CSVComputerStatus } var server = openServerResult.Value; - var getMachineSidResult = await Timeout.ExecuteRPCWithTimeout(TimeSpan.FromMinutes(2), (timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken)); + var getMachineSidResult = await _getMachineSidAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken)); if (getMachineSidResult.IsFailed) { _log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError); @@ -445,7 +448,7 @@ await SendComputerStatus(new CSVComputerStatus public virtual SharpHoundRPC.Result OpenSamServer(string computerName) { - var result = Timeout.ExecuteRPCWithTimeout(TimeSpan.FromMinutes(2), (_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult(); + var result = _openSamServerAdaptiveTimeout.ExecuteRPCWithTimeout((_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult(); if (result.IsFailed) { return SharpHoundRPC.Result.Fail(result.SError); diff --git a/src/CommonLib/Processors/ComputerAvailability.cs b/src/CommonLib/Processors/ComputerAvailability.cs index 6f2c9918c..489ebde1a 100644 --- a/src/CommonLib/Processors/ComputerAvailability.cs +++ b/src/CommonLib/Processors/ComputerAvailability.cs @@ -11,14 +11,12 @@ public class ComputerAvailability { private readonly int _computerExpiryDays; private readonly ILogger _log; private readonly IPortScanner _scanner; - private readonly int _scanTimeout; private readonly bool _skipPasswordCheck; private readonly bool _skipPortScan; public ComputerAvailability(int timeout = 10000, int computerExpiryDays = 60, bool skipPortScan = false, bool skipPasswordCheck = false, ILogger log = null) { - _scanner = new PortScanner(); - _scanTimeout = timeout; + _scanner = new PortScanner(maxTimeout: timeout); _skipPortScan = skipPortScan; _log = log ?? Logging.LogProvider.CreateLogger("CompAvail"); _computerExpiryDays = computerExpiryDays; @@ -28,8 +26,7 @@ public ComputerAvailability(int timeout = 10000, int computerExpiryDays = 60, bo public ComputerAvailability(IPortScanner scanner, int timeout = 500, int computerExpiryDays = 60, bool skipPortScan = false, bool skipPasswordCheck = false, ILogger log = null) { - _scanner = scanner ?? new PortScanner(); - _scanTimeout = timeout; + _scanner = scanner ?? new PortScanner(maxTimeout: timeout); _skipPortScan = skipPortScan; _log = log ?? Logging.LogProvider.CreateLogger("CompAvail"); _computerExpiryDays = computerExpiryDays; @@ -101,7 +98,7 @@ await SendComputerStatus(new CSVComputerStatus { Error = null }; - if (!await _scanner.CheckPort(computerName, timeout: _scanTimeout)) { + if (!await _scanner.CheckPort(computerName)) { _log.LogTrace("{ComputerName} is not available because port 445 is unavailable", computerName); await SendComputerStatus(new CSVComputerStatus { Status = ComputerStatus.PortNotOpen, diff --git a/src/CommonLib/Processors/ComputerSessionProcessor.cs b/src/CommonLib/Processors/ComputerSessionProcessor.cs index f66eaf5b9..2486b3d0e 100644 --- a/src/CommonLib/Processors/ComputerSessionProcessor.cs +++ b/src/CommonLib/Processors/ComputerSessionProcessor.cs @@ -22,6 +22,8 @@ public class ComputerSessionProcessor { private readonly bool _doLocalAdminSessionEnum; private readonly string _localAdminUsername; private readonly string _localAdminPassword; + private readonly AdaptiveTimeout _readUserSessionsAdaptiveTimeout; + private readonly AdaptiveTimeout _readUserSessionsPriviledgedAdaptiveTimeout; public ComputerSessionProcessor(ILdapUtils utils, NativeMethods nativeMethods = null, ILogger log = null, string currentUserName = null, @@ -34,6 +36,8 @@ public ComputerSessionProcessor(ILdapUtils utils, _doLocalAdminSessionEnum = doLocalAdminSessionEnum; _localAdminUsername = localAdminUsername; _localAdminPassword = localAdminPassword; + _readUserSessionsAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ReadUserSessions))); + _readUserSessionsPriviledgedAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ReadUserSessionsPrivileged))); } public event ComputerStatusDelegate ComputerStatusEvent; @@ -48,16 +52,12 @@ public ComputerSessionProcessor(ILdapUtils utils, /// /// public async Task ReadUserSessions(string computerName, string computerSid, - string computerDomain, TimeSpan timeout = default) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - + string computerDomain) { var ret = new SessionAPIResult(); _log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName); - var result = await Timeout.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => { + var result = await _readUserSessionsAdaptiveTimeout.ExecuteNetAPIWithTimeout((timeoutToken) => { NetAPIResult> result; if (_doLocalAdminSessionEnum) { // If we are authenticating using a local admin, we need to impersonate for this @@ -190,15 +190,11 @@ await SendComputerStatus(new CSVComputerStatus { /// /// public async Task ReadUserSessionsPrivileged(string computerName, - string computerSamAccountName, string computerSid, TimeSpan timeout = default) { + string computerSamAccountName, string computerSid) { var ret = new SessionAPIResult(); - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - _log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName); - var result = await Timeout.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => { + var result = await _readUserSessionsPriviledgedAdaptiveTimeout.ExecuteNetAPIWithTimeout((timeoutToken) => { NetAPIResult> result; if (_doLocalAdminSessionEnum) { diff --git a/src/CommonLib/Processors/DCLdapProcessor.cs b/src/CommonLib/Processors/DCLdapProcessor.cs index 1272ad013..1fe544355 100644 --- a/src/CommonLib/Processors/DCLdapProcessor.cs +++ b/src/CommonLib/Processors/DCLdapProcessor.cs @@ -22,43 +22,41 @@ public class LdapAuthOptions { public class DCLdapProcessor { private readonly ILogger _log; private readonly IPortScanner _scanner; - private readonly int _portScanTimeout; private readonly int _ldapTimeout; private readonly Uri _ldapEndpoint; private readonly Uri _ldapSslEndpoint; + private readonly AdaptiveTimeout _checkIsNtlmSigningRequiredAdaptiveTimeout; + private readonly AdaptiveTimeout _checkIsChannelBindingDisabledAdaptiveTimeout; public delegate Task ComputerStatusDelegate(CSVComputerStatus status); private readonly string SEC_E_UNSUPPORTED_FUNCTION = "80090302"; private readonly string SEC_E_BAD_BINDINGS = "80090346"; - public DCLdapProcessor(int portScanTimeout, string dcHostname, ILogger log = null) { + public DCLdapProcessor(int connectionTimeoutMs, string dcHostname, ILogger log = null) { _log = log ?? Logging.LogProvider.CreateLogger("DCLdapProcessor"); - _scanner = new PortScanner(); - _portScanTimeout = portScanTimeout; - _ldapTimeout = portScanTimeout / 1000; + _scanner = new PortScanner(maxTimeout: connectionTimeoutMs); + _ldapTimeout = connectionTimeoutMs / 1000; _ldapEndpoint = new Uri($"ldap://{dcHostname}:389"); _ldapSslEndpoint = new Uri($"ldaps://{dcHostname}:636"); + _checkIsNtlmSigningRequiredAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(CheckIsNtlmSigningRequired))); + _checkIsChannelBindingDisabledAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(CheckIsChannelBindingDisabled))); } public event ComputerStatusDelegate ComputerStatusEvent; - public async Task Scan(string computerName, TimeSpan timeout = default) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - + public async Task Scan(string computerName) { var hasLdap = await TestLdapPort(); var hasLdaps = await TestLdapsPort(); SharpHoundRPC.Result isSigningRequired = new(), isChannelBindingDisabled = new(); if (hasLdap) { - isSigningRequired = await Timeout.ExecuteRPCWithTimeout(timeout, CheckIsNtlmSigningRequired); + isSigningRequired = await _checkIsNtlmSigningRequiredAdaptiveTimeout.ExecuteRPCWithTimeout(CheckIsNtlmSigningRequired); } if (hasLdaps) { - isChannelBindingDisabled = await Timeout.ExecuteRPCWithTimeout(timeout, CheckIsChannelBindingDisabled); + isChannelBindingDisabled = await _checkIsChannelBindingDisabledAdaptiveTimeout.ExecuteRPCWithTimeout(CheckIsChannelBindingDisabled); } if (isSigningRequired.IsFailed) { @@ -117,12 +115,12 @@ await SendComputerStatus(new CSVComputerStatus { /// bool [ExcludeFromCodeCoverage] public virtual async Task TestLdapPort() { - return await _scanner.CheckPort(_ldapEndpoint.Host, _ldapEndpoint.Port, _portScanTimeout); + return await _scanner.CheckPort(_ldapEndpoint.Host, _ldapEndpoint.Port); } [ExcludeFromCodeCoverage] public virtual async Task TestLdapsPort() { - return await _scanner.CheckPort(_ldapSslEndpoint.Host, _ldapSslEndpoint.Port, _portScanTimeout); + return await _scanner.CheckPort(_ldapSslEndpoint.Host, _ldapSslEndpoint.Port); } public async Task> CheckIsNtlmSigningRequired(CancellationToken cancellationToken = default) { diff --git a/src/CommonLib/Processors/LocalGroupProcessor.cs b/src/CommonLib/Processors/LocalGroupProcessor.cs index f68ff781f..984ed53a4 100644 --- a/src/CommonLib/Processors/LocalGroupProcessor.cs +++ b/src/CommonLib/Processors/LocalGroupProcessor.cs @@ -16,22 +16,33 @@ public class LocalGroupProcessor public delegate Task ComputerStatusDelegate(CSVComputerStatus status); private readonly ILogger _log; private readonly ILdapUtils _utils; - - public LocalGroupProcessor(ILdapUtils utils, ILogger log = null) - { + private readonly AdaptiveTimeout _getMachineSidAdaptiveTimeout; + private readonly AdaptiveTimeout _openSamServerAdaptiveTimeout; + private readonly AdaptiveTimeout _getDomainsAdaptiveTimeout; + private readonly AdaptiveTimeout _openDomainAdaptiveTimeout; + private readonly AdaptiveTimeout _getAliasesAdaptiveTimeout; + private readonly AdaptiveTimeout _openAliasAdaptiveTimeout; + private readonly AdaptiveTimeout _getMembersAdaptiveTimeout; + private readonly AdaptiveTimeout _lookupPrincipalBySidAdaptiveTimeout; + + public LocalGroupProcessor(ILdapUtils utils, ILogger log = null) { _utils = utils; _log = log ?? Logging.LogProvider.CreateLogger("LocalGroupProcessor"); + _getMachineSidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetMachineSid))); + _openSamServerAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(SAMServer.OpenServer))); + _getDomainsAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetDomains))); + _openDomainAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.OpenDomain))); + _getAliasesAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMDomain.GetAliases))); + _openAliasAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMDomain.OpenAlias))); + _getMembersAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMAlias.GetMembers))); + _lookupPrincipalBySidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.LookupPrincipalBySid))); } public event ComputerStatusDelegate ComputerStatusEvent; - public virtual SharpHoundRPC.Result OpenSamServer(string computerName, TimeSpan timeout = default) + public virtual SharpHoundRPC.Result OpenSamServer(string computerName) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - - var result = Timeout.ExecuteRPCWithTimeout(timeout, (_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult(); + var result = _openSamServerAdaptiveTimeout.ExecuteRPCWithTimeout((_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult(); if (result.IsFailed) { return SharpHoundRPC.Result.Fail(result.SError); @@ -55,14 +66,10 @@ public IAsyncEnumerable GetLocalGroups(ResolvedSearchResult /// /// public async IAsyncEnumerable GetLocalGroups(string computerName, string computerObjectId, - string computerDomain, bool isDomainController, TimeSpan timeout = default) + string computerDomain, bool isDomainController) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - //Open a handle to the server - var openServerResult = OpenSamServer(computerName, timeout); + var openServerResult = OpenSamServer(computerName); if (openServerResult.IsFailed) { _log.LogTrace("OpenServer failed on {ComputerName}: {Error}", computerName, openServerResult.SError); @@ -81,7 +88,7 @@ await SendComputerStatus(new CSVComputerStatus //Try to get the machine sid for the computer if its not already cached SecurityIdentifier machineSid; if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) { - var getMachineSidResult = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken)); + var getMachineSidResult = await _getMachineSidAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken)); if (getMachineSidResult.IsFailed) { _log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError); @@ -105,7 +112,7 @@ await SendComputerStatus(new CSVComputerStatus } //Get all available domains in the server - var getDomainsResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetDomains()); + var getDomainsResult = await _getDomainsAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetDomains()); if (getDomainsResult.IsFailed) { _log.LogTrace("GetDomains failed on {ComputerName}: {Error}", computerName, getDomainsResult.SError); @@ -126,7 +133,7 @@ await SendComputerStatus(new CSVComputerStatus continue; //Open a handle to the domain - var openDomainResult = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => server.OpenDomain(domainResult.Name, cancellationToken: timeoutToken)); + var openDomainResult = await _openDomainAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.OpenDomain(domainResult.Name, cancellationToken: timeoutToken)); if (openDomainResult.IsFailed) { _log.LogTrace("Failed to open domain {Domain} on {ComputerName}: {Error}", domainResult.Name, computerName, openDomainResult.SError); @@ -145,7 +152,7 @@ await SendComputerStatus(new CSVComputerStatus var domain = openDomainResult.Value; //Open a handle to the available aliases - var getAliasesResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => domain.GetAliases()); + var getAliasesResult = await _getAliasesAdaptiveTimeout.ExecuteRPCWithTimeout((_) => domain.GetAliases()); if (getAliasesResult.IsFailed) { @@ -178,7 +185,7 @@ await SendComputerStatus(new CSVComputerStatus }; //Open a handle to the alias - var openAliasResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => domain.OpenAlias(alias.Rid)); + var openAliasResult = await _openAliasAdaptiveTimeout.ExecuteRPCWithTimeout((_) => domain.OpenAlias(alias.Rid)); if (openAliasResult.IsFailed) { _log.LogTrace("Failed to open alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error); @@ -199,7 +206,7 @@ await SendComputerStatus(new CSVComputerStatus var localGroup = openAliasResult.Value; //Call GetMembersInAlias to get raw group members - var getMembersResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => localGroup.GetMembers()); + var getMembersResult = await _getMembersAdaptiveTimeout.ExecuteRPCWithTimeout((_) => localGroup.GetMembers()); if (getMembersResult.IsFailed) { _log.LogTrace("Failed to get members in alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error); @@ -274,7 +281,7 @@ await SendComputerStatus(new CSVComputerStatus } //Attempt to lookup the principal in the server directly - var lookupUserResult = await Timeout.ExecuteRPCWithTimeout(timeout, timeoutToken => server.LookupPrincipalBySid(securityIdentifier, timeoutToken)); + var lookupUserResult = await _lookupPrincipalBySidAdaptiveTimeout.ExecuteRPCWithTimeout(timeoutToken => server.LookupPrincipalBySid(securityIdentifier, timeoutToken)); if (lookupUserResult.IsFailed) { _log.LogTrace("Unable to resolve local sid {SID}: {Error}", sidValue, lookupUserResult.SError); diff --git a/src/CommonLib/Processors/PortScanner.cs b/src/CommonLib/Processors/PortScanner.cs index 3a55cfb4a..b294b308e 100644 --- a/src/CommonLib/Processors/PortScanner.cs +++ b/src/CommonLib/Processors/PortScanner.cs @@ -9,13 +9,15 @@ namespace SharpHoundCommonLib.Processors { public class PortScanner : IPortScanner { private static readonly ConcurrentDictionary PortScanCache = new(); private readonly ILogger _log; + private readonly AdaptiveTimeout _adaptiveTimeout; - public PortScanner() { - _log = Logging.LogProvider.CreateLogger("PortScanner"); + public PortScanner() : this(null) { + } - public PortScanner(ILogger log = null) { + public PortScanner(ILogger log = null, int maxTimeout = 10000) { _log = log ?? Logging.LogProvider.CreateLogger("PortScanner"); + _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(maxTimeout), _log); } /// @@ -26,7 +28,7 @@ public PortScanner(ILogger log = null) { /// Timeout in milliseconds /// Used as a switch to control if this method should throw exceptions that occur. /// True if port is open, otherwise false - public virtual async Task CheckPort(string hostname, int port = 445, int timeout = 10000, + public virtual async Task CheckPort(string hostname, int port = 445, bool throwError = false) { var key = new PingCacheKey { Port = port, @@ -40,13 +42,11 @@ public virtual async Task CheckPort(string hostname, int port = 445, int t try { using var client = new TcpClient(); - // Blocking External Call - var ca = await Timeout.ExecuteWithTimeout(TimeSpan.FromMilliseconds(timeout), (_) => client.ConnectAsync(hostname, port)); + var ca = await _adaptiveTimeout.ExecuteWithTimeout((_) => client.ConnectAsync(hostname, port)); if (!ca.IsSuccess) { - _log.LogDebug("{HostName} did not respond to scan on port {Port} within {Timeout}ms", hostname, port, - timeout); + _log.LogDebug("{HostName} did not respond to scan on port {Port} within {Timeout}ms", hostname, port, _adaptiveTimeout.GetAdaptiveTimeout()); if (throwError) { - throw new TimeoutException("Timed Out"); + throw new TimeoutException(ca.Error); } PortScanCache.TryAdd(key, false); return false; diff --git a/src/CommonLib/Processors/SmbProcessor.cs b/src/CommonLib/Processors/SmbProcessor.cs index 43862113b..44c2d7116 100644 --- a/src/CommonLib/Processors/SmbProcessor.cs +++ b/src/CommonLib/Processors/SmbProcessor.cs @@ -16,22 +16,17 @@ public class SmbProcessor public delegate Task ComputerStatusDelegate(CSVComputerStatus status); private readonly ILogger _log; private readonly ISmbScanner _smbScanner; - private readonly int _timeoutMs; - - public SmbProcessor(int timeoutMs, ISmbScanner smbScanner = null, ILogger log = null) - { - _timeoutMs = timeoutMs; + private readonly AdaptiveTimeout _scanHostAdaptiveTimeout; + + public SmbProcessor(int timeoutMs, ISmbScanner smbScanner = null, ILogger log = null) { _log = log ?? Logging.LogProvider.CreateLogger("SmbProcessor"); - _smbScanner = smbScanner ?? new SmbScanner(_log) { TimeoutMs = _timeoutMs }; + _smbScanner = smbScanner ?? new SmbScanner(_log) { MaxTimeoutMs = timeoutMs }; + _scanHostAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(timeoutMs), Logging.LogProvider.CreateLogger(nameof(ISmbScanner.ScanHost))); } public event ComputerStatusDelegate ComputerStatusEvent; - public virtual async Task> Scan(string host, TimeSpan timeout = default) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - - var result = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => _smbScanner.ScanHost(host, 445, timeoutToken)); + public virtual async Task> Scan(string host) { + var result = await _scanHostAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => _smbScanner.ScanHost(host, 445, timeoutToken)); if (result.IsFailed) { await SendComputerStatus(new CSVComputerStatus { diff --git a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs index e23def5b5..3a3d4f03d 100644 --- a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs +++ b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs @@ -5,7 +5,6 @@ using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Enums; using SharpHoundCommonLib.OutputTypes; -using SharpHoundRPC; using SharpHoundRPC.Shared; using SharpHoundRPC.Wrappers; @@ -15,10 +14,16 @@ public class UserRightsAssignmentProcessor { private readonly ILogger _log; private readonly ILdapUtils _utils; + private readonly AdaptiveTimeout _openLSAPolicyAdaptiveTimeout; + private readonly AdaptiveTimeout _getLocalDomainInfoAdaptiveTimeout; + private readonly AdaptiveTimeout _getResolvedPrincipalWithPriviledgeAdaptiveTimeout; public UserRightsAssignmentProcessor(ILdapUtils utils, ILogger log = null) { _utils = utils; _log = log ?? Logging.LogProvider.CreateLogger("UserRightsAssignmentProcessor"); + _openLSAPolicyAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(OpenLSAPolicy))); + _getLocalDomainInfoAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ILSAPolicy.GetLocalDomainInformation))); + _getResolvedPrincipalWithPriviledgeAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ILSAPolicy.GetResolvedPrincipalsWithPrivilege))); } public event ComputerStatusDelegate ComputerStatusEvent; @@ -47,13 +52,8 @@ public IAsyncEnumerable GetUserRightsAssignments( /// /// public async IAsyncEnumerable GetUserRightsAssignments(string computerName, - string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null, - TimeSpan timeout = default) { - if (timeout == default) { - timeout = TimeSpan.FromMinutes(2); - } - - var policyOpenResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => OpenLSAPolicy(computerName)); + string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null) { + var policyOpenResult = await _openLSAPolicyAdaptiveTimeout.ExecuteRPCWithTimeout((_) => OpenLSAPolicy(computerName)); if (!policyOpenResult.IsSuccess) { _log.LogDebug("LSAOpenPolicy failed on {ComputerName} with status {Status}", computerName, policyOpenResult.Error); @@ -71,7 +71,7 @@ await SendComputerStatus(new CSVComputerStatus { SecurityIdentifier machineSid; if (!Cache.GetMachineSid(computerObjectId, out var temp)) { var getMachineSidResult = - await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetLocalDomainInformation()); + await _getLocalDomainInfoAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetLocalDomainInformation()); if (getMachineSidResult.IsFailed) { _log.LogWarning("Failed to get machine sid for {Server}: {Status}. Abandoning URA collection", computerName, getMachineSidResult.SError); @@ -98,7 +98,7 @@ await SendComputerStatus(new CSVComputerStatus { }; //Ask for all principals with the specified privilege. - var enumerateAccountsResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetResolvedPrincipalsWithPrivilege(privilege)); + var enumerateAccountsResult = await _getResolvedPrincipalWithPriviledgeAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetResolvedPrincipalsWithPrivilege(privilege)); if (enumerateAccountsResult.IsFailed) { _log.LogDebug( "LSAEnumerateAccountsWithUserRight failed on {ComputerName} with status {Status} for privilege {Privilege}", diff --git a/src/CommonLib/SMB/SmbScanner.cs b/src/CommonLib/SMB/SmbScanner.cs index cf88813b0..02b694f84 100644 --- a/src/CommonLib/SMB/SmbScanner.cs +++ b/src/CommonLib/SMB/SmbScanner.cs @@ -27,12 +27,14 @@ public class SmbScanner : ISmbScanner { /// /// Timeout value used when connecting to hosts or waiting for a response. /// - public int TimeoutMs { get; set; } = 2000; + public int MaxTimeoutMs { get; set; } = 2000; + public readonly AdaptiveTimeout _adaptiveTimeout; public ILogger _log; public SmbScanner(ILogger log) { - _log = log ?? Logging.LogProvider.CreateLogger("SmbScanner"); ; + _log = log ?? Logging.LogProvider.CreateLogger("SmbScanner"); + _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(MaxTimeoutMs), Logging.LogProvider.CreateLogger(nameof(TrySMBNegotiate))); } @@ -125,7 +127,7 @@ private SharpHoundRPC.Result CheckRegistrySigningRequired(string ho negoReqBytes = negotiateRequest.ToBytes(); } - var negoResp = await Timeout.ExecuteWithTimeout(TimeSpan.FromMilliseconds(TimeoutMs), (timeoutToken) => SendAndReceiveData(host, port, negoReqBytes, timeoutToken)); + var negoResp = await _adaptiveTimeout.ExecuteWithTimeout((timeoutToken) => SendAndReceiveData(host, port, negoReqBytes, timeoutToken)); if (!negoResp.IsSuccess) throw new OperationCanceledException("Connection attempt timed out"); diff --git a/src/CommonLib/Timeout.cs b/src/CommonLib/Timeout.cs index 5373918c4..55da9180f 100644 --- a/src/CommonLib/Timeout.cs +++ b/src/CommonLib/Timeout.cs @@ -32,7 +32,7 @@ public static async Task> ExecuteWithTimeout(TimeSpan timeout, Func if (parentToken.IsCancellationRequested) return Result.Fail("Cancellation requested"); else - return Result.Fail("Timeout"); + return Result.Fail($"Timeout after {timeout.TotalMilliseconds} ms"); } /// @@ -61,7 +61,7 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Action> ExecuteWithTimeout(TimeSpan timeout, Func if (parentToken.IsCancellationRequested) return Result.Fail("Cancellation requested"); else - return Result.Fail("Timeout"); + return Result.Fail($"Timeout after {timeout.TotalMilliseconds} ms"); } /// @@ -126,7 +126,7 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Func @@ -137,8 +137,8 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Func /// /// - public static async Task> ExecuteNetAPIWithTimeout(TimeSpan timeout, Func> func) { - var result = await ExecuteWithTimeout(timeout, func); + public static async Task> ExecuteNetAPIWithTimeout(TimeSpan timeout, Func> func, CancellationToken parentToken = default) { + var result = await ExecuteWithTimeout(timeout, func, parentToken); if (result.IsSuccess) return result.Value; else @@ -153,8 +153,8 @@ public static async Task> ExecuteNetAPIWithTimeout(TimeSpan t /// /// /// - public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func> func) { - var result = await ExecuteWithTimeout(timeout, func); + public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func> func, CancellationToken parentToken = default) { + var result = await ExecuteWithTimeout(timeout, func, parentToken); if (result.IsSuccess) return result.Value; else @@ -169,8 +169,8 @@ public static async Task> ExecuteNetAPIWithTimeout(TimeSpan t /// /// /// - public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func>> func) { - var result = await ExecuteWithTimeout(timeout, func); + public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func>> func, CancellationToken parentToken = default) { + var result = await ExecuteWithTimeout(timeout, func, parentToken); if (result.IsSuccess) return result.Value; else diff --git a/src/SharpHoundRPC/PortScanner/IPortScanner.cs b/src/SharpHoundRPC/PortScanner/IPortScanner.cs index 833108078..78624596d 100644 --- a/src/SharpHoundRPC/PortScanner/IPortScanner.cs +++ b/src/SharpHoundRPC/PortScanner/IPortScanner.cs @@ -2,6 +2,6 @@ namespace SharpHoundRPC.PortScanner { public interface IPortScanner { - Task CheckPort(string hostname, int port = 445, int timeout = 10000, bool throwError = false); + Task CheckPort(string hostname, int port = 445, bool throwError = false); } } \ No newline at end of file diff --git a/test/unit/AdaptiveTimeoutTest.cs b/test/unit/AdaptiveTimeoutTest.cs new file mode 100644 index 000000000..c86158daa --- /dev/null +++ b/test/unit/AdaptiveTimeoutTest.cs @@ -0,0 +1,69 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using SharpHoundCommonLib; +using Xunit; +using Xunit.Abstractions; + +namespace CommonLibTest; + +public class AdaptiveTimeoutTest { + private readonly ITestOutputHelper _testOutputHelper; + + public AdaptiveTimeoutTest(ITestOutputHelper testOutputHelper) { + _testOutputHelper = testOutputHelper; + } + + [Fact] + public async Task AdaptiveTimeout_GetAdaptiveTimeout_NotEnoughSamplesAsync() { + var maxTimeout = TimeSpan.FromSeconds(1); + var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3); + + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + + var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); + Assert.Equal(maxTimeout, adaptiveTimeoutResult); + } + + [Fact] + public async Task AdaptiveTimeout_GetAdaptiveTimeout_AdaptiveDisabled() { + var maxTimeout = TimeSpan.FromSeconds(1); + var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3, false); + + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + + var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); + Assert.Equal(maxTimeout, adaptiveTimeoutResult); + } + + [Fact] + public async Task AdaptiveTimeout_GetAdaptiveTimeout() { + var maxTimeout = TimeSpan.FromSeconds(1); + var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3); + + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(40)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(60)); + + var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); + Assert.True(adaptiveTimeoutResult < maxTimeout); + } + + [Fact] + public async Task AdaptiveTimeout_GetAdaptiveTimeout_TimeSpikeSafetyValve() { + var maxTimeout = TimeSpan.FromSeconds(1); + var numSamples = 100; + var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), numSamples, 1000, 10); + + for (int i = 0; i < numSamples; i++) + await adaptiveTimeout.ExecuteWithTimeout((_) => Thread.Sleep(10)); + + for (int i = 0; i < 6; i++) + await adaptiveTimeout.ExecuteWithTimeout((_) => Thread.Sleep(200)); + + var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); + Assert.Equal(maxTimeout, adaptiveTimeoutResult); + } +} \ No newline at end of file diff --git a/test/unit/ComputerAvailabilityTests.cs b/test/unit/ComputerAvailabilityTests.cs index 53127eb50..cd28f8849 100644 --- a/test/unit/ComputerAvailabilityTests.cs +++ b/test/unit/ComputerAvailabilityTests.cs @@ -18,12 +18,12 @@ public ComputerAvailabilityTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; var m = new Mock(); - m.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + m.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(Task.FromResult(false)); _falsePortScanner = m.Object; var m2 = new Mock(); - m2.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + m2.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(Task.FromResult(true)); _truePortScanner = m2.Object; } diff --git a/test/unit/ComputerSessionProcessorTest.cs b/test/unit/ComputerSessionProcessorTest.cs index af4d4ec97..9c772ff6b 100644 --- a/test/unit/ComputerSessionProcessorTest.cs +++ b/test/unit/ComputerSessionProcessorTest.cs @@ -216,45 +216,45 @@ public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringW Assert.Equal(expected, test.Results); } - [Fact] - public async Task ComputerSessionProcessor_TestTimeout() { - var nativeMethods = new Mock(); - nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(() => { - Task.Delay(1000).Wait(); - return Array.Empty(); - }); - var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); - var receivedStatus = new List(); - var machineDomainSid = $"{Consts.MockDomainSid}-1000"; - processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; }; - var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local", - TimeSpan.FromMilliseconds(1)); - Assert.Empty(results.Results); - Assert.Single(receivedStatus); - var status = receivedStatus[0]; - Assert.Equal("Timeout", status.Status); - } - - [Fact] - public async Task ComputerSessionProcessor_TestTimeoutPrivileged() { - var nativeMethods = new Mock(); - nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(() => { - Task.Delay(1000).Wait(); - return Array.Empty(); - }); - var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); - var receivedStatus = new List(); - var machineDomainSid = $"{Consts.MockDomainSid}-1000"; - processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; }; - - var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid, - "testlab.local", - TimeSpan.FromMilliseconds(1)); - Assert.Empty(results.Results); - Assert.Single(receivedStatus); - var status = receivedStatus[0]; - Assert.Equal("Timeout", status.Status); - } + // Obsolete by AdaptiveTimeout + // [Fact] + // public async Task ComputerSessionProcessor_TestTimeout() { + // var nativeMethods = new Mock(); + // nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(() => { + // Task.Delay(1000).Wait(); + // return Array.Empty(); + // }); + // var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); + // var receivedStatus = new List(); + // var machineDomainSid = $"{Consts.MockDomainSid}-1000"; + // processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; }; + // var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local"); + // Assert.Empty(results.Results); + // Assert.Single(receivedStatus); + // var status = receivedStatus[0]; + // Assert.Equal("Timeout", status.Status); + // } + + // Obsolete by AdaptiveTimeout + // [Fact] + // public async Task ComputerSessionProcessor_TestTimeoutPrivileged() { + // var nativeMethods = new Mock(); + // nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(() => { + // Task.Delay(1000).Wait(); + // return Array.Empty(); + // }); + // var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); + // var receivedStatus = new List(); + // var machineDomainSid = $"{Consts.MockDomainSid}-1000"; + // processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; }; + + // var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid, + // "testlab.local"); + // Assert.Empty(results.Results); + // Assert.Single(receivedStatus); + // var status = receivedStatus[0]; + // Assert.Equal("Timeout", status.Status); + // } [Fact] public async Task ComputerSessionProcessor_ReadUserSessionSendsComputerStatus() diff --git a/test/unit/DCLdapProcessorTest.cs b/test/unit/DCLdapProcessorTest.cs index 418d5ec95..e8ae0a653 100644 --- a/test/unit/DCLdapProcessorTest.cs +++ b/test/unit/DCLdapProcessorTest.cs @@ -30,7 +30,7 @@ public void Dispose() { [Fact] public async Task DCLdapProcessor_Scan() { - var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null); + var mockProcessor = new Mock(10, "primary.testlab.local", null); mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).ReturnsAsync(false); @@ -39,10 +39,11 @@ public async Task DCLdapProcessor_Scan() { var processor = mockProcessor.Object; var receivedStatus = new List(); - processor.ComputerStatusEvent += async status => { + processor.ComputerStatusEvent += status => { receivedStatus.Add(status); + return Task.CompletedTask; }; - var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMinutes(2)); + var results = await processor.Scan("primary.testlab.local"); Assert.Equal(2, receivedStatus.Count); var status = receivedStatus[0]; @@ -57,7 +58,7 @@ public async Task DCLdapProcessor_Scan() { [Fact] public async Task DCLdapProcessor_Scan_Failed() { - var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null); + var mockProcessor = new Mock(10, "primary.testlab.local", null); mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Throws(new Exception("Error")); @@ -66,10 +67,11 @@ public async Task DCLdapProcessor_Scan_Failed() { var processor = mockProcessor.Object; var receivedStatus = new List(); - processor.ComputerStatusEvent += async status => { + processor.ComputerStatusEvent += status => { receivedStatus.Add(status); + return Task.CompletedTask; }; - var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMinutes(2)); + var results = await processor.Scan("primary.testlab.local"); Assert.Equal(2, receivedStatus.Count); var status = receivedStatus[0]; @@ -84,38 +86,39 @@ public async Task DCLdapProcessor_Scan_Failed() { Assert.False(results.IsChannelBindingDisabled.Collected); } - [Fact] - public async Task DCLdapProcessor_CheckScan_Timeout() { - var mockProcessor = new Mock(2, "primary.testlab.local", null); - - mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Returns(async () => { - await Task.Delay(100); - return false; - }); - - mockProcessor.Setup(x => x.TestLdapPort()).ReturnsAsync(true); - mockProcessor.Setup(x => x.TestLdapsPort()).ReturnsAsync(true); - - var processor = mockProcessor.Object; - var receivedStatus = new List(); - processor.ComputerStatusEvent += status => { - receivedStatus.Add(status); - return Task.CompletedTask; - }; - var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMilliseconds(1)); - - Assert.Equal(2, receivedStatus.Count); - var status = receivedStatus[0]; - Assert.Equal("Timeout", status.Status); - status = receivedStatus[1]; - Assert.Equal("Timeout", status.Status); - Assert.Equal("Timeout", results.IsSigningRequired.FailureReason); - Assert.Equal("Timeout", results.IsChannelBindingDisabled.FailureReason); - } + // Obsolete by AdaptiveTimeout + // [Fact] + // public async Task DCLdapProcessor_CheckScan_Timeout() { + // var mockProcessor = new Mock(2, "primary.testlab.local", null); + + // mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Returns(async () => { + // await Task.Delay(100); + // return false; + // }); + + // mockProcessor.Setup(x => x.TestLdapPort()).ReturnsAsync(true); + // mockProcessor.Setup(x => x.TestLdapsPort()).ReturnsAsync(true); + + // var processor = mockProcessor.Object; + // var receivedStatus = new List(); + // processor.ComputerStatusEvent += status => { + // receivedStatus.Add(status); + // return Task.CompletedTask; + // }; + // var results = await processor.Scan("primary.testlab.local"); + + // Assert.Equal(2, receivedStatus.Count); + // var status = receivedStatus[0]; + // Assert.Equal("Timeout", status.Status); + // status = receivedStatus[1]; + // Assert.Equal("Timeout", status.Status); + // Assert.Equal("Timeout", results.IsSigningRequired.FailureReason); + // Assert.Equal("Timeout", results.IsChannelBindingDisabled.FailureReason); + // } [Fact] public async Task DCLdapProcessor_CheckIsNtlmSigningRequired() { - var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null); + var mockProcessor = new Mock(10, "primary.testlab.local", null); mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).ReturnsAsync(false); var processor = mockProcessor.Object; var result = await processor.CheckIsNtlmSigningRequired(); @@ -125,7 +128,7 @@ public async Task DCLdapProcessor_CheckIsNtlmSigningRequired() { [Fact] public async Task DCLdapProcessor_CheckIsNtlmSigningRequired_Exception() { - var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null); + var mockProcessor = new Mock(10, "primary.testlab.local", null); mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Throws(new Exception("Error")); var processor = mockProcessor.Object; var result = await processor.CheckIsNtlmSigningRequired(); @@ -142,7 +145,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_SEC_E var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException("Error", (int)LdapErrorCodes.InvalidCredentials, SEC_E_UNSUPPORTED_FUNCTION)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLogContains(LogLevel.Debug, expected); @@ -157,7 +160,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_SEC_E var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException("Error", (int)LdapErrorCodes.InvalidCredentials, SEC_E_BAD_BINDINGS)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLogContains(LogLevel.Debug, expected); @@ -172,7 +175,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_Unhan var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.InvalidCredentials, "80090347")); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLogContains(LogLevel.Error, expected); @@ -187,7 +190,7 @@ public async Task DCLdapProcessor_Authenticate_StrongAuthRequiredException() { var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.StrongAuthRequired, null)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLog(LogLevel.Debug, expected); @@ -202,7 +205,7 @@ public async Task DCLdapProcessor_Authenticate_ServerDownException() { var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.ServerDown)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLog(LogLevel.Debug, expected); @@ -217,7 +220,7 @@ public async Task DCLdapProcessor_Authenticate_LdapUnhandledException() { var mockLogger = new Mock>(); var mockLdapTransport = new Mock(null, It.IsAny()); mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.LocalError)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLogContains(LogLevel.Error, expected); @@ -235,7 +238,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidOperationException() { mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Verifiable(); mockAuthenticator.Setup(x => x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())) .Throws(new InvalidOperationException(exception)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), mockAuthenticator.Object, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLog(LogLevel.Debug, expected); @@ -253,7 +256,7 @@ public async Task DCLdapProcessor_Authenticate_UnhandledException() { mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Verifiable(); mockAuthenticator.Setup(x => x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())) .Throws(new Exception(exception)); - var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object); + var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object); var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), mockAuthenticator.Object, mockLdapTransport.Object); Assert.False(result); mockLogger.VerifyLogContains(LogLevel.Error, expected); diff --git a/test/unit/HttpNtlmAuthenticationServiceTest.cs b/test/unit/HttpNtlmAuthenticationServiceTest.cs index e3273de9d..e78d654df 100644 --- a/test/unit/HttpNtlmAuthenticationServiceTest.cs +++ b/test/unit/HttpNtlmAuthenticationServiceTest.cs @@ -77,39 +77,41 @@ public void HttpNtlmAuthenticationService_ExtractAuthSchemes_Success() { Assert.Equal("Negotiate", result[1]); } - [Fact] - public void HttpNtlmAuthenticationService_EnsureRequiresAuth_GetSupportedNtlmAuthSchemesAsync_Timeout() { - var url = new Uri("http://primary.testlab.local/"); - var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null); - var ex = Assert.ThrowsAsync(() => - service.EnsureRequiresAuth(url, true, TimeSpan.FromMilliseconds(1))); - Assert.Equal($"Timeout getting supported NTLM auth schemes for {url}", ex.Result.Message); + // Obsolete by AdaptiveTimeout + // [Fact] + // public void HttpNtlmAuthenticationService_EnsureRequiresAuth_GetSupportedNtlmAuthSchemesAsync_Timeout() { + // var url = new Uri("http://primary.testlab.local/"); + // var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null); + // var ex = Assert.ThrowsAsync(() => + // service.EnsureRequiresAuth(url, true)); + // Assert.Equal($"Timeout getting supported NTLM auth schemes for {url}", ex.Result.Message); - } + // } - [Fact] - public void HttpNtlmAuthenticationService_AuthWithBadChannelBindingsAsync_Timeout() { - var url = new Uri("http://primary.testlab.local/"); - var authScheme = "NTLM"; - var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null); - var httpResponseMessage = new HttpResponseMessage { - StatusCode = HttpStatusCode.InternalServerError, - }; - var mockAuthenticator = new Mock(It.IsAny(), null); - mockAuthenticator.Setup(x => - x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())).Returns(async () => { - await Task.Delay(1000); - return httpResponseMessage; - }); - - var ex = Assert.ThrowsAsync(async () => await TestPrivateMethod.InstanceMethod(service, - "AuthWithBadChannelBindingsAsync", - [ - url, authScheme, TimeSpan.FromMilliseconds(1), mockAuthenticator.Object - ])); - Assert.Equal($"Timeout during NTLM authentication for {url} with {authScheme}", ex.Result.Message); + // Obsolete by AdaptiveTimeout + // [Fact] + // public void HttpNtlmAuthenticationService_AuthWithBadChannelBindingsAsync_Timeout() { + // var url = new Uri("http://primary.testlab.local/"); + // var authScheme = "NTLM"; + // var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null); + // var httpResponseMessage = new HttpResponseMessage { + // StatusCode = HttpStatusCode.InternalServerError, + // }; + // var mockAuthenticator = new Mock(It.IsAny(), null); + // mockAuthenticator.Setup(x => + // x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())).Returns(async () => { + // await Task.Delay(1000); + // return httpResponseMessage; + // }); - } + // var ex = Assert.ThrowsAsync(async () => await TestPrivateMethod.InstanceMethod(service, + // "AuthWithBadChannelBindingsAsync", + // [ + // url, authScheme, TimeSpan.FromMilliseconds(1), mockAuthenticator.Object + // ])); + // Assert.Equal($"Timeout during NTLM authentication for {url} with {authScheme}", ex.Result.Message); + + // } //// Throws "no such host is known" exception // [Fact] diff --git a/test/unit/LocalGroupProcessorTest.cs b/test/unit/LocalGroupProcessorTest.cs index 51ded05f9..41e54c12a 100644 --- a/test/unit/LocalGroupProcessorTest.cs +++ b/test/unit/LocalGroupProcessorTest.cs @@ -29,7 +29,7 @@ public void Dispose() { public async Task LocalGroupProcessor_TestWorkstation() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockWorkstationSAMServer(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1001"; var results = await processor.GetLocalGroups("win10.testlab.local", machineDomainSid, "TESTLAB.LOCAL", false) @@ -57,7 +57,7 @@ public async Task LocalGroupProcessor_TestWorkstation() { public async Task LocalGroupProcessor_TestDomainController() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockDCSAMServer(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; @@ -162,7 +162,7 @@ public async Task LocalGroupProcessor_TestTimeout() { var mockUtils = new Mock(); var mockProcessor = new Mock(mockUtils.Object, null); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(() => { + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(() => { return SharpHoundRPC.Result.Fail("Timeout"); }); var processor = mockProcessor.Object; @@ -172,7 +172,7 @@ public async Task LocalGroupProcessor_TestTimeout() { receivedStatus.Add(status); return Task.CompletedTask; }; - var results = await processor.GetLocalGroups("primary.testlab.local", machineDomainSid, "testlab.local", true, TimeSpan.FromMilliseconds(10)) + var results = await processor.GetLocalGroups("primary.testlab.local", machineDomainSid, "testlab.local", true) .ToArrayAsync(); Assert.Empty(results); Assert.Single(receivedStatus); @@ -184,7 +184,7 @@ public async Task LocalGroupProcessor_TestTimeout() { public async Task LocalGroupProcessor_GetLocalGroups_GetMachineSidResultFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_GetMachineSid(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -205,7 +205,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_GetMachineSidResultFailed() public async Task LocalGroupProcessor_GetLocalGroups_GetDomainsResultFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_GetDomains(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -226,7 +226,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_GetDomainsResultFailed() { public async Task LocalGroupProcessor_GetLocalGroups_OpenDomainResultFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_OpenDomain(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -247,7 +247,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_OpenDomainResultFailed() { public async Task LocalGroupProcessor_GetLocalGroups_GetAliasesFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_GetAliases(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -268,7 +268,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_GetAliasesFailed() { public async Task LocalGroupProcessor_GetLocalGroups_OpenAliasFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_OpenAlias(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -291,7 +291,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_OpenAliasFailed() { public async Task LocalGroupProcessor_GetLocalGroups_GetMembersFailed() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_GetMembers(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -314,7 +314,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_GetMembersFailed() { public async Task LocalGroupProcessor_GetLocalGroups_LookupPrincipalBySid() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_LookupPrincipalBySid(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000"; var receivedStatus = new List(); @@ -343,7 +343,7 @@ public async Task LocalGroupProcessor_GetLocalGroups_LookupPrincipalBySid() { public async Task LocalGroupProcessor_GetLocalGroups_PreviouslyCached() { var mockProcessor = new Mock(new MockLdapUtils(), null); var mockSamServer = new MockFailSAMServer_PreviouslyCached(); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer); + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1001"; var results = await processor.GetLocalGroups("win10.testlab.local", machineDomainSid, "TESTLAB.LOCAL", false) diff --git a/test/unit/SmbProcessorTest.cs b/test/unit/SmbProcessorTest.cs index 8ba0fb5a2..4b037330b 100644 --- a/test/unit/SmbProcessorTest.cs +++ b/test/unit/SmbProcessorTest.cs @@ -37,12 +37,15 @@ public async Task SmbProcessor_TestTimeout() { var mockProcessor = new SmbProcessor(2, mockSmbScanner.Object); var receivedStatus = new List(); - mockProcessor.ComputerStatusEvent += async status => receivedStatus.Add(status); - var results = await mockProcessor.Scan("primary.testlab.local",TimeSpan.FromMilliseconds(1)); + mockProcessor.ComputerStatusEvent += status => { + receivedStatus.Add(status); + return Task.CompletedTask; + }; + var results = await mockProcessor.Scan("primary.testlab.local"); Assert.Single(receivedStatus); var status = receivedStatus[0]; - Assert.Equal("Timeout", status.Status); + Assert.StartsWith("Timeout", status.Status); } } } \ No newline at end of file diff --git a/test/unit/TimeoutTests.cs b/test/unit/TimeoutTests.cs index b108847ea..d5b06a204 100644 --- a/test/unit/TimeoutTests.cs +++ b/test/unit/TimeoutTests.cs @@ -26,7 +26,7 @@ public async Task ExecuteWithTimeout_Timeout() { }; var result = await SharpHoundCommonLib.Timeout.ExecuteWithTimeout(timeout, func); Assert.False(result.IsSuccess); - Assert.Equal("Timeout", result.Error); + Assert.StartsWith("Timeout", result.Error); } [Fact] @@ -85,7 +85,7 @@ public async Task ExecuteWithTimeout_T_Timeout() { }; var result = await SharpHoundCommonLib.Timeout.ExecuteWithTimeout(timeout, func); Assert.False(result.IsSuccess); - Assert.Equal("Timeout", result.Error); + Assert.StartsWith("Timeout", result.Error); } [Fact] @@ -147,7 +147,7 @@ public async Task ExecuteWithTimeout_Task_Timeout() { }; var result = await SharpHoundCommonLib.Timeout.ExecuteWithTimeout(timeout, func); Assert.False(result.IsSuccess); - Assert.Equal("Timeout", result.Error); + Assert.StartsWith("Timeout", result.Error); } [Fact] @@ -207,7 +207,7 @@ public async Task ExecuteWithTimeout_Task_T_Timeout() { }; var result = await SharpHoundCommonLib.Timeout.ExecuteWithTimeout(timeout, func); Assert.False(result.IsSuccess); - Assert.Equal("Timeout", result.Error); + Assert.StartsWith("Timeout", result.Error); } [Fact] diff --git a/test/unit/UserRightsAssignmentProcessorTest.cs b/test/unit/UserRightsAssignmentProcessorTest.cs index 0f4f80cf5..459c06bf4 100644 --- a/test/unit/UserRightsAssignmentProcessorTest.cs +++ b/test/unit/UserRightsAssignmentProcessorTest.cs @@ -68,27 +68,29 @@ public async Task UserRightsAssignmentProcessor_TestDC() Assert.Equal(Label.Group, adminResult.ObjectType); } - [Fact] - public async Task UserRightsAssignmentProcessor_TestTimeout() { - var mockProcessor = new Mock(new MockLdapUtils(), null); - mockProcessor.Setup(x => x.OpenLSAPolicy(It.IsAny())).Returns(()=> { - Task.Delay(100).Wait(); - return NtStatus.StatusAccessDenied; - }); - var processor = mockProcessor.Object; - var machineDomainSid = $"{Consts.MockDomainSid}-1000"; - var receivedStatus = new List(); - processor.ComputerStatusEvent += async status => { - receivedStatus.Add(status); - }; - var results = await processor.GetUserRightsAssignments("primary.testlab.local", machineDomainSid, "testlab.local", true, null,TimeSpan.FromMilliseconds(1)) - .ToArrayAsync(); - Assert.Empty(results); - Assert.Single(receivedStatus); - var status = receivedStatus[0]; - Assert.Equal("Timeout", status.Status); - } - + // Obsolete by AdaptiveTimeout + // [Fact] + // public async Task UserRightsAssignmentProcessor_TestTimeout() { + // var mockProcessor = new Mock(new MockLdapUtils(), null); + // mockProcessor.Setup(x => x.OpenLSAPolicy(It.IsAny())).Returns(()=> { + // Task.Delay(100).Wait(); + // return NtStatus.StatusAccessDenied; + // }); + // var processor = mockProcessor.Object; + // var machineDomainSid = $"{Consts.MockDomainSid}-1000"; + // var receivedStatus = new List(); + // processor.ComputerStatusEvent += status => { + // receivedStatus.Add(status); + // return Task.CompletedTask; + // }; + // var results = await processor.GetUserRightsAssignments("primary.testlab.local", machineDomainSid, "testlab.local", true, null) + // .ToArrayAsync(); + // Assert.Empty(results); + // Assert.Single(receivedStatus); + // var status = receivedStatus[0]; + // Assert.Equal("Timeout", status.Status); + // } + [WindowsOnlyFact] public async Task UserRightsAssignmentProcessor_TestGetLocalDomainInformationFail() {