diff --git a/src/CommonLib/AdaptiveTimeout.cs b/src/CommonLib/AdaptiveTimeout.cs
new file mode 100644
index 000000000..31aeb9033
--- /dev/null
+++ b/src/CommonLib/AdaptiveTimeout.cs
@@ -0,0 +1,226 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
+using SharpHoundRPC.NetAPINative;
+
+namespace SharpHoundCommonLib;
+
+public sealed class AdaptiveTimeout : IDisposable {
+ private readonly ExecutionTimeSampler _sampler;
+ private readonly ILogger _log;
+ private readonly TimeSpan _maxTimeout;
+ private readonly bool _useAdaptiveTimeout;
+ private readonly int _minSamplesForAdaptiveTimeout;
+ private int _clearSamplesDecay;
+ private const int TimeSpikePenalty = 2;
+ private const int TimeSpikeForgiveness = 1;
+ private const int ClearSamplesThreshold = 5;
+ private const int StdDevMultiplier = 5;
+
+ public AdaptiveTimeout(TimeSpan maxTimeout, ILogger log, int sampleCount = 100, int logFrequency = 1000, int minSamplesForAdaptiveTimeout = 30, bool useAdaptiveTimeout = true) {
+ if (maxTimeout <= TimeSpan.Zero)
+ throw new ArgumentException("maxTimeout must be positive", nameof(maxTimeout));
+ if (sampleCount <= 0)
+ throw new ArgumentException("sampleCount must be positive", nameof(sampleCount));
+ if (logFrequency <= 0)
+ throw new ArgumentException("logFrequency must be positive", nameof(logFrequency));
+ if (minSamplesForAdaptiveTimeout <= 0)
+ throw new ArgumentException("minSamplesForAdaptiveTimeout must be positive", nameof(minSamplesForAdaptiveTimeout));
+ if (log == null)
+ throw new ArgumentNullException(nameof(log));
+
+ _sampler = new ExecutionTimeSampler(log, sampleCount, logFrequency);
+ _log = log;
+ _maxTimeout = maxTimeout;
+ _useAdaptiveTimeout = useAdaptiveTimeout;
+ _minSamplesForAdaptiveTimeout = minSamplesForAdaptiveTimeout;
+ }
+
+ public void ClearSamples() {
+ _clearSamplesDecay = 0;
+ _sampler.ClearSamples();
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task> ExecuteWithTimeout(Func func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task ExecuteWithTimeout(Action func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task> ExecuteWithTimeout(Func> func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task ExecuteWithTimeout(Func func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task> ExecuteNetAPIWithTimeout(Func> func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteNetAPIWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task> ExecuteRPCWithTimeout(Func> func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteRPCWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ ///
+ /// Ignores the result of a function if it runs longer than a budgeted time, unblocking the caller.
+ /// Logs aggregate execution time data.
+ /// Manages its own timeout.
+ /// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
+ /// Please don't wrap a cached function in this timeout if adaptive timeouts enabled, normal distributions are better.
+ /// DO NOT use a single AdaptiveTimeout for multiple functions.
+ ///
+ ///
+ ///
+ ///
+ /// Returns a Fail result if a task runs longer than its budgeted time.
+ public async Task> ExecuteRPCWithTimeout(Func>> func, CancellationToken parentToken = default) {
+ var result = await Timeout.ExecuteRPCWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => func(timeoutToken)), parentToken);
+ TimeSpikeSafetyValve(result.IsSuccess);
+ return result;
+ }
+
+ public void Dispose() {
+ _sampler.Dispose();
+ }
+
+ // Within 5 standard deviations will have a conservative lower bound of catching 96% of executions (1 - 1/5^2),
+ // regardless of sample shape
+ // so long as those samples are independent and identically distributed
+ // (and if they're not, our TimeSpikeSafetyValve should provide us with some adaptability)
+ // But the effective collection rate is probably closer to 98+%
+ // (in part because we don't need to filter out "too fast" outliers)
+ // But we'll cap at configured maximum timeout
+ // https://modelassist.epixanalytics.com/space/EA/26574957/Tchebysheffs+Rule
+ // https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables
+ public TimeSpan GetAdaptiveTimeout() {
+ if (!UseAdaptiveTimeout())
+ return _maxTimeout;
+
+ try {
+ var stdDev = _sampler.StandardDeviation();
+ var adaptiveTimeoutMs = _sampler.Average() + (stdDev * StdDevMultiplier);
+ var cappedTimeoutMS = Math.Min(adaptiveTimeoutMs, _maxTimeout.TotalMilliseconds);
+ return TimeSpan.FromMilliseconds(cappedTimeoutMS);
+ }
+ catch (Exception ex) {
+ _log.LogError(ex, "Error calculating adaptive timeout, defaulting to max timeout.");
+ return _maxTimeout;
+ }
+ }
+
+ // AdaptiveTimeout will not respond well to rapid spikes in execution time
+ // imagine the wrapped function very regularly executes in 10ms
+ // then suddenly starts taking a regular 100ms
+ // this is fine (if it fits in our max timeout budget), and we shouldn't block
+ // so we should create a safety valve in case this happens to reset our data samples
+ private void TimeSpikeSafetyValve(bool isSuccess) {
+ if (isSuccess) {
+ _clearSamplesDecay -= TimeSpikeForgiveness;
+ _clearSamplesDecay = Math.Max(0, _clearSamplesDecay);
+ }
+ else
+ _clearSamplesDecay += TimeSpikePenalty;
+
+
+ if (_clearSamplesDecay >= ClearSamplesThreshold) {
+ if (UseAdaptiveTimeout()) {
+ ClearSamples();
+ _log.LogTrace("Time spike safety valve event at timeout {CurrentTimeout}.", GetAdaptiveTimeout());
+ }
+ else {
+ _log.LogWarning("This call is frequently running over the maximum allowed timeout of {MaxTimeout}.", _maxTimeout);
+ }
+ }
+ }
+
+ private bool UseAdaptiveTimeout() {
+ return _useAdaptiveTimeout && _sampler.Count >= _minSamplesForAdaptiveTimeout;
+ }
+}
\ No newline at end of file
diff --git a/src/CommonLib/ExecutionTimeSampler.cs b/src/CommonLib/ExecutionTimeSampler.cs
new file mode 100644
index 000000000..fc99f7e45
--- /dev/null
+++ b/src/CommonLib/ExecutionTimeSampler.cs
@@ -0,0 +1,104 @@
+using System;
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
+
+namespace SharpHoundCommonLib;
+
+///
+/// Holds a rolling sample of execution times on a function, providing and logging data aggregates.
+///
+public class ExecutionTimeSampler : IDisposable {
+ private readonly ILogger _log;
+ private readonly int _sampleCount;
+ private readonly int _logFrequency;
+ private int _samplesSinceLastLog;
+ private ConcurrentQueue _samples;
+
+ public int Count => _samples.Count;
+
+ public ExecutionTimeSampler(ILogger log, int sampleCount, int logFrequency) {
+ _log = log;
+ _sampleCount = sampleCount;
+ _logFrequency = logFrequency;
+ _samplesSinceLastLog = 0;
+ _samples = new ConcurrentQueue();
+ }
+
+ public void ClearSamples() {
+ Log(flush: true);
+ _samples = new ConcurrentQueue();
+ }
+
+ public double StandardDeviation() {
+ double average = _samples.Average();
+ double sumOfSquaresOfDifferences = _samples.Select(val => (val - average) * (val - average)).Sum();
+ double stddiv = Math.Sqrt(sumOfSquaresOfDifferences / _samples.Count);
+
+ return stddiv;
+ }
+
+ public double Average() => _samples.Average();
+
+ public async Task SampleExecutionTime(Func> func) {
+ var stopwatch = Stopwatch.StartNew();
+ var result = await func.Invoke();
+ stopwatch.Stop();
+ AddTimeSample(stopwatch.Elapsed);
+
+ return result;
+ }
+
+ public async Task SampleExecutionTime(Func func) {
+ var stopwatch = Stopwatch.StartNew();
+ await func.Invoke();
+ stopwatch.Stop();
+ AddTimeSample(stopwatch.Elapsed);
+ }
+
+ public T SampleExecutionTime(Func func) {
+ var stopwatch = Stopwatch.StartNew();
+ var result = func.Invoke();
+ stopwatch.Stop();
+ AddTimeSample(stopwatch.Elapsed);
+
+ return result;
+ }
+
+ public void SampleExecutionTime(Action func) {
+ var stopwatch = Stopwatch.StartNew();
+ func.Invoke();
+ stopwatch.Stop();
+ AddTimeSample(stopwatch.Elapsed);
+ }
+
+ public void Dispose() {
+ Log(flush: true);
+ }
+
+ private void AddTimeSample(TimeSpan timeSpan) {
+ while (_samples.Count >= _sampleCount) {
+ _samples.TryDequeue(out _);
+ }
+
+ _samples.Enqueue(timeSpan.TotalMilliseconds);
+ _samplesSinceLastLog++;
+
+ Log();
+ }
+
+ private void Log(bool flush = false) {
+ if ((flush || _samplesSinceLastLog >= _logFrequency) && _samples.Count > 0) {
+ try {
+ _log.LogInformation("Execution time Average: {Average}ms, StdDiv: {StandardDeviation}ms", _samples.Average(), StandardDeviation());
+ }
+ catch (Exception ex) {
+ _log.LogWarning("Failed to calculate execution time statistics: {Error}", ex.Message);
+ }
+
+ _samplesSinceLastLog = 0;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/CommonLib/IRegistryKey.cs b/src/CommonLib/IRegistryKey.cs
index c61991082..5c1d0a9c7 100644
--- a/src/CommonLib/IRegistryKey.cs
+++ b/src/CommonLib/IRegistryKey.cs
@@ -10,6 +10,7 @@ public interface IRegistryKey {
public class SHRegistryKey : IRegistryKey, IDisposable {
private readonly RegistryKey _currentKey;
+ private static readonly AdaptiveTimeout _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromSeconds(10), Logging.LogProvider.CreateLogger(nameof(SHRegistryKey)));
private SHRegistryKey(RegistryKey registryKey) {
_currentKey = registryKey;
@@ -35,7 +36,7 @@ public object GetValue(string subkey, string name) {
///
///
public static async Task Connect(RegistryHive hive, string machineName) {
- var remoteKey = await Timeout.ExecuteWithTimeout(TimeSpan.FromSeconds(10), (_) => RegistryKey.OpenRemoteBaseKey(hive, machineName));
+ var remoteKey = await _adaptiveTimeout.ExecuteWithTimeout((_) => RegistryKey.OpenRemoteBaseKey(hive, machineName));
if (remoteKey.IsSuccess)
return new SHRegistryKey(remoteKey.Value);
else
diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs
index 3862c68d0..84719438d 100644
--- a/src/CommonLib/LdapConnectionPool.cs
+++ b/src/CommonLib/LdapConnectionPool.cs
@@ -27,6 +27,10 @@ internal class LdapConnectionPool : IDisposable {
private readonly ILogger _log;
private readonly IPortScanner _portScanner;
private readonly NativeMethods _nativeMethods;
+ private readonly AdaptiveTimeout _queryAdaptiveTimeout;
+ private readonly AdaptiveTimeout _pagedQueryAdaptiveTimeout;
+ private readonly AdaptiveTimeout _rangedRetrievalAdaptiveTimeout;
+ private readonly AdaptiveTimeout _testConnectionAdaptiveTimeout;
private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2);
private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20);
private const int BackoffDelayMultiplier = 2;
@@ -54,6 +58,10 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c
_log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool");
_portScanner = scanner ?? new PortScanner();
_nativeMethods = nativeMethods ?? new NativeMethods();
+ _queryAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapQuery"));
+ _pagedQueryAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapPagedQuery"));
+ _rangedRetrievalAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapRangedRetrieval"));
+ _testConnectionAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("TestLdapConnection"));
}
private async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
@@ -100,7 +108,7 @@ public async IAsyncEnumerable> Query(LdapQueryParam
try {
_log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo());
- response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2));
+ response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _queryAdaptiveTimeout);
if (response != null) {
querySuccess = true;
@@ -165,6 +173,16 @@ public async IAsyncEnumerable> Query(LdapQueryParam
var backoffDelay = GetNextBackoff(busyRetryCount);
await Task.Delay(backoffDelay, cancellationToken);
}
+ catch (TimeoutException) when (busyRetryCount < MaxRetries) {
+ /*
+ * Treat a timeout as a busy error
+ */
+ busyRetryCount++;
+ _log.LogDebug("Query - Timeout: Executing busy backoff for query {Info} (Attempt {Count})",
+ queryParameters.GetQueryInfo(), busyRetryCount);
+ var backoffDelay = GetNextBackoff(busyRetryCount);
+ await Task.Delay(backoffDelay, cancellationToken);
+ }
catch (LdapException le) {
/*
* This is our fallback catch. If our retry counts have been exhausted this will trigger and break us out of our loop
@@ -255,7 +273,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery
SearchResponse response = null;
try {
_log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo());
- response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2));
+ response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _pagedQueryAdaptiveTimeout);
if (response != null) {
pageResponse = (PageResultResponseControl)response.Controls
.Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault();
@@ -326,6 +344,16 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery
var backoffDelay = GetNextBackoff(busyRetryCount);
await Task.Delay(backoffDelay, cancellationToken);
}
+ catch (TimeoutException) when (busyRetryCount < MaxRetries) {
+ /*
+ * Treat a timeout as a busy error
+ */
+ busyRetryCount++;
+ _log.LogDebug("PagedQuery - Timeout: Executing busy backoff for query {Info} (Attempt {Count})",
+ queryParameters.GetQueryInfo(), busyRetryCount);
+ var backoffDelay = GetNextBackoff(busyRetryCount);
+ await Task.Delay(backoffDelay, cancellationToken);
+ }
catch (LdapException le) {
tempResult = LdapResult.Fail(
$"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})",
@@ -468,7 +496,7 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish
}
try {
- response = (SearchResponse)await connectionWrapper.Connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2));
+ response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _rangedRetrievalAdaptiveTimeout);
}
catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) {
busyRetryCount++;
@@ -477,6 +505,16 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish
var backoffDelay = GetNextBackoff(busyRetryCount);
await Task.Delay(backoffDelay, cancellationToken);
}
+ catch (TimeoutException) when (busyRetryCount < MaxRetries) {
+ /*
+ * Treat a timeout as a busy error
+ */
+ busyRetryCount++;
+ _log.LogDebug("RangedRetrieval - Timeout: Executing busy backoff for query {Info} (Attempt {Count})",
+ queryParameters.GetQueryInfo(), busyRetryCount);
+ var backoffDelay = GetNextBackoff(busyRetryCount);
+ await Task.Delay(backoffDelay, cancellationToken);
+ }
catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown &&
queryRetryCount < MaxRetries) {
queryRetryCount++;
@@ -741,7 +779,7 @@ public void Dispose() {
}
string tempDomainName;
-
+
// Blocking External Call
var dsGetDcNameResult = _nativeMethods.CallDsGetDcName(null, _identifier,
(uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY |
@@ -941,7 +979,7 @@ private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl
var searchRequest = CreateSearchRequest("", new LdapFilter().AddAllObjects().GetFilter(),
SearchScope.Base, null);
- response = (SearchResponse)await connection.SendRequestAsync(searchRequest, TimeSpan.FromMinutes(2));
+ response = await SendRequestWithTimeout(connection, searchRequest, _testConnectionAdaptiveTimeout);
}
catch (LdapException e) {
/*
@@ -998,5 +1036,17 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF
searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope));
return searchRequest;
}
+
+ private async Task SendRequestWithTimeout(LdapConnection connection, SearchRequest request, AdaptiveTimeout adaptiveTimeout) {
+ // Add padding to account for network latency and processing overhead
+ const int TimeoutPaddingSeconds = 3;
+ var timeout = adaptiveTimeout.GetAdaptiveTimeout();
+ var timeoutWithPadding = timeout + TimeSpan.FromSeconds(TimeoutPaddingSeconds);
+ var result = await adaptiveTimeout.ExecuteWithTimeout((_) => connection.SendRequestAsync(request, timeoutWithPadding));
+ if (result.IsSuccess)
+ return (SearchResponse)result.Value;
+ else
+ throw new TimeoutException($"LDAP {request.Scope} query to '{request.DistinguishedName}' timed out after {timeout.TotalMilliseconds}ms.");
+ }
}
}
\ No newline at end of file
diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs
index 6d847c31d..14612da12 100644
--- a/src/CommonLib/LdapUtils.cs
+++ b/src/CommonLib/LdapUtils.cs
@@ -7,7 +7,6 @@
using System.Linq;
using System.Net;
using System.Net.Sockets;
-using System.Security.Cryptography;
using System.Security.Principal;
using System.Text;
using System.Text.RegularExpressions;
@@ -38,6 +37,10 @@ public class LdapUtils : ILdapUtils {
private static readonly ConcurrentDictionary
SeenWellKnownPrincipals = new();
+ private static readonly AdaptiveTimeout _requestNetBiosNameAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(RequestNETBIOSNameFromComputerAsync)));
+
+ private static readonly AdaptiveTimeout _callNetWkstaGetInfoAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(NativeMethods.CallNetWkstaGetInfo)));
+
private readonly ConcurrentDictionary
_hostResolutionMap = new(StringComparer.OrdinalIgnoreCase);
@@ -713,7 +716,7 @@ public bool GetDomain(out Domain domain) {
}
// Blocking External Call
- var result = await Timeout.ExecuteNetAPIWithTimeout(TimeSpan.FromMinutes(2), (_) => _nativeMethods.CallNetWkstaGetInfo(hostname));
+ var result = await _callNetWkstaGetInfoAdaptiveTimeout.ExecuteNetAPIWithTimeout((_) => _nativeMethods.CallNetWkstaGetInfo(hostname));
if (result.IsSuccess)
return (true, result.Value);
@@ -781,7 +784,7 @@ public bool GetDomain(out Domain domain) {
}
private static async Task<(bool Success, string NetBiosName)> RequestNETBIOSNameFromComputerWithTimeout(string server, string domain) {
- var result = await Timeout.ExecuteWithTimeout(TimeSpan.FromMinutes(1), async (timeoutToken) => await RequestNETBIOSNameFromComputerAsync(server, domain, timeoutToken));
+ var result = await _requestNetBiosNameAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => await RequestNETBIOSNameFromComputerAsync(server, domain, timeoutToken));
if (result.IsSuccess)
return (result.Value.Success, result.Value.NetBiosName);
else
diff --git a/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs b/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs
index e765a11b0..0138a632b 100644
--- a/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs
+++ b/src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs
@@ -16,35 +16,37 @@ namespace SharpHoundCommonLib.Ntlm;
public class HttpNtlmAuthenticationService {
private readonly ILogger _logger;
private readonly IHttpClientFactory _httpClientFactory;
+ private readonly AdaptiveTimeout _getSupportedNTLMAuthSchemesAdaptiveTimeout;
+ private readonly AdaptiveTimeout _ntlmAuthAdaptiveTimeout;
+ private readonly AdaptiveTimeout _authWithChannelBindingAdaptiveTimeout;
public HttpNtlmAuthenticationService(IHttpClientFactory httpClientFactory, ILogger logger = null) {
_logger = logger ?? Logging.LogProvider.CreateLogger(nameof(HttpNtlmAuthenticationService));
_httpClientFactory = httpClientFactory;
+ _getSupportedNTLMAuthSchemesAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(GetSupportedNtlmAuthSchemesAsync)));
+ _ntlmAuthAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(NtlmAuthenticationHandler.PerformNtlmAuthenticationAsync)));
+ _authWithChannelBindingAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(AuthWithBadChannelBindingsAsync)));
}
- public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings, TimeSpan timeout = default) {
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
+ public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings) {
if (url == null)
throw new ArgumentException("Url property is null");
if (useBadChannelBindings == null && url.Scheme == "https")
throw new ArgumentException("When using HTTPS, useBadChannelBindings must be set");
- var supportedAuthSchemes = await GetSupportedNtlmAuthSchemesAsync(url, timeout);
+ var supportedAuthSchemes = await GetSupportedNtlmAuthSchemesAsync(url);
_logger.LogDebug($"Supported NTLM auth schemes for {url}: " + string.Join(",", supportedAuthSchemes));
foreach (var authScheme in supportedAuthSchemes) {
if (useBadChannelBindings == null) {
- await AuthWithBadChannelBindingsAsync(url, authScheme, timeout);
+ await AuthWithBadChannelBindingsAsync(url, authScheme);
} else {
if ((bool)useBadChannelBindings) {
- await AuthWithBadChannelBindingsAsync(url, authScheme, timeout);
+ await AuthWithBadChannelBindingsAsync(url, authScheme);
} else {
- await AuthWithChannelBindingAsync(url, authScheme, timeout);
+ await AuthWithChannelBindingAsync(url, authScheme);
}
}
@@ -53,11 +55,11 @@ public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings, TimeS
}
}
- private async Task GetSupportedNtlmAuthSchemesAsync(Uri url, TimeSpan timeout) {
+ private async Task GetSupportedNtlmAuthSchemesAsync(Uri url) {
var httpClient = _httpClientFactory.CreateUnauthenticatedClient();
using var getRequest = new HttpRequestMessage(HttpMethod.Get, url);
- var result = await Timeout.ExecuteWithTimeout(timeout, async (timeoutToken) => {
+ var result = await _getSupportedNTLMAuthSchemesAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => {
var getResponse = await httpClient.SendAsync(getRequest, timeoutToken);
return ExtractAuthSchemes(getResponse);
});
@@ -101,12 +103,12 @@ internal string[] ExtractAuthSchemes(HttpResponseMessage response) {
return schemes;
}
- private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, TimeSpan timeout, NtlmAuthenticationHandler ntlmAuth = null) {
+ private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, NtlmAuthenticationHandler ntlmAuth = null) {
var httpClient = _httpClientFactory.CreateUnauthenticatedClient();
var transport = new HttpTransport(httpClient, url, authScheme, _logger);
var ntlmAuthHandler = ntlmAuth ?? new NtlmAuthenticationHandler($"HTTP/{url.Host}");
- var result = await Timeout.ExecuteWithTimeout(timeout, (timeoutToken) => ntlmAuthHandler.PerformNtlmAuthenticationAsync(transport, timeoutToken));
+ var result = await _ntlmAuthAdaptiveTimeout.ExecuteWithTimeout((timeoutToken) => ntlmAuthHandler.PerformNtlmAuthenticationAsync(transport, timeoutToken));
if (!result.IsSuccess) {
throw new TimeoutException($"Timeout during NTLM authentication for {url} with {authScheme}");
@@ -139,7 +141,7 @@ private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, T
response.EnsureSuccessStatusCode();
}
- private async Task AuthWithChannelBindingAsync(Uri url, string authScheme, TimeSpan timeout) {
+ private async Task AuthWithChannelBindingAsync(Uri url, string authScheme) {
var handler = new HttpClientHandler {
ServerCertificateCustomValidationCallback = (httpRequestMessage, cert, cetChain, policyErrors) => true,
};
@@ -153,7 +155,7 @@ private async Task AuthWithChannelBindingAsync(Uri url, string authScheme,
using var client = new HttpClient(handler);
- var result = await Timeout.ExecuteWithTimeout(timeout, async (timeoutToken) => {
+ var result = await _authWithChannelBindingAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => {
try {
HttpResponseMessage response = await client.GetAsync(url, timeoutToken);
return response.StatusCode == HttpStatusCode.OK;
diff --git a/src/CommonLib/Processors/CertAbuseProcessor.cs b/src/CommonLib/Processors/CertAbuseProcessor.cs
index b0c0ca088..4da05f191 100644
--- a/src/CommonLib/Processors/CertAbuseProcessor.cs
+++ b/src/CommonLib/Processors/CertAbuseProcessor.cs
@@ -18,14 +18,17 @@ public class CertAbuseProcessor
{
private readonly ILogger _log;
private readonly ILdapUtils _utils;
+ private readonly AdaptiveTimeout _getMachineSidAdaptiveTimeout;
+ private readonly AdaptiveTimeout _openSamServerAdaptiveTimeout;
public delegate Task ComputerStatusDelegate(CSVComputerStatus status);
public event ComputerStatusDelegate ComputerStatusEvent;
-
- public CertAbuseProcessor(ILdapUtils utils, ILogger log = null)
- {
+
+ public CertAbuseProcessor(ILdapUtils utils, ILogger log = null) {
_utils = utils;
_log = log ?? Logging.LogProvider.CreateLogger("CAProc");
+ _getMachineSidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetMachineSid)));
+ _openSamServerAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(SAMServer.OpenServer)));
}
///
@@ -365,7 +368,7 @@ await SendComputerStatus(new CSVComputerStatus
}
var server = openServerResult.Value;
- var getMachineSidResult = await Timeout.ExecuteRPCWithTimeout(TimeSpan.FromMinutes(2), (timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken));
+ var getMachineSidResult = await _getMachineSidAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken));
if (getMachineSidResult.IsFailed)
{
_log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError);
@@ -445,7 +448,7 @@ await SendComputerStatus(new CSVComputerStatus
public virtual SharpHoundRPC.Result OpenSamServer(string computerName)
{
- var result = Timeout.ExecuteRPCWithTimeout(TimeSpan.FromMinutes(2), (_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult();
+ var result = _openSamServerAdaptiveTimeout.ExecuteRPCWithTimeout((_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult();
if (result.IsFailed)
{
return SharpHoundRPC.Result.Fail(result.SError);
diff --git a/src/CommonLib/Processors/ComputerAvailability.cs b/src/CommonLib/Processors/ComputerAvailability.cs
index 6f2c9918c..489ebde1a 100644
--- a/src/CommonLib/Processors/ComputerAvailability.cs
+++ b/src/CommonLib/Processors/ComputerAvailability.cs
@@ -11,14 +11,12 @@ public class ComputerAvailability {
private readonly int _computerExpiryDays;
private readonly ILogger _log;
private readonly IPortScanner _scanner;
- private readonly int _scanTimeout;
private readonly bool _skipPasswordCheck;
private readonly bool _skipPortScan;
public ComputerAvailability(int timeout = 10000, int computerExpiryDays = 60, bool skipPortScan = false,
bool skipPasswordCheck = false, ILogger log = null) {
- _scanner = new PortScanner();
- _scanTimeout = timeout;
+ _scanner = new PortScanner(maxTimeout: timeout);
_skipPortScan = skipPortScan;
_log = log ?? Logging.LogProvider.CreateLogger("CompAvail");
_computerExpiryDays = computerExpiryDays;
@@ -28,8 +26,7 @@ public ComputerAvailability(int timeout = 10000, int computerExpiryDays = 60, bo
public ComputerAvailability(IPortScanner scanner, int timeout = 500, int computerExpiryDays = 60,
bool skipPortScan = false, bool skipPasswordCheck = false,
ILogger log = null) {
- _scanner = scanner ?? new PortScanner();
- _scanTimeout = timeout;
+ _scanner = scanner ?? new PortScanner(maxTimeout: timeout);
_skipPortScan = skipPortScan;
_log = log ?? Logging.LogProvider.CreateLogger("CompAvail");
_computerExpiryDays = computerExpiryDays;
@@ -101,7 +98,7 @@ await SendComputerStatus(new CSVComputerStatus {
Error = null
};
- if (!await _scanner.CheckPort(computerName, timeout: _scanTimeout)) {
+ if (!await _scanner.CheckPort(computerName)) {
_log.LogTrace("{ComputerName} is not available because port 445 is unavailable", computerName);
await SendComputerStatus(new CSVComputerStatus {
Status = ComputerStatus.PortNotOpen,
diff --git a/src/CommonLib/Processors/ComputerSessionProcessor.cs b/src/CommonLib/Processors/ComputerSessionProcessor.cs
index f66eaf5b9..2486b3d0e 100644
--- a/src/CommonLib/Processors/ComputerSessionProcessor.cs
+++ b/src/CommonLib/Processors/ComputerSessionProcessor.cs
@@ -22,6 +22,8 @@ public class ComputerSessionProcessor {
private readonly bool _doLocalAdminSessionEnum;
private readonly string _localAdminUsername;
private readonly string _localAdminPassword;
+ private readonly AdaptiveTimeout _readUserSessionsAdaptiveTimeout;
+ private readonly AdaptiveTimeout _readUserSessionsPriviledgedAdaptiveTimeout;
public ComputerSessionProcessor(ILdapUtils utils,
NativeMethods nativeMethods = null, ILogger log = null, string currentUserName = null,
@@ -34,6 +36,8 @@ public ComputerSessionProcessor(ILdapUtils utils,
_doLocalAdminSessionEnum = doLocalAdminSessionEnum;
_localAdminUsername = localAdminUsername;
_localAdminPassword = localAdminPassword;
+ _readUserSessionsAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ReadUserSessions)));
+ _readUserSessionsPriviledgedAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ReadUserSessionsPrivileged)));
}
public event ComputerStatusDelegate ComputerStatusEvent;
@@ -48,16 +52,12 @@ public ComputerSessionProcessor(ILdapUtils utils,
///
///
public async Task ReadUserSessions(string computerName, string computerSid,
- string computerDomain, TimeSpan timeout = default) {
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
+ string computerDomain) {
var ret = new SessionAPIResult();
_log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName);
- var result = await Timeout.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => {
+ var result = await _readUserSessionsAdaptiveTimeout.ExecuteNetAPIWithTimeout((timeoutToken) => {
NetAPIResult> result;
if (_doLocalAdminSessionEnum) {
// If we are authenticating using a local admin, we need to impersonate for this
@@ -190,15 +190,11 @@ await SendComputerStatus(new CSVComputerStatus {
///
///
public async Task ReadUserSessionsPrivileged(string computerName,
- string computerSamAccountName, string computerSid, TimeSpan timeout = default) {
+ string computerSamAccountName, string computerSid) {
var ret = new SessionAPIResult();
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
_log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName);
- var result = await Timeout.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => {
+ var result = await _readUserSessionsPriviledgedAdaptiveTimeout.ExecuteNetAPIWithTimeout((timeoutToken) => {
NetAPIResult>
result;
if (_doLocalAdminSessionEnum) {
diff --git a/src/CommonLib/Processors/DCLdapProcessor.cs b/src/CommonLib/Processors/DCLdapProcessor.cs
index 1272ad013..1fe544355 100644
--- a/src/CommonLib/Processors/DCLdapProcessor.cs
+++ b/src/CommonLib/Processors/DCLdapProcessor.cs
@@ -22,43 +22,41 @@ public class LdapAuthOptions {
public class DCLdapProcessor {
private readonly ILogger _log;
private readonly IPortScanner _scanner;
- private readonly int _portScanTimeout;
private readonly int _ldapTimeout;
private readonly Uri _ldapEndpoint;
private readonly Uri _ldapSslEndpoint;
+ private readonly AdaptiveTimeout _checkIsNtlmSigningRequiredAdaptiveTimeout;
+ private readonly AdaptiveTimeout _checkIsChannelBindingDisabledAdaptiveTimeout;
public delegate Task ComputerStatusDelegate(CSVComputerStatus status);
private readonly string SEC_E_UNSUPPORTED_FUNCTION = "80090302";
private readonly string SEC_E_BAD_BINDINGS = "80090346";
- public DCLdapProcessor(int portScanTimeout, string dcHostname, ILogger log = null) {
+ public DCLdapProcessor(int connectionTimeoutMs, string dcHostname, ILogger log = null) {
_log = log ?? Logging.LogProvider.CreateLogger("DCLdapProcessor");
- _scanner = new PortScanner();
- _portScanTimeout = portScanTimeout;
- _ldapTimeout = portScanTimeout / 1000;
+ _scanner = new PortScanner(maxTimeout: connectionTimeoutMs);
+ _ldapTimeout = connectionTimeoutMs / 1000;
_ldapEndpoint = new Uri($"ldap://{dcHostname}:389");
_ldapSslEndpoint = new Uri($"ldaps://{dcHostname}:636");
+ _checkIsNtlmSigningRequiredAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(CheckIsNtlmSigningRequired)));
+ _checkIsChannelBindingDisabledAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(1), Logging.LogProvider.CreateLogger(nameof(CheckIsChannelBindingDisabled)));
}
public event ComputerStatusDelegate ComputerStatusEvent;
- public async Task Scan(string computerName, TimeSpan timeout = default) {
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
+ public async Task Scan(string computerName) {
var hasLdap = await TestLdapPort();
var hasLdaps = await TestLdapsPort();
SharpHoundRPC.Result isSigningRequired = new(),
isChannelBindingDisabled = new();
if (hasLdap) {
- isSigningRequired = await Timeout.ExecuteRPCWithTimeout(timeout, CheckIsNtlmSigningRequired);
+ isSigningRequired = await _checkIsNtlmSigningRequiredAdaptiveTimeout.ExecuteRPCWithTimeout(CheckIsNtlmSigningRequired);
}
if (hasLdaps) {
- isChannelBindingDisabled = await Timeout.ExecuteRPCWithTimeout(timeout, CheckIsChannelBindingDisabled);
+ isChannelBindingDisabled = await _checkIsChannelBindingDisabledAdaptiveTimeout.ExecuteRPCWithTimeout(CheckIsChannelBindingDisabled);
}
if (isSigningRequired.IsFailed) {
@@ -117,12 +115,12 @@ await SendComputerStatus(new CSVComputerStatus {
/// bool
[ExcludeFromCodeCoverage]
public virtual async Task TestLdapPort() {
- return await _scanner.CheckPort(_ldapEndpoint.Host, _ldapEndpoint.Port, _portScanTimeout);
+ return await _scanner.CheckPort(_ldapEndpoint.Host, _ldapEndpoint.Port);
}
[ExcludeFromCodeCoverage]
public virtual async Task TestLdapsPort() {
- return await _scanner.CheckPort(_ldapSslEndpoint.Host, _ldapSslEndpoint.Port, _portScanTimeout);
+ return await _scanner.CheckPort(_ldapSslEndpoint.Host, _ldapSslEndpoint.Port);
}
public async Task> CheckIsNtlmSigningRequired(CancellationToken cancellationToken = default) {
diff --git a/src/CommonLib/Processors/LocalGroupProcessor.cs b/src/CommonLib/Processors/LocalGroupProcessor.cs
index f68ff781f..984ed53a4 100644
--- a/src/CommonLib/Processors/LocalGroupProcessor.cs
+++ b/src/CommonLib/Processors/LocalGroupProcessor.cs
@@ -16,22 +16,33 @@ public class LocalGroupProcessor
public delegate Task ComputerStatusDelegate(CSVComputerStatus status);
private readonly ILogger _log;
private readonly ILdapUtils _utils;
-
- public LocalGroupProcessor(ILdapUtils utils, ILogger log = null)
- {
+ private readonly AdaptiveTimeout _getMachineSidAdaptiveTimeout;
+ private readonly AdaptiveTimeout _openSamServerAdaptiveTimeout;
+ private readonly AdaptiveTimeout _getDomainsAdaptiveTimeout;
+ private readonly AdaptiveTimeout _openDomainAdaptiveTimeout;
+ private readonly AdaptiveTimeout _getAliasesAdaptiveTimeout;
+ private readonly AdaptiveTimeout _openAliasAdaptiveTimeout;
+ private readonly AdaptiveTimeout _getMembersAdaptiveTimeout;
+ private readonly AdaptiveTimeout _lookupPrincipalBySidAdaptiveTimeout;
+
+ public LocalGroupProcessor(ILdapUtils utils, ILogger log = null) {
_utils = utils;
_log = log ?? Logging.LogProvider.CreateLogger("LocalGroupProcessor");
+ _getMachineSidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetMachineSid)));
+ _openSamServerAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(SAMServer.OpenServer)));
+ _getDomainsAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.GetDomains)));
+ _openDomainAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.OpenDomain)));
+ _getAliasesAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMDomain.GetAliases)));
+ _openAliasAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMDomain.OpenAlias)));
+ _getMembersAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMAlias.GetMembers)));
+ _lookupPrincipalBySidAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ISAMServer.LookupPrincipalBySid)));
}
public event ComputerStatusDelegate ComputerStatusEvent;
- public virtual SharpHoundRPC.Result OpenSamServer(string computerName, TimeSpan timeout = default)
+ public virtual SharpHoundRPC.Result OpenSamServer(string computerName)
{
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
- var result = Timeout.ExecuteRPCWithTimeout(timeout, (_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult();
+ var result = _openSamServerAdaptiveTimeout.ExecuteRPCWithTimeout((_) => SAMServer.OpenServer(computerName)).GetAwaiter().GetResult();
if (result.IsFailed)
{
return SharpHoundRPC.Result.Fail(result.SError);
@@ -55,14 +66,10 @@ public IAsyncEnumerable GetLocalGroups(ResolvedSearchResult
///
///
public async IAsyncEnumerable GetLocalGroups(string computerName, string computerObjectId,
- string computerDomain, bool isDomainController, TimeSpan timeout = default)
+ string computerDomain, bool isDomainController)
{
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
//Open a handle to the server
- var openServerResult = OpenSamServer(computerName, timeout);
+ var openServerResult = OpenSamServer(computerName);
if (openServerResult.IsFailed)
{
_log.LogTrace("OpenServer failed on {ComputerName}: {Error}", computerName, openServerResult.SError);
@@ -81,7 +88,7 @@ await SendComputerStatus(new CSVComputerStatus
//Try to get the machine sid for the computer if its not already cached
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) {
- var getMachineSidResult = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken));
+ var getMachineSidResult = await _getMachineSidAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.GetMachineSid(cancellationToken: timeoutToken));
if (getMachineSidResult.IsFailed)
{
_log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError);
@@ -105,7 +112,7 @@ await SendComputerStatus(new CSVComputerStatus
}
//Get all available domains in the server
- var getDomainsResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetDomains());
+ var getDomainsResult = await _getDomainsAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetDomains());
if (getDomainsResult.IsFailed)
{
_log.LogTrace("GetDomains failed on {ComputerName}: {Error}", computerName, getDomainsResult.SError);
@@ -126,7 +133,7 @@ await SendComputerStatus(new CSVComputerStatus
continue;
//Open a handle to the domain
- var openDomainResult = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => server.OpenDomain(domainResult.Name, cancellationToken: timeoutToken));
+ var openDomainResult = await _openDomainAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => server.OpenDomain(domainResult.Name, cancellationToken: timeoutToken));
if (openDomainResult.IsFailed)
{
_log.LogTrace("Failed to open domain {Domain} on {ComputerName}: {Error}", domainResult.Name, computerName, openDomainResult.SError);
@@ -145,7 +152,7 @@ await SendComputerStatus(new CSVComputerStatus
var domain = openDomainResult.Value;
//Open a handle to the available aliases
- var getAliasesResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => domain.GetAliases());
+ var getAliasesResult = await _getAliasesAdaptiveTimeout.ExecuteRPCWithTimeout((_) => domain.GetAliases());
if (getAliasesResult.IsFailed)
{
@@ -178,7 +185,7 @@ await SendComputerStatus(new CSVComputerStatus
};
//Open a handle to the alias
- var openAliasResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => domain.OpenAlias(alias.Rid));
+ var openAliasResult = await _openAliasAdaptiveTimeout.ExecuteRPCWithTimeout((_) => domain.OpenAlias(alias.Rid));
if (openAliasResult.IsFailed)
{
_log.LogTrace("Failed to open alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
@@ -199,7 +206,7 @@ await SendComputerStatus(new CSVComputerStatus
var localGroup = openAliasResult.Value;
//Call GetMembersInAlias to get raw group members
- var getMembersResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => localGroup.GetMembers());
+ var getMembersResult = await _getMembersAdaptiveTimeout.ExecuteRPCWithTimeout((_) => localGroup.GetMembers());
if (getMembersResult.IsFailed)
{
_log.LogTrace("Failed to get members in alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
@@ -274,7 +281,7 @@ await SendComputerStatus(new CSVComputerStatus
}
//Attempt to lookup the principal in the server directly
- var lookupUserResult = await Timeout.ExecuteRPCWithTimeout(timeout, timeoutToken => server.LookupPrincipalBySid(securityIdentifier, timeoutToken));
+ var lookupUserResult = await _lookupPrincipalBySidAdaptiveTimeout.ExecuteRPCWithTimeout(timeoutToken => server.LookupPrincipalBySid(securityIdentifier, timeoutToken));
if (lookupUserResult.IsFailed)
{
_log.LogTrace("Unable to resolve local sid {SID}: {Error}", sidValue, lookupUserResult.SError);
diff --git a/src/CommonLib/Processors/PortScanner.cs b/src/CommonLib/Processors/PortScanner.cs
index 3a55cfb4a..b294b308e 100644
--- a/src/CommonLib/Processors/PortScanner.cs
+++ b/src/CommonLib/Processors/PortScanner.cs
@@ -9,13 +9,15 @@ namespace SharpHoundCommonLib.Processors {
public class PortScanner : IPortScanner {
private static readonly ConcurrentDictionary PortScanCache = new();
private readonly ILogger _log;
+ private readonly AdaptiveTimeout _adaptiveTimeout;
- public PortScanner() {
- _log = Logging.LogProvider.CreateLogger("PortScanner");
+ public PortScanner() : this(null) {
+
}
- public PortScanner(ILogger log = null) {
+ public PortScanner(ILogger log = null, int maxTimeout = 10000) {
_log = log ?? Logging.LogProvider.CreateLogger("PortScanner");
+ _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(maxTimeout), _log);
}
///
@@ -26,7 +28,7 @@ public PortScanner(ILogger log = null) {
/// Timeout in milliseconds
/// Used as a switch to control if this method should throw exceptions that occur.
/// True if port is open, otherwise false
- public virtual async Task CheckPort(string hostname, int port = 445, int timeout = 10000,
+ public virtual async Task CheckPort(string hostname, int port = 445,
bool throwError = false) {
var key = new PingCacheKey {
Port = port,
@@ -40,13 +42,11 @@ public virtual async Task CheckPort(string hostname, int port = 445, int t
try {
using var client = new TcpClient();
- // Blocking External Call
- var ca = await Timeout.ExecuteWithTimeout(TimeSpan.FromMilliseconds(timeout), (_) => client.ConnectAsync(hostname, port));
+ var ca = await _adaptiveTimeout.ExecuteWithTimeout((_) => client.ConnectAsync(hostname, port));
if (!ca.IsSuccess) {
- _log.LogDebug("{HostName} did not respond to scan on port {Port} within {Timeout}ms", hostname, port,
- timeout);
+ _log.LogDebug("{HostName} did not respond to scan on port {Port} within {Timeout}ms", hostname, port, _adaptiveTimeout.GetAdaptiveTimeout());
if (throwError) {
- throw new TimeoutException("Timed Out");
+ throw new TimeoutException(ca.Error);
}
PortScanCache.TryAdd(key, false);
return false;
diff --git a/src/CommonLib/Processors/SmbProcessor.cs b/src/CommonLib/Processors/SmbProcessor.cs
index 43862113b..44c2d7116 100644
--- a/src/CommonLib/Processors/SmbProcessor.cs
+++ b/src/CommonLib/Processors/SmbProcessor.cs
@@ -16,22 +16,17 @@ public class SmbProcessor
public delegate Task ComputerStatusDelegate(CSVComputerStatus status);
private readonly ILogger _log;
private readonly ISmbScanner _smbScanner;
- private readonly int _timeoutMs;
-
- public SmbProcessor(int timeoutMs, ISmbScanner smbScanner = null, ILogger log = null)
- {
- _timeoutMs = timeoutMs;
+ private readonly AdaptiveTimeout _scanHostAdaptiveTimeout;
+
+ public SmbProcessor(int timeoutMs, ISmbScanner smbScanner = null, ILogger log = null) {
_log = log ?? Logging.LogProvider.CreateLogger("SmbProcessor");
- _smbScanner = smbScanner ?? new SmbScanner(_log) { TimeoutMs = _timeoutMs };
+ _smbScanner = smbScanner ?? new SmbScanner(_log) { MaxTimeoutMs = timeoutMs };
+ _scanHostAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(timeoutMs), Logging.LogProvider.CreateLogger(nameof(ISmbScanner.ScanHost)));
}
public event ComputerStatusDelegate ComputerStatusEvent;
- public virtual async Task> Scan(string host, TimeSpan timeout = default) {
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
- var result = await Timeout.ExecuteRPCWithTimeout(timeout, (timeoutToken) => _smbScanner.ScanHost(host, 445, timeoutToken));
+ public virtual async Task> Scan(string host) {
+ var result = await _scanHostAdaptiveTimeout.ExecuteRPCWithTimeout((timeoutToken) => _smbScanner.ScanHost(host, 445, timeoutToken));
if (result.IsFailed) {
await SendComputerStatus(new CSVComputerStatus {
diff --git a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs
index e23def5b5..3a3d4f03d 100644
--- a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs
+++ b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs
@@ -5,7 +5,6 @@
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.OutputTypes;
-using SharpHoundRPC;
using SharpHoundRPC.Shared;
using SharpHoundRPC.Wrappers;
@@ -15,10 +14,16 @@ public class UserRightsAssignmentProcessor {
private readonly ILogger _log;
private readonly ILdapUtils _utils;
+ private readonly AdaptiveTimeout _openLSAPolicyAdaptiveTimeout;
+ private readonly AdaptiveTimeout _getLocalDomainInfoAdaptiveTimeout;
+ private readonly AdaptiveTimeout _getResolvedPrincipalWithPriviledgeAdaptiveTimeout;
public UserRightsAssignmentProcessor(ILdapUtils utils, ILogger log = null) {
_utils = utils;
_log = log ?? Logging.LogProvider.CreateLogger("UserRightsAssignmentProcessor");
+ _openLSAPolicyAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(OpenLSAPolicy)));
+ _getLocalDomainInfoAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ILSAPolicy.GetLocalDomainInformation)));
+ _getResolvedPrincipalWithPriviledgeAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(ILSAPolicy.GetResolvedPrincipalsWithPrivilege)));
}
public event ComputerStatusDelegate ComputerStatusEvent;
@@ -47,13 +52,8 @@ public IAsyncEnumerable GetUserRightsAssignments(
///
///
public async IAsyncEnumerable GetUserRightsAssignments(string computerName,
- string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null,
- TimeSpan timeout = default) {
- if (timeout == default) {
- timeout = TimeSpan.FromMinutes(2);
- }
-
- var policyOpenResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => OpenLSAPolicy(computerName));
+ string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null) {
+ var policyOpenResult = await _openLSAPolicyAdaptiveTimeout.ExecuteRPCWithTimeout((_) => OpenLSAPolicy(computerName));
if (!policyOpenResult.IsSuccess) {
_log.LogDebug("LSAOpenPolicy failed on {ComputerName} with status {Status}", computerName,
policyOpenResult.Error);
@@ -71,7 +71,7 @@ await SendComputerStatus(new CSVComputerStatus {
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var temp)) {
var getMachineSidResult =
- await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetLocalDomainInformation());
+ await _getLocalDomainInfoAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetLocalDomainInformation());
if (getMachineSidResult.IsFailed) {
_log.LogWarning("Failed to get machine sid for {Server}: {Status}. Abandoning URA collection",
computerName, getMachineSidResult.SError);
@@ -98,7 +98,7 @@ await SendComputerStatus(new CSVComputerStatus {
};
//Ask for all principals with the specified privilege.
- var enumerateAccountsResult = await Timeout.ExecuteRPCWithTimeout(timeout, (_) => server.GetResolvedPrincipalsWithPrivilege(privilege));
+ var enumerateAccountsResult = await _getResolvedPrincipalWithPriviledgeAdaptiveTimeout.ExecuteRPCWithTimeout((_) => server.GetResolvedPrincipalsWithPrivilege(privilege));
if (enumerateAccountsResult.IsFailed) {
_log.LogDebug(
"LSAEnumerateAccountsWithUserRight failed on {ComputerName} with status {Status} for privilege {Privilege}",
diff --git a/src/CommonLib/SMB/SmbScanner.cs b/src/CommonLib/SMB/SmbScanner.cs
index cf88813b0..02b694f84 100644
--- a/src/CommonLib/SMB/SmbScanner.cs
+++ b/src/CommonLib/SMB/SmbScanner.cs
@@ -27,12 +27,14 @@ public class SmbScanner : ISmbScanner {
///
/// Timeout value used when connecting to hosts or waiting for a response.
///
- public int TimeoutMs { get; set; } = 2000;
+ public int MaxTimeoutMs { get; set; } = 2000;
+ public readonly AdaptiveTimeout _adaptiveTimeout;
public ILogger _log;
public SmbScanner(ILogger log) {
- _log = log ?? Logging.LogProvider.CreateLogger("SmbScanner"); ;
+ _log = log ?? Logging.LogProvider.CreateLogger("SmbScanner");
+ _adaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMilliseconds(MaxTimeoutMs), Logging.LogProvider.CreateLogger(nameof(TrySMBNegotiate)));
}
@@ -125,7 +127,7 @@ private SharpHoundRPC.Result CheckRegistrySigningRequired(string ho
negoReqBytes = negotiateRequest.ToBytes();
}
- var negoResp = await Timeout.ExecuteWithTimeout(TimeSpan.FromMilliseconds(TimeoutMs), (timeoutToken) => SendAndReceiveData(host, port, negoReqBytes, timeoutToken));
+ var negoResp = await _adaptiveTimeout.ExecuteWithTimeout((timeoutToken) => SendAndReceiveData(host, port, negoReqBytes, timeoutToken));
if (!negoResp.IsSuccess)
throw new OperationCanceledException("Connection attempt timed out");
diff --git a/src/CommonLib/Timeout.cs b/src/CommonLib/Timeout.cs
index 5373918c4..55da9180f 100644
--- a/src/CommonLib/Timeout.cs
+++ b/src/CommonLib/Timeout.cs
@@ -32,7 +32,7 @@ public static async Task> ExecuteWithTimeout(TimeSpan timeout, Func
if (parentToken.IsCancellationRequested)
return Result.Fail("Cancellation requested");
else
- return Result.Fail("Timeout");
+ return Result.Fail($"Timeout after {timeout.TotalMilliseconds} ms");
}
///
@@ -61,7 +61,7 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Action> ExecuteWithTimeout(TimeSpan timeout, Func
if (parentToken.IsCancellationRequested)
return Result.Fail("Cancellation requested");
else
- return Result.Fail("Timeout");
+ return Result.Fail($"Timeout after {timeout.TotalMilliseconds} ms");
}
///
@@ -126,7 +126,7 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Func
@@ -137,8 +137,8 @@ public static async Task ExecuteWithTimeout(TimeSpan timeout, Func
///
///
- public static async Task> ExecuteNetAPIWithTimeout(TimeSpan timeout, Func> func) {
- var result = await ExecuteWithTimeout(timeout, func);
+ public static async Task> ExecuteNetAPIWithTimeout(TimeSpan timeout, Func> func, CancellationToken parentToken = default) {
+ var result = await ExecuteWithTimeout(timeout, func, parentToken);
if (result.IsSuccess)
return result.Value;
else
@@ -153,8 +153,8 @@ public static async Task> ExecuteNetAPIWithTimeout(TimeSpan t
///
///
///
- public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func> func) {
- var result = await ExecuteWithTimeout(timeout, func);
+ public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func> func, CancellationToken parentToken = default) {
+ var result = await ExecuteWithTimeout(timeout, func, parentToken);
if (result.IsSuccess)
return result.Value;
else
@@ -169,8 +169,8 @@ public static async Task> ExecuteNetAPIWithTimeout(TimeSpan t
///
///
///
- public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func>> func) {
- var result = await ExecuteWithTimeout(timeout, func);
+ public static async Task> ExecuteRPCWithTimeout(TimeSpan timeout, Func>> func, CancellationToken parentToken = default) {
+ var result = await ExecuteWithTimeout(timeout, func, parentToken);
if (result.IsSuccess)
return result.Value;
else
diff --git a/src/SharpHoundRPC/PortScanner/IPortScanner.cs b/src/SharpHoundRPC/PortScanner/IPortScanner.cs
index 833108078..78624596d 100644
--- a/src/SharpHoundRPC/PortScanner/IPortScanner.cs
+++ b/src/SharpHoundRPC/PortScanner/IPortScanner.cs
@@ -2,6 +2,6 @@
namespace SharpHoundRPC.PortScanner {
public interface IPortScanner {
- Task CheckPort(string hostname, int port = 445, int timeout = 10000, bool throwError = false);
+ Task CheckPort(string hostname, int port = 445, bool throwError = false);
}
}
\ No newline at end of file
diff --git a/test/unit/AdaptiveTimeoutTest.cs b/test/unit/AdaptiveTimeoutTest.cs
new file mode 100644
index 000000000..c86158daa
--- /dev/null
+++ b/test/unit/AdaptiveTimeoutTest.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using SharpHoundCommonLib;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace CommonLibTest;
+
+public class AdaptiveTimeoutTest {
+ private readonly ITestOutputHelper _testOutputHelper;
+
+ public AdaptiveTimeoutTest(ITestOutputHelper testOutputHelper) {
+ _testOutputHelper = testOutputHelper;
+ }
+
+ [Fact]
+ public async Task AdaptiveTimeout_GetAdaptiveTimeout_NotEnoughSamplesAsync() {
+ var maxTimeout = TimeSpan.FromSeconds(1);
+ var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3);
+
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50));
+
+ var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout();
+ Assert.Equal(maxTimeout, adaptiveTimeoutResult);
+ }
+
+ [Fact]
+ public async Task AdaptiveTimeout_GetAdaptiveTimeout_AdaptiveDisabled() {
+ var maxTimeout = TimeSpan.FromSeconds(1);
+ var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3, false);
+
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50));
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50));
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50));
+
+ var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout();
+ Assert.Equal(maxTimeout, adaptiveTimeoutResult);
+ }
+
+ [Fact]
+ public async Task AdaptiveTimeout_GetAdaptiveTimeout() {
+ var maxTimeout = TimeSpan.FromSeconds(1);
+ var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3);
+
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(40));
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50));
+ await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(60));
+
+ var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout();
+ Assert.True(adaptiveTimeoutResult < maxTimeout);
+ }
+
+ [Fact]
+ public async Task AdaptiveTimeout_GetAdaptiveTimeout_TimeSpikeSafetyValve() {
+ var maxTimeout = TimeSpan.FromSeconds(1);
+ var numSamples = 100;
+ var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), numSamples, 1000, 10);
+
+ for (int i = 0; i < numSamples; i++)
+ await adaptiveTimeout.ExecuteWithTimeout((_) => Thread.Sleep(10));
+
+ for (int i = 0; i < 6; i++)
+ await adaptiveTimeout.ExecuteWithTimeout((_) => Thread.Sleep(200));
+
+ var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout();
+ Assert.Equal(maxTimeout, adaptiveTimeoutResult);
+ }
+}
\ No newline at end of file
diff --git a/test/unit/ComputerAvailabilityTests.cs b/test/unit/ComputerAvailabilityTests.cs
index 53127eb50..cd28f8849 100644
--- a/test/unit/ComputerAvailabilityTests.cs
+++ b/test/unit/ComputerAvailabilityTests.cs
@@ -18,12 +18,12 @@ public ComputerAvailabilityTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
var m = new Mock();
- m.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()))
+ m.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny()))
.Returns(Task.FromResult(false));
_falsePortScanner = m.Object;
var m2 = new Mock();
- m2.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()))
+ m2.Setup(x => x.CheckPort(It.IsAny(), It.IsAny(), It.IsAny()))
.Returns(Task.FromResult(true));
_truePortScanner = m2.Object;
}
diff --git a/test/unit/ComputerSessionProcessorTest.cs b/test/unit/ComputerSessionProcessorTest.cs
index af4d4ec97..9c772ff6b 100644
--- a/test/unit/ComputerSessionProcessorTest.cs
+++ b/test/unit/ComputerSessionProcessorTest.cs
@@ -216,45 +216,45 @@ public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringW
Assert.Equal(expected, test.Results);
}
- [Fact]
- public async Task ComputerSessionProcessor_TestTimeout() {
- var nativeMethods = new Mock();
- nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(() => {
- Task.Delay(1000).Wait();
- return Array.Empty();
- });
- var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,"");
- var receivedStatus = new List();
- var machineDomainSid = $"{Consts.MockDomainSid}-1000";
- processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; };
- var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local",
- TimeSpan.FromMilliseconds(1));
- Assert.Empty(results.Results);
- Assert.Single(receivedStatus);
- var status = receivedStatus[0];
- Assert.Equal("Timeout", status.Status);
- }
-
- [Fact]
- public async Task ComputerSessionProcessor_TestTimeoutPrivileged() {
- var nativeMethods = new Mock();
- nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(() => {
- Task.Delay(1000).Wait();
- return Array.Empty();
- });
- var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,"");
- var receivedStatus = new List();
- var machineDomainSid = $"{Consts.MockDomainSid}-1000";
- processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; };
-
- var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid,
- "testlab.local",
- TimeSpan.FromMilliseconds(1));
- Assert.Empty(results.Results);
- Assert.Single(receivedStatus);
- var status = receivedStatus[0];
- Assert.Equal("Timeout", status.Status);
- }
+ // Obsolete by AdaptiveTimeout
+ // [Fact]
+ // public async Task ComputerSessionProcessor_TestTimeout() {
+ // var nativeMethods = new Mock();
+ // nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(() => {
+ // Task.Delay(1000).Wait();
+ // return Array.Empty();
+ // });
+ // var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,"");
+ // var receivedStatus = new List();
+ // var machineDomainSid = $"{Consts.MockDomainSid}-1000";
+ // processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; };
+ // var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local");
+ // Assert.Empty(results.Results);
+ // Assert.Single(receivedStatus);
+ // var status = receivedStatus[0];
+ // Assert.Equal("Timeout", status.Status);
+ // }
+
+ // Obsolete by AdaptiveTimeout
+ // [Fact]
+ // public async Task ComputerSessionProcessor_TestTimeoutPrivileged() {
+ // var nativeMethods = new Mock();
+ // nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(() => {
+ // Task.Delay(1000).Wait();
+ // return Array.Empty();
+ // });
+ // var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,"");
+ // var receivedStatus = new List();
+ // var machineDomainSid = $"{Consts.MockDomainSid}-1000";
+ // processor.ComputerStatusEvent += status => { receivedStatus.Add(status); return Task.CompletedTask; };
+
+ // var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid,
+ // "testlab.local");
+ // Assert.Empty(results.Results);
+ // Assert.Single(receivedStatus);
+ // var status = receivedStatus[0];
+ // Assert.Equal("Timeout", status.Status);
+ // }
[Fact]
public async Task ComputerSessionProcessor_ReadUserSessionSendsComputerStatus()
diff --git a/test/unit/DCLdapProcessorTest.cs b/test/unit/DCLdapProcessorTest.cs
index 418d5ec95..e8ae0a653 100644
--- a/test/unit/DCLdapProcessorTest.cs
+++ b/test/unit/DCLdapProcessorTest.cs
@@ -30,7 +30,7 @@ public void Dispose() {
[Fact]
public async Task DCLdapProcessor_Scan() {
- var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null);
+ var mockProcessor = new Mock(10, "primary.testlab.local", null);
mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).ReturnsAsync(false);
@@ -39,10 +39,11 @@ public async Task DCLdapProcessor_Scan() {
var processor = mockProcessor.Object;
var receivedStatus = new List();
- processor.ComputerStatusEvent += async status => {
+ processor.ComputerStatusEvent += status => {
receivedStatus.Add(status);
+ return Task.CompletedTask;
};
- var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMinutes(2));
+ var results = await processor.Scan("primary.testlab.local");
Assert.Equal(2, receivedStatus.Count);
var status = receivedStatus[0];
@@ -57,7 +58,7 @@ public async Task DCLdapProcessor_Scan() {
[Fact]
public async Task DCLdapProcessor_Scan_Failed() {
- var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null);
+ var mockProcessor = new Mock(10, "primary.testlab.local", null);
mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Throws(new Exception("Error"));
@@ -66,10 +67,11 @@ public async Task DCLdapProcessor_Scan_Failed() {
var processor = mockProcessor.Object;
var receivedStatus = new List();
- processor.ComputerStatusEvent += async status => {
+ processor.ComputerStatusEvent += status => {
receivedStatus.Add(status);
+ return Task.CompletedTask;
};
- var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMinutes(2));
+ var results = await processor.Scan("primary.testlab.local");
Assert.Equal(2, receivedStatus.Count);
var status = receivedStatus[0];
@@ -84,38 +86,39 @@ public async Task DCLdapProcessor_Scan_Failed() {
Assert.False(results.IsChannelBindingDisabled.Collected);
}
- [Fact]
- public async Task DCLdapProcessor_CheckScan_Timeout() {
- var mockProcessor = new Mock(2, "primary.testlab.local", null);
-
- mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Returns(async () => {
- await Task.Delay(100);
- return false;
- });
-
- mockProcessor.Setup(x => x.TestLdapPort()).ReturnsAsync(true);
- mockProcessor.Setup(x => x.TestLdapsPort()).ReturnsAsync(true);
-
- var processor = mockProcessor.Object;
- var receivedStatus = new List();
- processor.ComputerStatusEvent += status => {
- receivedStatus.Add(status);
- return Task.CompletedTask;
- };
- var results = await processor.Scan("primary.testlab.local", TimeSpan.FromMilliseconds(1));
-
- Assert.Equal(2, receivedStatus.Count);
- var status = receivedStatus[0];
- Assert.Equal("Timeout", status.Status);
- status = receivedStatus[1];
- Assert.Equal("Timeout", status.Status);
- Assert.Equal("Timeout", results.IsSigningRequired.FailureReason);
- Assert.Equal("Timeout", results.IsChannelBindingDisabled.FailureReason);
- }
+ // Obsolete by AdaptiveTimeout
+ // [Fact]
+ // public async Task DCLdapProcessor_CheckScan_Timeout() {
+ // var mockProcessor = new Mock(2, "primary.testlab.local", null);
+
+ // mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Returns(async () => {
+ // await Task.Delay(100);
+ // return false;
+ // });
+
+ // mockProcessor.Setup(x => x.TestLdapPort()).ReturnsAsync(true);
+ // mockProcessor.Setup(x => x.TestLdapsPort()).ReturnsAsync(true);
+
+ // var processor = mockProcessor.Object;
+ // var receivedStatus = new List();
+ // processor.ComputerStatusEvent += status => {
+ // receivedStatus.Add(status);
+ // return Task.CompletedTask;
+ // };
+ // var results = await processor.Scan("primary.testlab.local");
+
+ // Assert.Equal(2, receivedStatus.Count);
+ // var status = receivedStatus[0];
+ // Assert.Equal("Timeout", status.Status);
+ // status = receivedStatus[1];
+ // Assert.Equal("Timeout", status.Status);
+ // Assert.Equal("Timeout", results.IsSigningRequired.FailureReason);
+ // Assert.Equal("Timeout", results.IsChannelBindingDisabled.FailureReason);
+ // }
[Fact]
public async Task DCLdapProcessor_CheckIsNtlmSigningRequired() {
- var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null);
+ var mockProcessor = new Mock(10, "primary.testlab.local", null);
mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).ReturnsAsync(false);
var processor = mockProcessor.Object;
var result = await processor.CheckIsNtlmSigningRequired();
@@ -125,7 +128,7 @@ public async Task DCLdapProcessor_CheckIsNtlmSigningRequired() {
[Fact]
public async Task DCLdapProcessor_CheckIsNtlmSigningRequired_Exception() {
- var mockProcessor = new Mock(It.IsAny(), "primary.testlab.local", null);
+ var mockProcessor = new Mock(10, "primary.testlab.local", null);
mockProcessor.Setup(x => x.Authenticate(It.IsAny(), It.IsAny(), null, null, It.IsAny())).Throws(new Exception("Error"));
var processor = mockProcessor.Object;
var result = await processor.CheckIsNtlmSigningRequired();
@@ -142,7 +145,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_SEC_E
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException("Error", (int)LdapErrorCodes.InvalidCredentials, SEC_E_UNSUPPORTED_FUNCTION));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLogContains(LogLevel.Debug, expected);
@@ -157,7 +160,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_SEC_E
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException("Error", (int)LdapErrorCodes.InvalidCredentials, SEC_E_BAD_BINDINGS));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLogContains(LogLevel.Debug, expected);
@@ -172,7 +175,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidCredentialsException_Unhan
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.InvalidCredentials, "80090347"));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLogContains(LogLevel.Error, expected);
@@ -187,7 +190,7 @@ public async Task DCLdapProcessor_Authenticate_StrongAuthRequiredException() {
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.StrongAuthRequired, null));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLog(LogLevel.Debug, expected);
@@ -202,7 +205,7 @@ public async Task DCLdapProcessor_Authenticate_ServerDownException() {
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.ServerDown));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLog(LogLevel.Debug, expected);
@@ -217,7 +220,7 @@ public async Task DCLdapProcessor_Authenticate_LdapUnhandledException() {
var mockLogger = new Mock>();
var mockLdapTransport = new Mock(null, It.IsAny());
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Throws(new LdapNativeException(exception, (int)LdapErrorCodes.LocalError));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), null, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLogContains(LogLevel.Error, expected);
@@ -235,7 +238,7 @@ public async Task DCLdapProcessor_Authenticate_InvalidOperationException() {
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Verifiable();
mockAuthenticator.Setup(x => x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny()))
.Throws(new InvalidOperationException(exception));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), mockAuthenticator.Object, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLog(LogLevel.Debug, expected);
@@ -253,7 +256,7 @@ public async Task DCLdapProcessor_Authenticate_UnhandledException() {
mockLdapTransport.Setup(x => x.InitializeConnectionAsync(It.IsAny())).Verifiable();
mockAuthenticator.Setup(x => x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny()))
.Throws(new Exception(exception));
- var processor = new DCLdapProcessor(It.IsAny(), "primary.testlab.local", mockLogger.Object);
+ var processor = new DCLdapProcessor(10, "primary.testlab.local", mockLogger.Object);
var result = await processor.Authenticate(new Uri(endpoint), It.IsAny(), mockAuthenticator.Object, mockLdapTransport.Object);
Assert.False(result);
mockLogger.VerifyLogContains(LogLevel.Error, expected);
diff --git a/test/unit/HttpNtlmAuthenticationServiceTest.cs b/test/unit/HttpNtlmAuthenticationServiceTest.cs
index e3273de9d..e78d654df 100644
--- a/test/unit/HttpNtlmAuthenticationServiceTest.cs
+++ b/test/unit/HttpNtlmAuthenticationServiceTest.cs
@@ -77,39 +77,41 @@ public void HttpNtlmAuthenticationService_ExtractAuthSchemes_Success() {
Assert.Equal("Negotiate", result[1]);
}
- [Fact]
- public void HttpNtlmAuthenticationService_EnsureRequiresAuth_GetSupportedNtlmAuthSchemesAsync_Timeout() {
- var url = new Uri("http://primary.testlab.local/");
- var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
- var ex = Assert.ThrowsAsync(() =>
- service.EnsureRequiresAuth(url, true, TimeSpan.FromMilliseconds(1)));
- Assert.Equal($"Timeout getting supported NTLM auth schemes for {url}", ex.Result.Message);
+ // Obsolete by AdaptiveTimeout
+ // [Fact]
+ // public void HttpNtlmAuthenticationService_EnsureRequiresAuth_GetSupportedNtlmAuthSchemesAsync_Timeout() {
+ // var url = new Uri("http://primary.testlab.local/");
+ // var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
+ // var ex = Assert.ThrowsAsync(() =>
+ // service.EnsureRequiresAuth(url, true));
+ // Assert.Equal($"Timeout getting supported NTLM auth schemes for {url}", ex.Result.Message);
- }
+ // }
- [Fact]
- public void HttpNtlmAuthenticationService_AuthWithBadChannelBindingsAsync_Timeout() {
- var url = new Uri("http://primary.testlab.local/");
- var authScheme = "NTLM";
- var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
- var httpResponseMessage = new HttpResponseMessage {
- StatusCode = HttpStatusCode.InternalServerError,
- };
- var mockAuthenticator = new Mock(It.IsAny(), null);
- mockAuthenticator.Setup(x =>
- x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())).Returns(async () => {
- await Task.Delay(1000);
- return httpResponseMessage;
- });
-
- var ex = Assert.ThrowsAsync(async () => await TestPrivateMethod.InstanceMethod(service,
- "AuthWithBadChannelBindingsAsync",
- [
- url, authScheme, TimeSpan.FromMilliseconds(1), mockAuthenticator.Object
- ]));
- Assert.Equal($"Timeout during NTLM authentication for {url} with {authScheme}", ex.Result.Message);
+ // Obsolete by AdaptiveTimeout
+ // [Fact]
+ // public void HttpNtlmAuthenticationService_AuthWithBadChannelBindingsAsync_Timeout() {
+ // var url = new Uri("http://primary.testlab.local/");
+ // var authScheme = "NTLM";
+ // var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
+ // var httpResponseMessage = new HttpResponseMessage {
+ // StatusCode = HttpStatusCode.InternalServerError,
+ // };
+ // var mockAuthenticator = new Mock(It.IsAny(), null);
+ // mockAuthenticator.Setup(x =>
+ // x.PerformNtlmAuthenticationAsync(It.IsAny(), It.IsAny())).Returns(async () => {
+ // await Task.Delay(1000);
+ // return httpResponseMessage;
+ // });
- }
+ // var ex = Assert.ThrowsAsync(async () => await TestPrivateMethod.InstanceMethod(service,
+ // "AuthWithBadChannelBindingsAsync",
+ // [
+ // url, authScheme, TimeSpan.FromMilliseconds(1), mockAuthenticator.Object
+ // ]));
+ // Assert.Equal($"Timeout during NTLM authentication for {url} with {authScheme}", ex.Result.Message);
+
+ // }
//// Throws "no such host is known" exception
// [Fact]
diff --git a/test/unit/LocalGroupProcessorTest.cs b/test/unit/LocalGroupProcessorTest.cs
index 51ded05f9..41e54c12a 100644
--- a/test/unit/LocalGroupProcessorTest.cs
+++ b/test/unit/LocalGroupProcessorTest.cs
@@ -29,7 +29,7 @@ public void Dispose() {
public async Task LocalGroupProcessor_TestWorkstation() {
var mockProcessor = new Mock(new MockLdapUtils(), null);
var mockSamServer = new MockWorkstationSAMServer();
- mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer);
+ mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer);
var processor = mockProcessor.Object;
var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1001";
var results = await processor.GetLocalGroups("win10.testlab.local", machineDomainSid, "TESTLAB.LOCAL", false)
@@ -57,7 +57,7 @@ public async Task LocalGroupProcessor_TestWorkstation() {
public async Task LocalGroupProcessor_TestDomainController() {
var mockProcessor = new Mock(new MockLdapUtils(), null);
var mockSamServer = new MockDCSAMServer();
- mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(mockSamServer);
+ mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(mockSamServer);
var processor = mockProcessor.Object;
var machineDomainSid = $"{Consts.MockWorkstationMachineSid}-1000";
@@ -162,7 +162,7 @@ public async Task LocalGroupProcessor_TestTimeout() {
var mockUtils = new Mock();
var mockProcessor = new Mock(mockUtils.Object, null);
- mockProcessor.Setup(x => x.OpenSamServer(It.IsAny(), It.IsAny())).Returns(() => {
+ mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(() => {
return SharpHoundRPC.Result