From 98ee6c7f669ea29618577219847716389dfd9d0c Mon Sep 17 00:00:00 2001 From: Stanley Dimant Date: Sat, 20 Aug 2022 14:22:50 +0200 Subject: [PATCH] attempt to use lock in Authentication handler to lock down dictionary access --- .../SecretKeyAuthenticationHandler.cs | 126 +++++++++++------- .../Metrics/LockedProxyCounter.cs | 22 --- .../Metrics/LockedProxyGauge.cs | 47 ------- .../Metrics/MareMetrics.cs | 58 ++++---- 4 files changed, 111 insertions(+), 142 deletions(-) delete mode 100644 MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyCounter.cs delete mode 100644 MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyGauge.cs diff --git a/MareSynchronosServer/MareSynchronosServer/Authentication/SecretKeyAuthenticationHandler.cs b/MareSynchronosServer/MareSynchronosServer/Authentication/SecretKeyAuthenticationHandler.cs index d5a046d..55371fb 100644 --- a/MareSynchronosServer/MareSynchronosServer/Authentication/SecretKeyAuthenticationHandler.cs +++ b/MareSynchronosServer/MareSynchronosServer/Authentication/SecretKeyAuthenticationHandler.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Security.Claims; @@ -9,6 +8,7 @@ using System.Text.Encodings.Web; using System.Threading; using System.Threading.Tasks; using MareSynchronosServer.Data; +using MareSynchronosServer.Metrics; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http; using Microsoft.EntityFrameworkCore; @@ -23,7 +23,7 @@ namespace MareSynchronosServer.Authentication private int failedAttempts = 1; public int FailedAttempts => failedAttempts; public Task ResetTask { get; set; } - public CancellationTokenSource ResetCts { get; set; } = new(); + public CancellationTokenSource? ResetCts { get; set; } public void Dispose() { @@ -48,61 +48,81 @@ namespace MareSynchronosServer.Authentication private readonly IConfiguration _configuration; public const string AuthScheme = "SecretKeyAuth"; private const string unauthorized = "Unauthorized"; - public static ConcurrentDictionary Authentications = new(); - private static ConcurrentDictionary FailedAuthorizations = new(); - private static SemaphoreSlim dbLockSemaphore = new SemaphoreSlim(20); - private int failedAttemptsForTempBan; - private int tempBanMinutes; + public static readonly Dictionary Authentications = new(); + private static readonly Dictionary FailedAuthorizations = new(); + private static readonly object authDictLock = new(); + private static readonly object failedAuthLock = new(); + private readonly int failedAttemptsForTempBan; + private readonly int tempBanMinutes; public static void ClearUnauthorizedUsers() { - foreach (var item in Authentications.ToArray()) + lock (authDictLock) { - if (item.Value == unauthorized) + foreach (var item in Authentications.ToArray()) { - Authentications[item.Key] = string.Empty; + if (item.Value == unauthorized) + { + Authentications[item.Key] = string.Empty; + } } } } public static void RemoveAuthentication(string uid) { - var auth = Authentications.Where(u => u.Value == uid); - if (auth.Any()) + lock (authDictLock) { - Authentications.Remove(auth.First().Key, out _); + var auth = Authentications.Where(u => u.Value == uid); + if (auth.Any()) + { + Authentications.Remove(auth.First().Key); + } } } protected override async Task HandleAuthenticateAsync() { + MareMetrics.AuthenticationRequests.Inc(); + if (!Request.Headers.ContainsKey("Authorization")) { + MareMetrics.AuthenticationFailures.Inc(); return AuthenticateResult.Fail("Failed Authorization"); } var authHeader = Request.Headers["Authorization"].ToString(); if (string.IsNullOrEmpty(authHeader)) + { + MareMetrics.AuthenticationFailures.Inc(); return AuthenticateResult.Fail("Failed Authorization"); + } var ip = _accessor.GetIpAddress(); - if (FailedAuthorizations.TryGetValue(ip, out var failedAuth)) + lock (failedAuthLock) { - if (failedAuth.FailedAttempts > failedAttemptsForTempBan) + if (FailedAuthorizations.TryGetValue(ip, out var failedAuth) && failedAuth.FailedAttempts > failedAttemptsForTempBan) { - failedAuth.ResetCts.Cancel(); - failedAuth.ResetCts.Dispose(); + MareMetrics.AuthenticationFailures.Inc(); + + failedAuth.ResetCts?.Cancel(); + failedAuth.ResetCts?.Dispose(); failedAuth.ResetCts = new CancellationTokenSource(); var token = failedAuth.ResetCts.Token; failedAuth.ResetTask = Task.Run(async () => { await Task.Delay(TimeSpan.FromMinutes(tempBanMinutes), token); if (token.IsCancellationRequested) return; - FailedAuthorizations.Remove(ip, out var fauth); + FailedAuthorization fauth; + lock (failedAuthLock) + { + FailedAuthorizations.Remove(ip, out fauth); + } fauth.Dispose(); }, token); + Logger.LogWarning("TempBan " + ip + " for authorization spam"); return AuthenticateResult.Fail("Failed Authorization"); } @@ -111,49 +131,61 @@ namespace MareSynchronosServer.Authentication using var sha256 = SHA256.Create(); var hashedHeader = BitConverter.ToString(sha256.ComputeHash(Encoding.UTF8.GetBytes(authHeader))).Replace("-", ""); - if (Authentications.TryGetValue(hashedHeader, out string uid)) + string uid; + lock (authDictLock) { - if (uid == unauthorized) + if (Authentications.TryGetValue(hashedHeader, out uid)) { - Logger.LogWarning("Failed authorization from " + ip); - if (FailedAuthorizations.TryGetValue(ip, out var auth)) + if (uid == unauthorized) { - auth.IncreaseFailedAttempts(); + MareMetrics.AuthenticationFailures.Inc(); + + lock (failedAuthLock) + { + Logger.LogWarning("Failed authorization from " + ip); + if (FailedAuthorizations.TryGetValue(ip, out var auth)) + { + auth.IncreaseFailedAttempts(); + } + else + { + FailedAuthorizations[ip] = new FailedAuthorization(); + } + } + + return AuthenticateResult.Fail("Failed Authorization"); } - else - { - FailedAuthorizations[ip] = new FailedAuthorization(); - } - return AuthenticateResult.Fail("Failed Authorization"); + + MareMetrics.AuthenticationCacheHits.Inc(); } } if (string.IsNullOrEmpty(uid)) { - try - { - await dbLockSemaphore.WaitAsync(); - uid = (await _mareDbContext.Auth.Include("User").AsNoTracking() - .FirstOrDefaultAsync(m => m.HashedKey == hashedHeader))?.UserUID; - } - catch { } - finally - { - dbLockSemaphore.Release(); - } + uid = (await _mareDbContext.Auth.AsNoTracking() + .FirstOrDefaultAsync(m => m.HashedKey == hashedHeader))?.UserUID; if (uid == null) { - Authentications[hashedHeader] = unauthorized; + lock (authDictLock) + { + Authentications[hashedHeader] = unauthorized; + } + Logger.LogWarning("Failed authorization from " + ip); - if (FailedAuthorizations.TryGetValue(ip, out var auth)) + lock (failedAuthLock) { - auth.IncreaseFailedAttempts(); - } - else - { - FailedAuthorizations[ip] = new FailedAuthorization(); + if (FailedAuthorizations.TryGetValue(ip, out var auth)) + { + auth.IncreaseFailedAttempts(); + } + else + { + FailedAuthorizations[ip] = new FailedAuthorization(); + } } + + MareMetrics.AuthenticationFailures.Inc(); return AuthenticateResult.Fail("Failed Authorization"); } else @@ -170,6 +202,8 @@ namespace MareSynchronosServer.Authentication var principal = new ClaimsPrincipal(identity); var ticket = new AuthenticationTicket(principal, Scheme.Name); + MareMetrics.AuthenticationSuccesses.Inc(); + return AuthenticateResult.Success(ticket); } diff --git a/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyCounter.cs b/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyCounter.cs deleted file mode 100644 index 24dcbd7..0000000 --- a/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyCounter.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Prometheus; - -namespace MareSynchronosServer.Metrics -{ - public class LockedProxyCounter - { - private readonly Counter _c; - - public LockedProxyCounter(Counter c) - { - _c = c; - } - - public void Inc(double inc = 1d) - { - //lock (_c) - //{ - _c.Inc(inc); - //} - } - } -} diff --git a/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyGauge.cs b/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyGauge.cs deleted file mode 100644 index b62ff38..0000000 --- a/MareSynchronosServer/MareSynchronosServer/Metrics/LockedProxyGauge.cs +++ /dev/null @@ -1,47 +0,0 @@ -using Prometheus; - -namespace MareSynchronosServer.Metrics; - -public class LockedProxyGauge -{ - private readonly Gauge _g; - - public LockedProxyGauge(Gauge g) - { - _g = g; - } - - public void Inc(double inc = 1d) - { - //lock (_g) - //{ - _g.Inc(inc); - //} - } - - public void IncTo(double incTo) - { - //lock (_g) - //{ - _g.IncTo(incTo); - //} - } - - public void Dec(double decBy = 1d) - { - //lock (_g) - //{ - _g.Dec(decBy); - //} - } - - public void Set(double setTo) - { - //lock (_g) - //{ - _g.Set(setTo); - //} - } - - public double Value => _g.Value; -} \ No newline at end of file diff --git a/MareSynchronosServer/MareSynchronosServer/Metrics/MareMetrics.cs b/MareSynchronosServer/MareSynchronosServer/Metrics/MareMetrics.cs index 89c801b..a78798d 100644 --- a/MareSynchronosServer/MareSynchronosServer/Metrics/MareMetrics.cs +++ b/MareSynchronosServer/MareSynchronosServer/Metrics/MareMetrics.cs @@ -8,39 +8,43 @@ namespace MareSynchronosServer.Metrics { public class MareMetrics { - public static readonly LockedProxyCounter InitializedConnections = - new(Prometheus.Metrics.CreateCounter("mare_initialized_connections", "Initialized Connections")); - public static readonly LockedProxyGauge Connections = - new(Prometheus.Metrics.CreateGauge("mare_unauthorized_connections", "Unauthorized Connections")); - public static readonly LockedProxyGauge AuthorizedConnections = - new(Prometheus.Metrics.CreateGauge("mare_authorized_connections", "Authorized Connections")); - public static readonly LockedProxyGauge AvailableWorkerThreads = new(Prometheus.Metrics.CreateGauge("mare_available_threadpool", "Available Threadpool Workers")); - public static readonly LockedProxyGauge AvailableIOWorkerThreads = new(Prometheus.Metrics.CreateGauge("mare_available_threadpool_io", "Available Threadpool IO Workers")); + public static readonly Counter InitializedConnections = + Prometheus.Metrics.CreateCounter("mare_initialized_connections", "Initialized Connections"); + public static readonly Gauge Connections = + Prometheus.Metrics.CreateGauge("mare_unauthorized_connections", "Unauthorized Connections"); + public static readonly Gauge AuthorizedConnections = + Prometheus.Metrics.CreateGauge("mare_authorized_connections", "Authorized Connections"); + public static readonly Gauge AvailableWorkerThreads = Prometheus.Metrics.CreateGauge("mare_available_threadpool", "Available Threadpool Workers"); + public static readonly Gauge AvailableIOWorkerThreads = Prometheus.Metrics.CreateGauge("mare_available_threadpool_io", "Available Threadpool IO Workers"); - public static readonly LockedProxyGauge UsersRegistered = new(Prometheus.Metrics.CreateGauge("mare_users_registered", "Total Registrations")); + public static readonly Gauge UsersRegistered = Prometheus.Metrics.CreateGauge("mare_users_registered", "Total Registrations"); - public static readonly LockedProxyGauge Pairs = new(Prometheus.Metrics.CreateGauge("mare_pairs", "Total Pairs")); - public static readonly LockedProxyGauge PairsPaused = new(Prometheus.Metrics.CreateGauge("mare_pairs_paused", "Total Paused Pairs")); + public static readonly Gauge Pairs = Prometheus.Metrics.CreateGauge("mare_pairs", "Total Pairs"); + public static readonly Gauge PairsPaused = Prometheus.Metrics.CreateGauge("mare_pairs_paused", "Total Paused Pairs"); - public static readonly LockedProxyGauge FilesTotal = new(Prometheus.Metrics.CreateGauge("mare_files", "Total uploaded files")); - public static readonly LockedProxyGauge FilesTotalSize = - new(Prometheus.Metrics.CreateGauge("mare_files_size", "Total uploaded files (bytes)")); + public static readonly Gauge FilesTotal = Prometheus.Metrics.CreateGauge("mare_files", "Total uploaded files"); + public static readonly Gauge FilesTotalSize = + Prometheus.Metrics.CreateGauge("mare_files_size", "Total uploaded files (bytes)"); - public static readonly LockedProxyCounter UserPushData = new(Prometheus.Metrics.CreateCounter("mare_user_push", "Users pushing data")); - public static readonly LockedProxyCounter UserPushDataTo = - new(Prometheus.Metrics.CreateCounter("mare_user_push_to", "Users Receiving Data")); + public static readonly Counter UserPushData = Prometheus.Metrics.CreateCounter("mare_user_push", "Users pushing data"); + public static readonly Counter UserPushDataTo = + Prometheus.Metrics.CreateCounter("mare_user_push_to", "Users Receiving Data"); - public static readonly LockedProxyCounter UserDownloadedFiles = - new(Prometheus.Metrics.CreateCounter("mare_user_downloaded_files", "Total Downloaded Files by Users")); - public static readonly LockedProxyCounter UserDownloadedFilesSize = - new(Prometheus.Metrics.CreateCounter("mare_user_downloaded_files_size", "Total Downloaded Files Size by Users")); + public static readonly Counter UserDownloadedFiles = + Prometheus.Metrics.CreateCounter("mare_user_downloaded_files", "Total Downloaded Files by Users"); + public static readonly Counter UserDownloadedFilesSize = + Prometheus.Metrics.CreateCounter("mare_user_downloaded_files_size", "Total Downloaded Files Size by Users"); - public static readonly LockedProxyGauge - CPUUsage = new(Prometheus.Metrics.CreateGauge("mare_cpu_usage", "Total calculated CPU usage in %")); - public static readonly LockedProxyGauge RAMUsage = - new(Prometheus.Metrics.CreateGauge("mare_ram_usage", "Total calculated RAM usage in bytes for Mare + MSSQL")); - public static readonly LockedProxyGauge NetworkOut = new(Prometheus.Metrics.CreateGauge("mare_network_out", "Network out in byte/s")); - public static readonly LockedProxyGauge NetworkIn = new(Prometheus.Metrics.CreateGauge("mare_network_in", "Network in in byte/s")); + public static readonly Gauge + CPUUsage = Prometheus.Metrics.CreateGauge("mare_cpu_usage", "Total calculated CPU usage in %"); + public static readonly Gauge RAMUsage = + Prometheus.Metrics.CreateGauge("mare_ram_usage", "Total calculated RAM usage in bytes for Mare + MSSQL"); + public static readonly Gauge NetworkOut = Prometheus.Metrics.CreateGauge("mare_network_out", "Network out in byte/s"); + public static readonly Gauge NetworkIn = Prometheus.Metrics.CreateGauge("mare_network_in", "Network in in byte/s"); + public static readonly Counter AuthenticationRequests = Prometheus.Metrics.CreateCounter("mare_auth_requests", "Mare Authentication Requests"); + public static readonly Counter AuthenticationCacheHits = Prometheus.Metrics.CreateCounter("mare_auth_requests_cachehit", "Mare Authentication Requests Cache Hits"); + public static readonly Counter AuthenticationFailures = Prometheus.Metrics.CreateCounter("mare_auth_requests_fail", "Mare Authentication Requests Failed"); + public static readonly Counter AuthenticationSuccesses = Prometheus.Metrics.CreateCounter("mare_auth_requests_success", "Mare Authentication Requests Success"); public static void InitializeMetrics(MareDbContext dbContext, IConfiguration configuration) {