adjust authentication handler

This commit is contained in:
rootdarkarchon
2022-11-05 14:51:59 +01:00
parent efec37db34
commit 2a1b04214b
2 changed files with 62 additions and 114 deletions

View File

@@ -4,23 +4,11 @@ using System.Threading.Tasks;
namespace MareSynchronosServices.Authentication; namespace MareSynchronosServices.Authentication;
internal class FailedAuthorization : IDisposable internal class FailedAuthorization
{ {
private int failedAttempts = 1; private int failedAttempts = 1;
public int FailedAttempts => failedAttempts; public int FailedAttempts => failedAttempts;
public Task ResetTask { get; set; } public Task ResetTask { get; set; }
public CancellationTokenSource? ResetCts { get; set; }
public void Dispose()
{
try
{
ResetCts?.Cancel();
ResetCts?.Dispose();
}
catch { }
}
public void IncreaseFailedAttempts() public void IncreaseFailedAttempts()
{ {
Interlocked.Increment(ref failedAttempts); Interlocked.Increment(ref failedAttempts);

View File

@@ -1,9 +1,9 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MareSynchronosShared.Data; using MareSynchronosShared.Data;
using MareSynchronosShared.Metrics; using MareSynchronosShared.Metrics;
@@ -16,164 +16,124 @@ namespace MareSynchronosServices.Authentication;
public class SecretKeyAuthenticationHandler public class SecretKeyAuthenticationHandler
{ {
private readonly ILogger<SecretKeyAuthenticationHandler> logger; private readonly ILogger<SecretKeyAuthenticationHandler> _logger;
private readonly MareMetrics metrics; private readonly MareMetrics _metrics;
private const string Unauthorized = "Unauthorized"; private const string Unauthorized = "Unauthorized";
private readonly Dictionary<string, string> authorizations = new(); private readonly ConcurrentDictionary<string, string> _cachedAuthorizations = new();
private readonly Dictionary<string, FailedAuthorization?> failedAuthorizations = new(); private readonly ConcurrentDictionary<string, FailedAuthorization?> _failedAuthorizations = new();
private readonly object authDictLock = new();
private readonly object failedAuthLock = new();
private readonly int _failedAttemptsForTempBan; private readonly int _failedAttemptsForTempBan;
private readonly int _tempBanMinutes; private readonly int _tempBanMinutes;
private List<string> _whitelistedIps = new(); private readonly List<string> _whitelistedIps = new();
public void ClearUnauthorizedUsers() public void ClearUnauthorizedUsers()
{ {
lock (authDictLock) foreach (var item in _cachedAuthorizations.ToArray())
{ {
foreach (var item in authorizations.ToArray()) if (item.Value == Unauthorized)
{ {
if (item.Value == Unauthorized) _cachedAuthorizations[item.Key] = string.Empty;
{
authorizations[item.Key] = string.Empty;
}
} }
} }
} }
public void RemoveAuthentication(string uid) public void RemoveAuthentication(string uid)
{ {
lock (authDictLock) var authorization = _cachedAuthorizations.Where(u => u.Value == uid);
if (authorization.Any())
{ {
var authorization = authorizations.Where(u => u.Value == uid); _cachedAuthorizations.Remove(authorization.First().Key, out _);
if (authorization.Any())
{
authorizations.Remove(authorization.First().Key);
}
} }
} }
public async Task<AuthReply> AuthenticateAsync(MareDbContext mareDbContext, string ip, string secretKey) public async Task<AuthReply> AuthenticateAsync(MareDbContext mareDbContext, string ip, string secretKey)
{ {
metrics.IncCounter(MetricsAPI.CounterAuthenticationRequests); _metrics.IncCounter(MetricsAPI.CounterAuthenticationRequests);
if (string.IsNullOrEmpty(secretKey)) if (string.IsNullOrEmpty(secretKey))
{ {
metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures); _metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } }; return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } };
} }
lock (failedAuthLock) if (_failedAuthorizations.TryGetValue(ip, out var existingFailedAuthorization) && existingFailedAuthorization.FailedAttempts > _failedAttemptsForTempBan)
{ {
if (failedAuthorizations.TryGetValue(ip, out var existingFailedAuthorization) && existingFailedAuthorization.FailedAttempts > _failedAttemptsForTempBan) _metrics.IncCounter(MetricsAPI.CounterAuthenticationCacheHits);
{ _metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
if (existingFailedAuthorization.ResetTask == null)
{
_logger.LogWarning("TempBan {ip} for authorization spam", ip);
existingFailedAuthorization.ResetCts?.Cancel();
existingFailedAuthorization.ResetCts?.Dispose();
existingFailedAuthorization.ResetCts = new CancellationTokenSource();
var token = existingFailedAuthorization.ResetCts.Token;
existingFailedAuthorization.ResetTask = Task.Run(async () => existingFailedAuthorization.ResetTask = Task.Run(async () =>
{ {
await Task.Delay(TimeSpan.FromMinutes(_tempBanMinutes), token).ConfigureAwait(false); await Task.Delay(TimeSpan.FromMinutes(_tempBanMinutes)).ConfigureAwait(false);
if (token.IsCancellationRequested) return;
FailedAuthorization? failedAuthorization;
lock (failedAuthLock)
{
failedAuthorizations.Remove(ip, out failedAuthorization);
}
failedAuthorization?.Dispose();
}, token);
logger.LogWarning("TempBan {ip} for authorization spam", ip); }).ContinueWith((t) =>
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } }; {
_failedAuthorizations.Remove(ip, out _);
});
} }
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } };
} }
using var sha256 = SHA256.Create(); using var sha256 = SHA256.Create();
var hashedHeader = BitConverter.ToString(sha256.ComputeHash(Encoding.UTF8.GetBytes(secretKey))).Replace("-", ""); var hashedHeader = BitConverter.ToString(sha256.ComputeHash(Encoding.UTF8.GetBytes(secretKey))).Replace("-", "");
string uid; bool fromCache = _cachedAuthorizations.TryGetValue(hashedHeader, out string uid);
lock (authDictLock)
if (fromCache)
{ {
if (authorizations.TryGetValue(hashedHeader, out uid)) _metrics.IncCounter(MetricsAPI.CounterAuthenticationCacheHits);
if (uid == Unauthorized)
{ {
if (uid == Unauthorized) return AuthenticationFailure(ip);
{
metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
logger.LogWarning("Failed authorization from {ip}", ip);
lock (failedAuthLock)
{
if (!_whitelistedIps.Any(w => ip.Contains(w)))
{
if (failedAuthorizations.TryGetValue(ip, out var auth))
{
auth.IncreaseFailedAttempts();
}
else
{
failedAuthorizations[ip] = new FailedAuthorization();
}
}
}
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } };
}
metrics.IncCounter(MetricsAPI.CounterAuthenticationCacheHits);
} }
} }
else
if (string.IsNullOrEmpty(uid))
{ {
uid = (await mareDbContext.Auth.AsNoTracking() uid = (await mareDbContext.Auth.AsNoTracking()
.FirstOrDefaultAsync(m => m.HashedKey == hashedHeader).ConfigureAwait(false))?.UserUID; .FirstOrDefaultAsync(m => m.HashedKey == hashedHeader).ConfigureAwait(false))?.UserUID;
if (uid == null) if (uid == null)
{ {
lock (authDictLock) _cachedAuthorizations[hashedHeader] = Unauthorized;
{
authorizations[hashedHeader] = Unauthorized;
}
logger.LogWarning("Failed authorization from {ip}", ip); return AuthenticationFailure(ip);
lock (failedAuthLock)
{
if (!_whitelistedIps.Any(w => ip.Contains(w)))
{
if (failedAuthorizations.TryGetValue(ip, out var auth))
{
auth.IncreaseFailedAttempts();
}
else
{
failedAuthorizations[ip] = new FailedAuthorization();
}
}
}
metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } };
} }
lock (authDictLock) _cachedAuthorizations[hashedHeader] = uid;
{
authorizations[hashedHeader] = uid;
}
} }
metrics.IncCounter(MetricsAPI.CounterAuthenticationSuccesses); _metrics.IncCounter(MetricsAPI.CounterAuthenticationSuccesses);
return new AuthReply() { Success = true, Uid = new UidMessage() { Uid = uid } }; return new AuthReply() { Success = true, Uid = new UidMessage() { Uid = uid } };
} }
private AuthReply AuthenticationFailure(string ip)
{
_metrics.IncCounter(MetricsAPI.CounterAuthenticationFailures);
_logger.LogWarning("Failed authorization from {ip}", ip);
if (!_whitelistedIps.Any(w => ip.Contains(w)))
{
if (_failedAuthorizations.TryGetValue(ip, out var auth))
{
auth.IncreaseFailedAttempts();
}
else
{
_failedAuthorizations[ip] = new FailedAuthorization();
}
}
return new AuthReply() { Success = false, Uid = new UidMessage() { Uid = string.Empty } };
}
public SecretKeyAuthenticationHandler(IConfiguration configuration, ILogger<SecretKeyAuthenticationHandler> logger, MareMetrics metrics) public SecretKeyAuthenticationHandler(IConfiguration configuration, ILogger<SecretKeyAuthenticationHandler> logger, MareMetrics metrics)
{ {
this.logger = logger; this._logger = logger;
this.metrics = metrics; this._metrics = metrics;
var config = configuration.GetRequiredSection("MareSynchronos"); var config = configuration.GetRequiredSection("MareSynchronos");
_failedAttemptsForTempBan = config.GetValue<int>("FailedAuthForTempBan", 5); _failedAttemptsForTempBan = config.GetValue<int>("FailedAuthForTempBan", 5);
logger.LogInformation("FailedAuthForTempBan: {num}", _failedAttemptsForTempBan); logger.LogInformation("FailedAuthForTempBan: {num}", _failedAttemptsForTempBan);