try onconnected rate limiting

This commit is contained in:
Stanley Dimant
2022-08-03 23:34:52 +02:00
parent 831029a244
commit 74efb5eb6a

View File

@@ -1,4 +1,5 @@
using AspNetCoreRateLimit; using AspNetCoreRateLimit;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using System; using System;
@@ -8,19 +9,19 @@ namespace MareSynchronosServer.Throttling;
public class SignalRLimitFilter : IHubFilter public class SignalRLimitFilter : IHubFilter
{ {
private readonly IRateLimitProcessor _processor; private readonly IRateLimitProcessor _processor;
private readonly IHttpContextAccessor accessor;
public SignalRLimitFilter( public SignalRLimitFilter(
IOptions<IpRateLimitOptions> options, IProcessingStrategy processing, IRateLimitCounterStore counterStore, IOptions<IpRateLimitOptions> options, IProcessingStrategy processing, IIpPolicyStore policyStore, IHttpContextAccessor accessor)
IRateLimitConfiguration rateLimitConfiguration, IIpPolicyStore policyStore)
{ {
_processor = new IpRateLimitProcessor(options?.Value, policyStore, processing); _processor = new IpRateLimitProcessor(options?.Value, policyStore, processing);
this.accessor = accessor;
} }
public async ValueTask<object> InvokeMethodAsync( public async ValueTask<object> InvokeMethodAsync(
HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next) HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{ {
var httpContext = invocationContext.Context.GetHttpContext(); var ip = accessor.GetIpAddress();
var ip = httpContext.Connection.RemoteIpAddress.ToString();
var client = new ClientRequestIdentity var client = new ClientRequestIdentity
{ {
ClientIp = ip, ClientIp = ip,
@@ -44,9 +45,27 @@ public class SignalRLimitFilter : IHubFilter
} }
// Optional method // Optional method
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next) public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{ {
return next(context); var ip = accessor.GetIpAddress();
var client = new ClientRequestIdentity
{
ClientIp = ip,
Path = "Connect",
HttpVerb = "ws",
};
foreach (var rule in await _processor.GetMatchingRulesAsync(client))
{
var counter = await _processor.ProcessRequestAsync(client, rule);
Console.WriteLine("time: {0}, count: {1}", counter.Timestamp, counter.Count);
if (counter.Count > rule.Limit)
{
var retry = counter.Timestamp.RetryAfterFrom(rule);
throw new HubException($"rate limit {retry}");
}
}
await next(context);
} }
// Optional method // Optional method