diff --git a/MareSynchronosServer/MareSynchronosServer/Startup.cs b/MareSynchronosServer/MareSynchronosServer/Startup.cs index 3c16ff2..153976c 100644 --- a/MareSynchronosServer/MareSynchronosServer/Startup.cs +++ b/MareSynchronosServer/MareSynchronosServer/Startup.cs @@ -68,6 +68,7 @@ namespace MareSynchronosServer .AddScheme(SecretKeyAuthenticationHandler.AuthScheme, options => { }); services.AddAuthorization(options => options.FallbackPolicy = new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build()); + services.AddSingleton(); services.AddSignalR(hubOptions => { @@ -75,9 +76,8 @@ namespace MareSynchronosServer hubOptions.EnableDetailedErrors = true; hubOptions.MaximumParallelInvocationsPerClient = 10; hubOptions.StreamBufferCapacity = 200; + hubOptions.AddFilter(); }); - - services.AddSingleton(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. diff --git a/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs new file mode 100644 index 0000000..ac0f161 --- /dev/null +++ b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs @@ -0,0 +1,58 @@ +using AspNetCoreRateLimit; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Options; +using System; +using System.Threading.Tasks; + +namespace MareSynchronosServer.Throttling; +public class SignalRLimitFilter : IHubFilter +{ + private readonly IRateLimitProcessor _processor; + + public SignalRLimitFilter( + IOptions options, IProcessingStrategy processing, IRateLimitCounterStore counterStore, + IRateLimitConfiguration rateLimitConfiguration, IIpPolicyStore policyStore) + { + _processor = new IpRateLimitProcessor(options?.Value, policyStore, processing); + } + + public async ValueTask InvokeMethodAsync( + HubInvocationContext invocationContext, Func> next) + { + var httpContext = invocationContext.Context.GetHttpContext(); + var ip = httpContext.Connection.RemoteIpAddress.ToString(); + var client = new ClientRequestIdentity + { + ClientIp = ip, + Path = invocationContext.HubMethodName, + HttpVerb = "ws", + ClientId = invocationContext.Context.UserIdentifier + }; + foreach (var rule in await _processor.GetMatchingRulesAsync(client)) + { + var counter = await _processor.ProcessRequestAsync(client, rule); + Console.WriteLine("time: {0}, count: {1}", counter.Timestamp, counter.Count); + if (counter.Count > rule.Limit) + { + var retry = counter.Timestamp.RetryAfterFrom(rule); + throw new HubException($"call limit {retry}"); + } + } + + Console.WriteLine($"Calling hub method '{invocationContext.HubMethodName}'"); + return await next(invocationContext); + } + + // Optional method + public Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + return next(context); + } + + // Optional method + public Task OnDisconnectedAsync( + HubLifetimeContext context, Exception exception, Func next) + { + return next(context, exception); + } +}