Add keyed message subscribers

This commit is contained in:
Loporrit
2025-02-20 12:09:15 +00:00
parent 99c293d89f
commit 7918b54c92
2 changed files with 37 additions and 12 deletions

View File

@@ -16,9 +16,9 @@ public sealed class MareMediator : IHostedService
private readonly ConcurrentQueue<MessageBase> _messageQueue = new();
private readonly PerformanceCollectorService _performanceCollector;
private readonly MareConfigService _mareConfigService;
private readonly ConcurrentDictionary<Type, HashSet<SubscriberAction>> _subscriberDict = [];
private readonly ConcurrentDictionary<(Type, string?), HashSet<SubscriberAction>> _subscriberDict = [];
private bool _processQueue = false;
private readonly ConcurrentDictionary<Type, MethodInfo?> _genericExecuteMethods = new();
private readonly ConcurrentDictionary<(Type, string?), MethodInfo?> _genericExecuteMethods = new();
public MareMediator(ILogger<MareMediator> logger, PerformanceCollectorService performanceCollector, MareConfigService mareConfigService)
{
_logger = logger;
@@ -36,7 +36,10 @@ public sealed class MareMediator : IHostedService
sb.Append("=> ");
foreach (var item in _subscriberDict.Where(item => item.Value.Any(v => v.Subscriber == subscriber)).ToList())
{
sb.Append(item.Key.Name).Append(", ");
sb.Append(item.Key.Item1.Name);
if (item.Key.Item2 != null)
sb.Append($":{item.Key.Item2!}");
sb.Append(", ");
}
if (!string.Equals(sb.ToString(), "=> ", StringComparison.Ordinal))
@@ -99,9 +102,9 @@ public sealed class MareMediator : IHostedService
{
lock (_addRemoveLock)
{
_subscriberDict.TryAdd(typeof(T), []);
_subscriberDict.TryAdd((typeof(T), null), []);
if (!_subscriberDict[typeof(T)].Add(new(subscriber, action)))
if (!_subscriberDict[(typeof(T), null)].Add(new(subscriber, action)))
{
throw new InvalidOperationException("Already subscribed");
}
@@ -110,13 +113,28 @@ public sealed class MareMediator : IHostedService
}
}
public void SubscribeKeyed<T>(IMediatorSubscriber subscriber, string key, Action<T> action) where T : MessageBase
{
lock (_addRemoveLock)
{
_subscriberDict.TryAdd((typeof(T), key), []);
if (!_subscriberDict[(typeof(T), key)].Add(new(subscriber, action)))
{
throw new InvalidOperationException("Already subscribed");
}
_logger.LogDebug("Subscriber added for message {message}:{key}: {sub}", typeof(T).Name, key, subscriber.GetType().Name);
}
}
public void Unsubscribe<T>(IMediatorSubscriber subscriber) where T : MessageBase
{
lock (_addRemoveLock)
{
if (_subscriberDict.ContainsKey(typeof(T)))
if (_subscriberDict.ContainsKey((typeof(T), null)))
{
_subscriberDict[typeof(T)].RemoveWhere(p => p.Subscriber == subscriber);
_subscriberDict[(typeof(T), null)].RemoveWhere(p => p.Subscriber == subscriber);
}
}
}
@@ -125,12 +143,12 @@ public sealed class MareMediator : IHostedService
{
lock (_addRemoveLock)
{
foreach (Type kvp in _subscriberDict.Select(k => k.Key))
foreach (var kvp in _subscriberDict.Select(k => k.Key))
{
int unSubbed = _subscriberDict[kvp]?.RemoveWhere(p => p.Subscriber == subscriber) ?? 0;
if (unSubbed > 0)
{
_logger.LogDebug("{sub} unsubscribed from {msg}", subscriber.GetType().Name, kvp.Name);
_logger.LogDebug("{sub} unsubscribed from {msg}", subscriber.GetType().Name, kvp.Item1.Name);
}
}
}
@@ -138,7 +156,7 @@ public sealed class MareMediator : IHostedService
private void ExecuteMessage(MessageBase message)
{
if (!_subscriberDict.TryGetValue(message.GetType(), out HashSet<SubscriberAction>? subscribers) || subscribers == null || !subscribers.Any()) return;
if (!_subscriberDict.TryGetValue((message.GetType(), message.SubscriberKey), out HashSet<SubscriberAction>? subscribers) || subscribers == null || !subscribers.Any()) return;
List<SubscriberAction> subscribersCopy = [];
lock (_addRemoveLock)
@@ -148,9 +166,9 @@ public sealed class MareMediator : IHostedService
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
var msgType = message.GetType();
if (!_genericExecuteMethods.TryGetValue(msgType, out var methodInfo))
if (!_genericExecuteMethods.TryGetValue((msgType, message.SubscriberKey), out var methodInfo))
{
_genericExecuteMethods[msgType] = methodInfo = GetType()
_genericExecuteMethods[(msgType, message.SubscriberKey)] = methodInfo = GetType()
.GetMethod(nameof(ExecuteReflected), System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?
.MakeGenericMethod(msgType);
}

View File

@@ -4,10 +4,17 @@
public abstract record MessageBase
{
public virtual bool KeepThreadContext => false;
public virtual string? SubscriberKey => null;
}
public record SameThreadMessage : MessageBase
{
public override bool KeepThreadContext => true;
}
public record KeyedMessage(string MessageKey, bool SameThread = false) : MessageBase
{
public override string? SubscriberKey => MessageKey;
public override bool KeepThreadContext => SameThread;
}
#pragma warning restore MA0048