From 74efb5eb6a9b62b353fdd688e50f87af305414d0 Mon Sep 17 00:00:00 2001 From: Stanley Dimant Date: Wed, 3 Aug 2022 23:34:52 +0200 Subject: [PATCH] try onconnected rate limiting --- .../Throttling/SignalRLimitFilter.cs | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs index ac0f161..59cb7b6 100644 --- a/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs +++ b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs @@ -1,4 +1,5 @@ using AspNetCoreRateLimit; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Options; using System; @@ -8,19 +9,19 @@ namespace MareSynchronosServer.Throttling; public class SignalRLimitFilter : IHubFilter { private readonly IRateLimitProcessor _processor; + private readonly IHttpContextAccessor accessor; public SignalRLimitFilter( - IOptions options, IProcessingStrategy processing, IRateLimitCounterStore counterStore, - IRateLimitConfiguration rateLimitConfiguration, IIpPolicyStore policyStore) + IOptions options, IProcessingStrategy processing, IIpPolicyStore policyStore, IHttpContextAccessor accessor) { _processor = new IpRateLimitProcessor(options?.Value, policyStore, processing); + this.accessor = accessor; } public async ValueTask InvokeMethodAsync( HubInvocationContext invocationContext, Func> next) { - var httpContext = invocationContext.Context.GetHttpContext(); - var ip = httpContext.Connection.RemoteIpAddress.ToString(); + var ip = accessor.GetIpAddress(); var client = new ClientRequestIdentity { ClientIp = ip, @@ -44,9 +45,27 @@ public class SignalRLimitFilter : IHubFilter } // Optional method - public Task OnConnectedAsync(HubLifetimeContext context, Func next) + public async Task OnConnectedAsync(HubLifetimeContext context, Func next) { - return next(context); + var ip = accessor.GetIpAddress(); + var client = new ClientRequestIdentity + { + ClientIp = ip, + Path = "Connect", + HttpVerb = "ws", + }; + foreach (var rule in await _processor.GetMatchingRulesAsync(client)) + { + var counter = await _processor.ProcessRequestAsync(client, rule); + Console.WriteLine("time: {0}, count: {1}", counter.Timestamp, counter.Count); + if (counter.Count > rule.Limit) + { + var retry = counter.Timestamp.RetryAfterFrom(rule); + throw new HubException($"rate limit {retry}"); + } + } + + await next(context); } // Optional method