Skip to content
70 changes: 70 additions & 0 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool)
public int ActiveCommandId { get; private set; }
public int CancellationTimeout { get; private set; }
public int ConnectionId { get; set; }
public string? ServerHostname { get; set; }
public byte[]? AuthPluginData { get; set; }
public long CreatedTimestamp { get; }
public ConnectionPool? Pool { get; }
Expand Down Expand Up @@ -122,6 +123,24 @@ public void DoCancel(ICancellableCommand commandToCancel, MySqlCommand killComma
return;
}

// Verify server identity before executing KILL QUERY to prevent cancelling on the wrong server
var killSession = killCommand.Connection!.Session;
if (!string.IsNullOrEmpty(ServerHostname) && !string.IsNullOrEmpty(killSession.ServerHostname))
{
if (!string.Equals(ServerHostname, killSession.ServerHostname, StringComparison.Ordinal))
{
Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname);
return;
}
}
else if (!string.IsNullOrEmpty(ServerHostname) || !string.IsNullOrEmpty(killSession.ServerHostname))
{
// One session has hostname, the other doesn't - this is a potential mismatch
Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname);
return;
}
// If both sessions have no hostname, allow the operation for backward compatibility

// NOTE: This command is executed while holding the lock to prevent race conditions during asynchronous cancellation.
// For example, if the lock weren't held, the current command could finish and the other thread could set ActiveCommandId
// to zero, then start executing a new command. By the time this "KILL QUERY" command reached the server, the wrong
Expand Down Expand Up @@ -640,6 +659,9 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
ConnectionId = newConnectionId;
}

// Get server hostname for KILL QUERY verification
await GetServerHostnameAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout;
return redirectionUrl;
}
Expand Down Expand Up @@ -1963,6 +1985,52 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation
}
}

private async Task GetServerHostnameAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
Log.GettingServerHostname(m_logger, Id);
try
{
var payload = SupportsQueryAttributes ? s_selectHostnameWithAttributesPayload : s_selectHostnameNoAttributesPayload;
await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);

// column count: 1
_ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);

// @@hostname column
_ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

if (!SupportsDeprecateEof)
{
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
_ = EofPayload.Create(payload.Span);
}

// first (and only) row
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

var reader = new ByteArrayReader(payload.Span);
var length = reader.ReadLengthEncodedIntegerOrNull();
var hostname = length > 0 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null;

ServerHostname = hostname;

Log.RetrievedServerHostname(m_logger, Id, hostname);

// OK/EOF payload
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
if (OkPayload.IsOk(payload.Span, this))
OkPayload.Verify(payload.Span, this);
else
EofPayload.Create(payload.Span);
}
catch (MySqlException ex)
{
Log.FailedToGetServerHostname(m_logger, ex, Id);
// Set fallback value to ensure operation can continue
ServerHostname = null;
}
}

private void ShutdownSocket()
{
Log.ClosingStreamSocket(m_logger, Id);
Expand Down Expand Up @@ -2194,6 +2262,8 @@ protected override void OnStatementBegin(int index)
private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @__MySqlConnector__Sleep;"u8);
private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8);
private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8);
private static readonly PayloadData s_selectHostnameNoAttributesPayload = QueryPayload.Create(false, "SELECT @@hostname;"u8);
private static readonly PayloadData s_selectHostnameWithAttributesPayload = QueryPayload.Create(true, "SELECT @@hostname;"u8);

private readonly ILogger m_logger;
#if NET9_0_OR_GREATER
Expand Down
4 changes: 4 additions & 0 deletions src/MySqlConnector/Logging/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ internal static class EventIds
public const int CertificateErrorUnixSocket = 2158;
public const int CertificateErrorNoPassword = 2159;
public const int CertificateErrorValidThumbprint = 2160;
public const int GettingServerHostname = 2161;
public const int RetrievedServerHostname = 2162;
public const int FailedToGetServerHostname = 2163;

