diff --git a/MareSynchronosServer/MareSynchronosServer/Startup.cs b/MareSynchronosServer/MareSynchronosServer/Startup.cs index efdbe3b..05dcbf0 100644 --- a/MareSynchronosServer/MareSynchronosServer/Startup.cs +++ b/MareSynchronosServer/MareSynchronosServer/Startup.cs @@ -18,6 +18,7 @@ using Microsoft.Extensions.FileProviders; using Microsoft.AspNetCore.Authorization; using MareSynchronosServer.Discord; using AspNetCoreRateLimit; +using MareSynchronosServer.Throttling; namespace MareSynchronosServer { @@ -37,14 +38,6 @@ namespace MareSynchronosServer { services.AddHttpContextAccessor(); - services.AddSignalR(hubOptions => - { - hubOptions.MaximumReceiveMessageSize = long.MaxValue; - hubOptions.EnableDetailedErrors = true; - hubOptions.MaximumParallelInvocationsPerClient = 10; - hubOptions.StreamBufferCapacity = 200; - }); - services.AddMemoryCache(); services.Configure(Configuration.GetSection("IpRateLimiting")); @@ -77,6 +70,15 @@ namespace MareSynchronosServer services.AddAuthorization(options => options.FallbackPolicy = new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build()); services.AddSingleton(); + + services.AddSignalR(hubOptions => + { + hubOptions.MaximumReceiveMessageSize = long.MaxValue; + hubOptions.EnableDetailedErrors = true; + hubOptions.MaximumParallelInvocationsPerClient = 10; + hubOptions.StreamBufferCapacity = 200; + hubOptions.AddFilter(); + }); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. diff --git a/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs new file mode 100644 index 0000000..ac0f161 --- /dev/null +++ b/MareSynchronosServer/MareSynchronosServer/Throttling/SignalRLimitFilter.cs @@ -0,0 +1,58 @@ +using AspNetCoreRateLimit; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Options; +using System; +using System.Threading.Tasks; + +namespace MareSynchronosServer.Throttling; +public class SignalRLimitFilter : IHubFilter +{ + private readonly IRateLimitProcessor _processor; + + public SignalRLimitFilter( + IOptions options, IProcessingStrategy processing, IRateLimitCounterStore counterStore, + IRateLimitConfiguration rateLimitConfiguration, IIpPolicyStore policyStore) + { + _processor = new IpRateLimitProcessor(options?.Value, policyStore, processing); + } + + public async ValueTask InvokeMethodAsync( + HubInvocationContext invocationContext, Func> next) + { + var httpContext = invocationContext.Context.GetHttpContext(); + var ip = httpContext.Connection.RemoteIpAddress.ToString(); + var client = new ClientRequestIdentity + { + ClientIp = ip, + Path = invocationContext.HubMethodName, + HttpVerb = "ws", + ClientId = invocationContext.Context.UserIdentifier + }; + foreach (var rule in await _processor.GetMatchingRulesAsync(client)) + { + var counter = await _processor.ProcessRequestAsync(client, rule); + Console.WriteLine("time: {0}, count: {1}", counter.Timestamp, counter.Count); + if (counter.Count > rule.Limit) + { + var retry = counter.Timestamp.RetryAfterFrom(rule); + throw new HubException($"call limit {retry}"); + } + } + + Console.WriteLine($"Calling hub method '{invocationContext.HubMethodName}'"); + return await next(invocationContext); + } + + // Optional method + public Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + return next(context); + } + + // Optional method + public Task OnDisconnectedAsync( + HubLifetimeContext context, Exception exception, Func next) + { + return next(context, exception); + } +}