Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ namespace MQTTnet.Server.Internal
{
public interface ISubscriptionChangedNotification
{
void OnSubscriptionsAdded(MqttSession clientSession, List<string> subscriptionsTopics);
void OnSubscriptionsAdded(MqttSession clientSession, List<MqttSubscription> subscriptionsTopics);

void OnSubscriptionsRemoved(MqttSession clientSession, List<string> subscriptionTopics);
}
Expand Down
91 changes: 74 additions & 17 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification
// The _sessions dictionary contains all session, the _subscriberSessions hash set contains subscriber sessions only.
// See the MqttSubscription object for a detailed explanation.
readonly MqttSessionsStorage _sessionsStorage = new();
readonly HashSet<MqttSession> _subscriberSessions = [];
readonly HashSet<MqttSession> _subscriberSessionsWithWildcards = [];
readonly Dictionary<string, HashSet<MqttSession>> _simpleTopicToSessions = [];

public MqttClientSessionsManager(MqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, MqttServerEventContainer eventContainer, IMqttNetLogger logger)
{
Expand Down Expand Up @@ -77,7 +78,7 @@ public async Task DeleteSessionAsync(string clientId)
{
if (_sessionsStorage.TryRemoveSession(clientId, out session))
{
_subscriberSessions.Remove(session);
CleanupClientSessionUnsafe(session);
}
}
finally
Expand Down Expand Up @@ -161,11 +162,30 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage).ConfigureAwait(false);
}

List<MqttSession> subscriberSessions;
HashSet<MqttSession> subscriberSessions;
_sessionsManagementLock.EnterReadLock();
try
{
subscriberSessions = _subscriberSessions.ToList();
if (_simpleTopicToSessions.TryGetValue(applicationMessage.Topic, out var matchedSimpleTopicSessions))
{
// Create the initial subscriberSessions from whichever set is larger to take advantage
// of the internal ConstructFrom other HashSet optimizations
if (matchedSimpleTopicSessions.Count > _subscriberSessionsWithWildcards.Count)
{
subscriberSessions = new HashSet<MqttSession>(matchedSimpleTopicSessions);
subscriberSessions.UnionWith(_subscriberSessionsWithWildcards);
}
else
{
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
subscriberSessions.UnionWith(matchedSimpleTopicSessions);
}
}
else
{
// Always include the sessions with wildcards. They need to be properly matched against the topic filter.
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
}
}
finally
{
Expand Down Expand Up @@ -446,20 +466,32 @@ public async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter
}
}