// Command execution events, 2200-2299
public const int CannotExecuteNewCommandInState = 2200;
Expand All @@ -108,6 +111,7 @@ internal static class EventIds
public const int IgnoringCancellationForInactiveCommand = 2306;
public const int CancelingCommand = 2307;
public const int SendingSleepToClearPendingCancellation = 2308;
public const int IgnoringCancellationForDifferentServer = 2309;

// Cached procedure events, 2400-2499
public const int GettingCachedProcedure = 2400;
Expand Down
12 changes: 12 additions & 0 deletions src/MySqlConnector/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ internal static partial class Log
[LoggerMessage(EventIds.FailedToGetConnectionId, LogLevel.Information, "Session {SessionId} failed to get CONNECTION_ID(), VERSION()")]
public static partial void FailedToGetConnectionId(ILogger logger, Exception exception, string sessionId);

[LoggerMessage(EventIds.GettingServerHostname, LogLevel.Debug, "Session {SessionId} getting server hostname")]
public static partial void GettingServerHostname(ILogger logger, string sessionId);

[LoggerMessage(EventIds.RetrievedServerHostname, LogLevel.Debug, "Session {SessionId} retrieved server hostname: {ServerHostname}")]
public static partial void RetrievedServerHostname(ILogger logger, string sessionId, string? serverHostname);

[LoggerMessage(EventIds.FailedToGetServerHostname, LogLevel.Information, "Session {SessionId} failed to get server hostname")]
public static partial void FailedToGetServerHostname(ILogger logger, Exception exception, string sessionId);

[LoggerMessage(EventIds.IgnoringCancellationForDifferentServer, LogLevel.Warning, "Session {SessionId} ignoring cancellation from session {KillSessionId}: server hostname mismatch (this hostname={ServerHostname}, kill hostname={KillServerHostname})")]
public static partial void IgnoringCancellationForDifferentServer(ILogger logger, string sessionId, string killSessionId, string? serverHostname, string? killServerHostname);

[LoggerMessage(EventIds.ClosingStreamSocket, LogLevel.Debug, "Session {SessionId} closing stream/socket")]
public static partial void ClosingStreamSocket(ILogger logger, string sessionId);

Expand Down
53 changes: 53 additions & 0 deletions tests/IntegrationTests/ServerIdentificationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Diagnostics;

namespace IntegrationTests;

public class ServerIdentificationTests : IClassFixture<DatabaseFixture>, IDisposable
{
public ServerIdentificationTests(DatabaseFixture database)
{
m_database = database;
}

public void Dispose()
{
}

[SkippableFact(ServerFeatures.Timeout)]
public void CancelCommand_WithServerVerification()
{
// This test verifies that cancellation still works with server verification
using var connection = new MySqlConnection(AppConfig.ConnectionString);
connection.Open();

using var cmd = new MySqlCommand("SELECT SLEEP(5)", connection);
var task = Task.Run(async () =>
{
await Task.Delay(TimeSpan.FromSeconds(0.5));
cmd.Cancel();
});

var stopwatch = Stopwatch.StartNew();
TestUtilities.AssertExecuteScalarReturnsOneOrIsCanceled(cmd);
Assert.InRange(stopwatch.ElapsedMilliseconds, 250, 2500);

#pragma warning disable xUnit1031 // Do not use blocking task operations in test method
task.Wait(); // shouldn't throw
#pragma warning restore xUnit1031 // Do not use blocking task operations in test method
}

[SkippableFact(ServerFeatures.KnownCertificateAuthority)]
public void ServerHasServerHostname()
{
using var connection = new MySqlConnection(AppConfig.ConnectionString);
connection.Open();

// Test that we can query server hostname
using var cmd = new MySqlCommand("SELECT @@hostname", connection);
var hostname = cmd.ExecuteScalar();

// Hostname might be null on some server configurations, but the query should succeed
}

private readonly DatabaseFixture m_database;
}