From 7918b54c92c5ca629d470662fb019f2f6841493b Mon Sep 17 00:00:00 2001 From: Loporrit <141286461+loporrit@users.noreply.github.com> Date: Thu, 20 Feb 2025 12:09:15 +0000 Subject: [PATCH] Add keyed message subscribers --- .../Services/Mediator/MareMediator.cs | 42 +++++++++++++------ .../Services/Mediator/MessageBase.cs | 7 ++++ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/MareSynchronos/Services/Mediator/MareMediator.cs b/MareSynchronos/Services/Mediator/MareMediator.cs index eecdb4c..1e97d6f 100644 --- a/MareSynchronos/Services/Mediator/MareMediator.cs +++ b/MareSynchronos/Services/Mediator/MareMediator.cs @@ -16,9 +16,9 @@ public sealed class MareMediator : IHostedService private readonly ConcurrentQueue _messageQueue = new(); private readonly PerformanceCollectorService _performanceCollector; private readonly MareConfigService _mareConfigService; - private readonly ConcurrentDictionary> _subscriberDict = []; + private readonly ConcurrentDictionary<(Type, string?), HashSet> _subscriberDict = []; private bool _processQueue = false; - private readonly ConcurrentDictionary _genericExecuteMethods = new(); + private readonly ConcurrentDictionary<(Type, string?), MethodInfo?> _genericExecuteMethods = new(); public MareMediator(ILogger logger, PerformanceCollectorService performanceCollector, MareConfigService mareConfigService) { _logger = logger; @@ -36,7 +36,10 @@ public sealed class MareMediator : IHostedService sb.Append("=> "); foreach (var item in _subscriberDict.Where(item => item.Value.Any(v => v.Subscriber == subscriber)).ToList()) { - sb.Append(item.Key.Name).Append(", "); + sb.Append(item.Key.Item1.Name); + if (item.Key.Item2 != null) + sb.Append($":{item.Key.Item2!}"); + sb.Append(", "); } if (!string.Equals(sb.ToString(), "=> ", StringComparison.Ordinal)) @@ -99,9 +102,9 @@ public sealed class MareMediator : IHostedService { lock (_addRemoveLock) { - _subscriberDict.TryAdd(typeof(T), []); + _subscriberDict.TryAdd((typeof(T), null), []); - if (!_subscriberDict[typeof(T)].Add(new(subscriber, action))) + if (!_subscriberDict[(typeof(T), null)].Add(new(subscriber, action))) { throw new InvalidOperationException("Already subscribed"); } @@ -110,13 +113,28 @@ public sealed class MareMediator : IHostedService } } + public void SubscribeKeyed(IMediatorSubscriber subscriber, string key, Action action) where T : MessageBase + { + lock (_addRemoveLock) + { + _subscriberDict.TryAdd((typeof(T), key), []); + + if (!_subscriberDict[(typeof(T), key)].Add(new(subscriber, action))) + { + throw new InvalidOperationException("Already subscribed"); + } + + _logger.LogDebug("Subscriber added for message {message}:{key}: {sub}", typeof(T).Name, key, subscriber.GetType().Name); + } + } + public void Unsubscribe(IMediatorSubscriber subscriber) where T : MessageBase { lock (_addRemoveLock) { - if (_subscriberDict.ContainsKey(typeof(T))) + if (_subscriberDict.ContainsKey((typeof(T), null))) { - _subscriberDict[typeof(T)].RemoveWhere(p => p.Subscriber == subscriber); + _subscriberDict[(typeof(T), null)].RemoveWhere(p => p.Subscriber == subscriber); } } } @@ -125,12 +143,12 @@ public sealed class MareMediator : IHostedService { lock (_addRemoveLock) { - foreach (Type kvp in _subscriberDict.Select(k => k.Key)) + foreach (var kvp in _subscriberDict.Select(k => k.Key)) { int unSubbed = _subscriberDict[kvp]?.RemoveWhere(p => p.Subscriber == subscriber) ?? 0; if (unSubbed > 0) { - _logger.LogDebug("{sub} unsubscribed from {msg}", subscriber.GetType().Name, kvp.Name); + _logger.LogDebug("{sub} unsubscribed from {msg}", subscriber.GetType().Name, kvp.Item1.Name); } } } @@ -138,7 +156,7 @@ public sealed class MareMediator : IHostedService private void ExecuteMessage(MessageBase message) { - if (!_subscriberDict.TryGetValue(message.GetType(), out HashSet? subscribers) || subscribers == null || !subscribers.Any()) return; + if (!_subscriberDict.TryGetValue((message.GetType(), message.SubscriberKey), out HashSet? subscribers) || subscribers == null || !subscribers.Any()) return; List subscribersCopy = []; lock (_addRemoveLock) @@ -148,9 +166,9 @@ public sealed class MareMediator : IHostedService #pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields var msgType = message.GetType(); - if (!_genericExecuteMethods.TryGetValue(msgType, out var methodInfo)) + if (!_genericExecuteMethods.TryGetValue((msgType, message.SubscriberKey), out var methodInfo)) { - _genericExecuteMethods[msgType] = methodInfo = GetType() + _genericExecuteMethods[(msgType, message.SubscriberKey)] = methodInfo = GetType() .GetMethod(nameof(ExecuteReflected), System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)? .MakeGenericMethod(msgType); } diff --git a/MareSynchronos/Services/Mediator/MessageBase.cs b/MareSynchronos/Services/Mediator/MessageBase.cs index e29bf8d..40d9de2 100644 --- a/MareSynchronos/Services/Mediator/MessageBase.cs +++ b/MareSynchronos/Services/Mediator/MessageBase.cs @@ -4,10 +4,17 @@ public abstract record MessageBase { public virtual bool KeepThreadContext => false; + public virtual string? SubscriberKey => null; } public record SameThreadMessage : MessageBase { public override bool KeepThreadContext => true; } + +public record KeyedMessage(string MessageKey, bool SameThread = false) : MessageBase +{ + public override string? SubscriberKey => MessageKey; + public override bool KeepThreadContext => SameThread; +} #pragma warning restore MA0048 \ No newline at end of file