From d871002d0155e6c5ba68ba24270b61812157c34c Mon Sep 17 00:00:00 2001 From: Loporrit <141286461+loporrit@users.noreply.github.com> Date: Fri, 1 Sep 2023 19:10:27 +0000 Subject: [PATCH] Account registration endpoint --- .../AccountRegistrationService.cs | 153 ++++++++++++++++++ .../Controllers/JwtController.cs | 12 ++ .../MareSynchronosServer/Startup.cs | 2 + .../Metrics/MetricsAPI.cs | 1 + .../Utils/MareConfigurationAuthBase.cs | 6 + 5 files changed, 174 insertions(+) create mode 100644 MareSynchronosServer/MareSynchronosServer/Authentication/AccountRegistrationService.cs diff --git a/MareSynchronosServer/MareSynchronosServer/Authentication/AccountRegistrationService.cs b/MareSynchronosServer/MareSynchronosServer/Authentication/AccountRegistrationService.cs new file mode 100644 index 0000000..fd1576c --- /dev/null +++ b/MareSynchronosServer/MareSynchronosServer/Authentication/AccountRegistrationService.cs @@ -0,0 +1,153 @@ +using System.Collections.Concurrent; +using MareSynchronos.API.Dto.Account; +using MareSynchronosShared.Data; +using MareSynchronosShared.Metrics; +using MareSynchronosShared.Services; +using MareSynchronosShared.Utils; +using Microsoft.EntityFrameworkCore; +using System.Text.RegularExpressions; +using MareSynchronosShared.Models; + +namespace MareSynchronosServer.Authentication; + +internal record IpRegistrationCount +{ + private int count = 1; + public int Count => count; + public Task ResetTask { get; set; } + public CancellationTokenSource ResetTaskCts { get; set; } + public void IncreaseCount() + { + Interlocked.Increment(ref count); + } +} + +public class AccountRegistrationService +{ + private readonly MareMetrics _metrics; + private readonly MareDbContext _mareDbContext; + private readonly IServiceScopeFactory _serviceScopeFactory; + private readonly IConfigurationService _configurationService; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _registrationsPerIp = new(StringComparer.Ordinal); + + private Regex _registrationUserAgentRegex = new Regex(@"^MareSynchronos/", RegexOptions.Compiled); + + public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext, + IServiceScopeFactory serviceScopeFactory, IConfigurationService configuration, + ILogger logger) + { + _mareDbContext = mareDbContext; + _logger = logger; + _configurationService = configuration; + _metrics = metrics; + _serviceScopeFactory = serviceScopeFactory; + } + + public async Task RegisterAccountAsync(string ua, string ip) + { + var reply = new RegisterReplyDto(); + + if (!_registrationUserAgentRegex.Match(ua).Success) + { + reply.ErrorMessage = "User-Agent not allowed"; + return reply; + } + + if (_registrationsPerIp.TryGetValue(ip, out var registrationCount) + && registrationCount.Count >= _configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.RegisterIpLimit), 3)) + { + _logger.LogWarning("Rejecting {ip} for registration spam", ip); + + if (registrationCount.ResetTask == null) + { + registrationCount.ResetTaskCts = new CancellationTokenSource(); + + if (registrationCount.ResetTaskCts != null) + registrationCount.ResetTaskCts.Cancel(); + + registrationCount.ResetTask = Task.Run(async () => + { + await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false); + + }).ContinueWith((t) => + { + _registrationsPerIp.Remove(ip, out _); + }, registrationCount.ResetTaskCts.Token); + } + reply.ErrorMessage = "Too many registrations from this IP. Please try again later."; + return reply; + } + + var user = new User(); + + var hasValidUid = false; + while (!hasValidUid) + { + var uid = StringUtils.GenerateRandomString(7); + if (_mareDbContext.Users.Any(u => u.UID == uid || u.Alias == uid)) continue; + user.UID = uid; + hasValidUid = true; + } + + // make the first registered user on the service to admin + if (!await _mareDbContext.Users.AnyAsync().ConfigureAwait(false)) + { + user.IsAdmin = true; + } + + user.LastLoggedIn = DateTime.UtcNow; + + var computedHash = StringUtils.Sha256String(StringUtils.GenerateRandomString(64) + DateTime.UtcNow.ToString()); + var auth = new Auth() + { + HashedKey = StringUtils.Sha256String(computedHash), + User = user, + }; + + await _mareDbContext.Users.AddAsync(user).ConfigureAwait(false); + await _mareDbContext.Auth.AddAsync(auth).ConfigureAwait(false); + await _mareDbContext.SaveChangesAsync().ConfigureAwait(false); + + _logger.LogInformation("User registered: {userUID} from IP {ip}", user.UID, ip); + _metrics.IncCounter(MetricsAPI.CounterAuthenticationRequests); + + reply.Success = true; + reply.UID = user.UID; + reply.SecretKey = computedHash; + + RecordIpRegistration(ip); + + return reply; + } + + private void RecordIpRegistration(string ip) + { + var whitelisted = _configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.WhitelistedIps), new List()); + if (!whitelisted.Any(w => ip.Contains(w, StringComparison.OrdinalIgnoreCase))) + { + if (_registrationsPerIp.TryGetValue(ip, out var count)) + { + count.IncreaseCount(); + } + else + { + count = _registrationsPerIp[ip] = new IpRegistrationCount(); + + if (count.ResetTaskCts != null) + count.ResetTaskCts.Cancel(); + + count.ResetTaskCts = new CancellationTokenSource(); + + count.ResetTask = Task.Run(async () => + { + await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false); + + }).ContinueWith((t) => + { + _registrationsPerIp.Remove(ip, out _); + }, count.ResetTaskCts.Token); + } + } + } +} diff --git a/MareSynchronosServer/MareSynchronosServer/Controllers/JwtController.cs b/MareSynchronosServer/MareSynchronosServer/Controllers/JwtController.cs index 2810617..242d547 100644 --- a/MareSynchronosServer/MareSynchronosServer/Controllers/JwtController.cs +++ b/MareSynchronosServer/MareSynchronosServer/Controllers/JwtController.cs @@ -24,10 +24,12 @@ public class JwtController : Controller private readonly IRedisDatabase _redis; private readonly MareDbContext _mareDbContext; private readonly SecretKeyAuthenticatorService _secretKeyAuthenticatorService; + private readonly AccountRegistrationService _accountRegistrationService; private readonly IConfigurationService _configuration; public JwtController(IHttpContextAccessor accessor, MareDbContext mareDbContext, SecretKeyAuthenticatorService secretKeyAuthenticatorService, + AccountRegistrationService accountRegistrationService, IConfigurationService configuration, IRedisDatabase redisDb) { @@ -35,6 +37,7 @@ public class JwtController : Controller _redis = redisDb; _mareDbContext = mareDbContext; _secretKeyAuthenticatorService = secretKeyAuthenticatorService; + _accountRegistrationService = accountRegistrationService; _configuration = configuration; } @@ -114,6 +117,15 @@ public class JwtController : Controller return Content(token.RawData); } + [AllowAnonymous] + [HttpPost(MareAuth.Auth_Register)] + public async Task Register() + { + var ua = HttpContext.Request.Headers["User-Agent"][0] ?? "-"; + var ip = _accessor.GetIpAddress(); + return Json(await _accountRegistrationService.RegisterAccountAsync(ua, ip)); + } + private JwtSecurityToken CreateToken(IEnumerable authClaims) { var authSigningKey = new SymmetricSecurityKey(Encoding.ASCII.GetBytes(_configuration.GetValue(nameof(MareConfigurationAuthBase.Jwt)))); diff --git a/MareSynchronosServer/MareSynchronosServer/Startup.cs b/MareSynchronosServer/MareSynchronosServer/Startup.cs index 509357c..b45da78 100644 --- a/MareSynchronosServer/MareSynchronosServer/Startup.cs +++ b/MareSynchronosServer/MareSynchronosServer/Startup.cs @@ -184,6 +184,7 @@ public class Startup private static void ConfigureAuthorization(IServiceCollection services) { services.AddSingleton(); + services.AddSingleton(); services.AddTransient(); services.AddOptions(JwtBearerDefaults.AuthenticationScheme) @@ -257,6 +258,7 @@ public class Startup MetricsAPI.CounterAuthenticationFailures, MetricsAPI.CounterAuthenticationRequests, MetricsAPI.CounterAuthenticationSuccesses, + MetricsAPI.CounterAccountsCreated, }, new List { MetricsAPI.GaugeAuthorizedConnections, diff --git a/MareSynchronosServer/MareSynchronosShared/Metrics/MetricsAPI.cs b/MareSynchronosServer/MareSynchronosShared/Metrics/MetricsAPI.cs index 7c41bbc..3ca9592 100644 --- a/MareSynchronosServer/MareSynchronosShared/Metrics/MetricsAPI.cs +++ b/MareSynchronosServer/MareSynchronosShared/Metrics/MetricsAPI.cs @@ -33,4 +33,5 @@ public class MetricsAPI public const string GaugeDownloadQueue = "mare_download_queue"; public const string CounterFileRequests = "mare_files_requests"; public const string CounterFileRequestSize = "mare_files_request_size"; + public const string CounterAccountsCreated = "mare_accounts_created"; } \ No newline at end of file diff --git a/MareSynchronosServer/MareSynchronosShared/Utils/MareConfigurationAuthBase.cs b/MareSynchronosServer/MareSynchronosShared/Utils/MareConfigurationAuthBase.cs index 966f956..b9683d0 100644 --- a/MareSynchronosServer/MareSynchronosShared/Utils/MareConfigurationAuthBase.cs +++ b/MareSynchronosServer/MareSynchronosShared/Utils/MareConfigurationAuthBase.cs @@ -9,6 +9,10 @@ public class MareConfigurationAuthBase : MareConfigurationBase [RemoteConfiguration] public int TempBanDurationInMinutes { get; set; } = 5; [RemoteConfiguration] + public int RegisterIpLimit { get; set; } = 3; + [RemoteConfiguration] + public int RegisterIpDurationInMinutes { get; set; } = 10; + [RemoteConfiguration] public List WhitelistedIps { get; set; } = new(); public override string ToString() @@ -17,6 +21,8 @@ public class MareConfigurationAuthBase : MareConfigurationBase sb.AppendLine(base.ToString()); sb.AppendLine($"{nameof(FailedAuthForTempBan)} => {FailedAuthForTempBan}"); sb.AppendLine($"{nameof(TempBanDurationInMinutes)} => {TempBanDurationInMinutes}"); + sb.AppendLine($"{nameof(RegisterIpLimit)} => {RegisterIpLimit}"); + sb.AppendLine($"{nameof(RegisterIpDurationInMinutes)} => {RegisterIpDurationInMinutes}"); sb.AppendLine($"{nameof(Jwt)} => {Jwt}"); sb.AppendLine($"{nameof(WhitelistedIps)} => {string.Join(", ", WhitelistedIps)}"); return sb.ToString();