public void OnSubscriptionsAdded(MqttSession clientSession, List<string> topics)
public void OnSubscriptionsAdded(MqttSession clientSession, List<MqttSubscription> subscriptions)
{
_sessionsManagementLock.EnterWriteLock();
try
{
if (!clientSession.HasSubscribedTopics)
foreach (var subscription in subscriptions)
{
// first subscribed topic
_subscriberSessions.Add(clientSession);
}

foreach (var topic in topics)
{
clientSession.AddSubscribedTopic(topic);
if (subscription.TopicHasWildcard)
{
if (!clientSession.HasSubscribedWildcardTopics)
{
_subscriberSessionsWithWildcards.Add(clientSession);
}
}
else
{
if (_simpleTopicToSessions.TryGetValue(subscription.Topic, out var simpleTopicSessions))
{
simpleTopicSessions.Add(clientSession);
}
else
{
_simpleTopicToSessions[subscription.Topic] = [clientSession];
}
}
clientSession.AddSubscribedTopic(subscription.Topic, subscription.TopicHasWildcard);
}
}
finally
Expand All @@ -475,13 +507,21 @@ public void OnSubscriptionsRemoved(MqttSession clientSession, List<string> subsc
{
foreach (var subscriptionTopic in subscriptionTopics)
{
if (_simpleTopicToSessions.TryGetValue(subscriptionTopic, out var simpleTopicSessions))
{
simpleTopicSessions.Remove(clientSession);
if (simpleTopicSessions.Count == 0)
{
_simpleTopicToSessions.Remove(subscriptionTopic);
}
}
clientSession.RemoveSubscribedTopic(subscriptionTopic);
}

if (!clientSession.HasSubscribedTopics)
if (!clientSession.HasSubscribedWildcardTopics)
{
// last subscription removed
_subscriberSessions.Remove(clientSession);
// Last wildcard subscription removed
_subscriberSessionsWithWildcards.Remove(clientSession);
}
}
finally
Expand Down Expand Up @@ -564,7 +604,7 @@ async Task<MqttConnectedClient> CreateClientConnection(
if (connectPacket.CleanSession)
{
_logger.Verbose("Deleting existing session of client '{0}' due to clean start", connectPacket.ClientId);
_subscriberSessions.Remove(oldSession);
CleanupClientSessionUnsafe(oldSession);
session = CreateSession(connectPacket, validatingConnectionEventArgs);
}
else
Expand Down Expand Up @@ -669,6 +709,23 @@ MqttSession GetClientSession(string clientId)
}
}

//* Must be called with the _sessionsManagementLock held.
void CleanupClientSessionUnsafe(MqttSession session)
{
_subscriberSessionsWithWildcards.Remove(session);
foreach (var simpleTopic in session.SubscribedSimpleTopics)
{
if (_simpleTopicToSessions.TryGetValue(simpleTopic, out var simpleTopicSessions))
{
simpleTopicSessions.Remove(session);
if (simpleTopicSessions.Count == 0)
{
_simpleTopicToSessions.Remove(simpleTopic);
}
}
}
}

async Task<MqttConnectPacket> ReceiveConnectPacket(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public async Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket
var retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false);
var result = new SubscribeResult(subscribePacket.TopicFilters.Count);

var addedSubscriptions = new List<string>();
var addedSubscriptions = new List<MqttSubscription>();
var finalTopicFilters = new List<MqttTopicFilter>();

// The topic filters are order by its QoS so that the higher QoS will win over a
Expand Down Expand Up @@ -195,7 +195,7 @@ public async Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket

var createSubscriptionResult = CreateSubscription(topicFilter, subscribePacket.SubscriptionIdentifier, interceptorEventArgs.Response.ReasonCode);

addedSubscriptions.Add(topicFilter.Topic);
addedSubscriptions.Add(createSubscriptionResult.Subscription);
finalTopicFilters.Add(topicFilter);

FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result);
Expand Down
23 changes: 14 additions & 9 deletions Source/MQTTnet.Server/Internal/MqttSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ public sealed class MqttSession : IDisposable
// Do not use a dictionary in order to keep the ordering of the messages.
readonly List<MqttPublishPacket> _unacknowledgedPublishPackets = new();

// Bookkeeping to know if this is a subscribing client; lazy initialize later.
HashSet<string> _subscribedTopics;
readonly HashSet<string> _subscribedSimpleTopics = [];
readonly HashSet<string> _subscribedWildcardTopics = [];

public MqttSession(
MqttConnectPacket connectPacket,
Expand All @@ -50,7 +50,9 @@ public MqttSession(

public uint ExpiryInterval => _connectPacket.SessionExpiryInterval;

public bool HasSubscribedTopics => _subscribedTopics != null && _subscribedTopics.Count > 0;
public bool HasSubscribedWildcardTopics => _subscribedWildcardTopics.Count > 0;

public HashSet<string> SubscribedSimpleTopics => _subscribedSimpleTopics;

public string Id => _connectPacket.ClientId;

Expand Down Expand Up @@ -79,14 +81,16 @@ public MqttPublishPacket AcknowledgePublishPacket(ushort packetIdentifier)
return publishPacket;
}

public void AddSubscribedTopic(string topic)
public void AddSubscribedTopic(string topic, bool isWildcardTopic)
{
if (_subscribedTopics == null)
if (isWildcardTopic)
{
_subscribedTopics = new HashSet<string>();
_subscribedWildcardTopics.Add(topic);
}
else
{
_subscribedSimpleTopics.Add(topic);
}

_subscribedTopics.Add(topic);
}

public Task DeleteAsync()
Expand Down Expand Up @@ -208,7 +212,8 @@ public void Recover()

public void RemoveSubscribedTopic(string topic)
{
_subscribedTopics?.Remove(topic);
_subscribedSimpleTopics.Remove(topic);
_subscribedWildcardTopics.Remove(topic);
}

public Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
Expand Down
1 change: 0 additions & 1 deletion Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Server;
using MQTTnet.Server.Internal;

namespace MQTTnet.Tests
Expand Down