diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index 32ec98884..df9bb67d5 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -12,6 +12,7 @@ Current package versions: - Add Redis 8.8 stream negative acknowledgements (`XNACK`) ([#3058 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3058)) - Update experimental `GCRA` APIs and wire protocol terminology from "requests" to "tokens", to match server change ([#3051 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3051)) - Add experimental `Aggregate.Count` support for sorted-set combination operations against Redis 8.8 ([#3059 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3059)) +- Prefer RESP3 and avoid opening a separate subscription connection for Azure Managed Redis endpoints ([#3067 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3067)) ## 2.12.14 diff --git a/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs b/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs index 06656b608..7b36b0c03 100644 --- a/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs +++ b/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs @@ -1,7 +1,6 @@ using System; using System.Net; using System.Threading.Tasks; -using StackExchange.Redis.Maintenance; namespace StackExchange.Redis.Configuration { @@ -54,9 +53,15 @@ private bool IsHostInDomains(string hostName, string[] domains) /// public override Task AfterConnectAsync(ConnectionMultiplexer muxer, Action log) - => AzureMaintenanceEvent.AddListenerAsync(muxer, log); + => Task.CompletedTask; /// public override bool GetDefaultSsl(EndPointCollection endPoints) => true; + + /// + public override RedisProtocol? Protocol => RedisProtocol.Resp3; // prefer RESP3 on AMR + + /// + public override string ConfigurationChannel => ""; // disable on AMR } } diff --git a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs index e4fa25891..f560c8ce4 100644 --- a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs +++ b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs @@ -259,6 +259,11 @@ protected virtual string GetDefaultClientName() => /// public virtual bool SetClientLibrary => true; + /// + /// Gets the preferred protocol to use for the connection. + /// + public virtual RedisProtocol? Protocol => null; + /// /// Tries to get the RoleInstance Id if Microsoft.WindowsAzure.ServiceRuntime is loaded. /// In case of any failure, swallows the exception and returns null. diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index 641fccc95..58a281ddb 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -1169,13 +1169,18 @@ private ConfigurationOptions DoParse(string configuration, bool ignoreUnknown) /// /// Specify the redis protocol type. /// - public RedisProtocol? Protocol { get; set; } + public RedisProtocol? Protocol + { + get => field ?? Defaults.Protocol; + set; + } internal bool TryResp3() { + var protocol = Protocol; // note: deliberately leaving the IsAvailable duplicated to use short-circuit - // if (Protocol is null) + // if (protocol is null) // { // // if not specified, lean on the server version and whether HELLO is available // return new RedisFeatures(DefaultVersion).Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO); @@ -1187,7 +1192,7 @@ internal bool TryResp3() // edge case in the library itself, the break is still visible to external callers via Execute[Async]; with an // abundance of caution, we are therefore making RESP3 explicit opt-in only for now; we may revisit this in a major { - return Protocol.GetValueOrDefault() >= RedisProtocol.Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO); + return protocol.GetValueOrDefault() >= RedisProtocol.Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO); } } diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt index ab058de62..983182101 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt @@ -1 +1,4 @@ #nullable enable +override StackExchange.Redis.Configuration.AzureManagedRedisOptionsProvider.ConfigurationChannel.get -> string! +override StackExchange.Redis.Configuration.AzureManagedRedisOptionsProvider.Protocol.get -> StackExchange.Redis.RedisProtocol? +virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.Protocol.get -> StackExchange.Redis.RedisProtocol? diff --git a/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs b/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs index a01e845da..dca2c4c2f 100644 --- a/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs +++ b/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs @@ -86,6 +86,63 @@ public void IsMatchOnAzureManagedRedisDomain(string hostName) Assert.IsType(provider); } + [Fact] + public async Task AzureManagedRedisConnectsViaResp3WithoutSubscriptionConnection() + { + using var serverObj = new InProcessTestServer(Output, new DnsEndPoint("contoso.redis.azure.net", 10000), useSsl: true); + var config = serverObj.GetClientConfig(); + config.Protocol = null; + + await using var conn = await ConnectionMultiplexer.ConnectAsync(config, Writer); + + var server = conn.GetServer(conn.GetEndPoints().Single()); + var interactiveId = ((IInternalConnectionMultiplexer)conn).GetConnectionId(server.EndPoint, ConnectionType.Interactive); + var clients = server.ClientList(); + var namedClients = clients.Where(x => x.Name == conn.ClientName).ToArray(); + + Assert.Equal(RedisProtocol.Resp3, server.Protocol); + Assert.Equal(1, serverObj.ClientCount); + Assert.NotNull(interactiveId); + Assert.Single(namedClients); + var self = Assert.Single(clients, x => x.Id == interactiveId); + Assert.Equal(ClientType.Normal, self.ClientType); + Assert.Equal(0, self.SubscriptionCount); + Assert.Equal(0, self.PatternSubscriptionCount); + Assert.Equal(0, self.ShardedSubscriptionCount); + } + + [Fact] + public async Task VanillaResp2ConnectsWithSeparatePubSubConnection() + { + using var serverObj = new InProcessTestServer(Output, new DnsEndPoint("redis.contoso.com", 10000), useSsl: true); + var config = serverObj.GetClientConfig(); + config.Protocol = RedisProtocol.Resp2; + + await using var conn = await ConnectionMultiplexer.ConnectAsync(config, Writer); + var sub = conn.GetSubscriber(); + await sub.SubscribeAsync(RedisChannel.Literal(nameof(VanillaResp2ConnectsWithSeparatePubSubConnection)), (_, _) => { }); + + var server = conn.GetServer(conn.GetEndPoints().Single()); + var mux = (IInternalConnectionMultiplexer)conn; + var interactiveId = mux.GetConnectionId(server.EndPoint, ConnectionType.Interactive); + var subscriptionId = mux.GetConnectionId(server.EndPoint, ConnectionType.Subscription); + var clients = server.ClientList(); + var namedClients = clients.Where(x => x.Name == conn.ClientName).ToArray(); + + Assert.Equal(RedisProtocol.Resp2, server.Protocol); + Assert.Equal(2, serverObj.ClientCount); + Assert.NotNull(interactiveId); + Assert.NotNull(subscriptionId); + Assert.NotEqual(interactiveId, subscriptionId); + Assert.Equal(2, namedClients.Length); + + var interactive = Assert.Single(clients, x => x.Id == interactiveId); + var subscription = Assert.Single(clients, x => x.Id == subscriptionId); + Assert.Equal(ClientType.Normal, interactive.ClientType); + Assert.Equal(ClientType.PubSub, subscription.ClientType); + Assert.True(subscription.SubscriptionCount > 0); + } + [Fact] public void AllOverridesFromDefaultsProp() { diff --git a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs index af9f1ee44..338b557c1 100644 --- a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs +++ b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs @@ -3,7 +3,11 @@ using System.IO; using System.IO.Pipelines; using System.Net; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -17,11 +21,22 @@ namespace StackExchange.Redis.Tests; public class InProcessTestServer : MemoryCacheRedisServer { private readonly ITestOutputHelper? _log; - public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null) + private readonly X509Certificate2? _serverCertificate; + private readonly string? _serverCertificateThumbprint; + private readonly RemoteCertificateValidationCallback? _certificateValidationCallback; + + public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null, bool useSsl = false) : base(endpoint) { RedisVersion = RedisFeatures.v6_0_0; // for client to expect RESP3 _log = log; + UseSsl = useSsl; + if (useSsl) + { + _serverCertificate = CreateServerCertificate(DefaultEndPoint); + _serverCertificateThumbprint = _serverCertificate.Thumbprint; + _certificateValidationCallback = ValidateServerCertificate; + } // ReSharper disable once VirtualMemberCallInConstructor _log?.WriteLine($"Creating in-process server: {ToString()}"); Tunnel = new InProcTunnel(this); @@ -90,6 +105,11 @@ public ConfigurationOptions GetClientConfig(bool withPubSub = true, bool default // WriteMode = (BufferedStreamWriter.WriteMode)writeMode, }; if (!string.IsNullOrEmpty(Password)) config.Password = Password; + if (UseSsl) + { + config.Ssl = true; + config.CertificateValidation += _certificateValidationCallback; + } /* useful for viewing *outbound* data in the log #if DEBUG @@ -121,6 +141,7 @@ public ConfigurationOptions GetClientConfig(bool withPubSub = true, bool default } public Tunnel Tunnel { get; } + public bool UseSsl { get; } public override void Log(string message) { @@ -200,6 +221,66 @@ protected override void OnSkippedReply(RedisClient client) base.OnSkippedReply(client); } + protected override void Dispose(bool disposing) + { + if (disposing) + { + _serverCertificate?.Dispose(); + } + base.Dispose(disposing); + } + + private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors errors) + { + if (errors == SslPolicyErrors.None) + { + return true; + } + + return certificate is not null + && _serverCertificateThumbprint is not null + && string.Equals(certificate.GetCertHashString(), _serverCertificateThumbprint, StringComparison.OrdinalIgnoreCase); + } + + private static X509Certificate2 CreateServerCertificate(EndPoint endpoint) + { + var now = DateTimeOffset.UtcNow; + var subjectName = GetCertificateSubjectName(endpoint); + + using var rsa = RSA.Create(2048); + var request = new CertificateRequest($"CN={subjectName}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + request.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false)); + request.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment, false)); + request.CertificateExtensions.Add( + new X509EnhancedKeyUsageExtension( + new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") }, + false)); + + var san = new SubjectAlternativeNameBuilder(); + switch (endpoint) + { + case DnsEndPoint dns: + san.AddDnsName(dns.Host); + break; + case IPEndPoint ip: + san.AddIpAddress(ip.Address); + break; + } + request.CertificateExtensions.Add(san.Build()); + + using var certificate = request.CreateSelfSigned(now.AddMinutes(-5), now.AddDays(7)); +#pragma warning disable SYSLIB0057 + return new X509Certificate2(certificate.Export(X509ContentType.Pfx)); +#pragma warning restore SYSLIB0057 + + static string GetCertificateSubjectName(EndPoint endpoint) => endpoint switch + { + DnsEndPoint dns => dns.Host, + IPEndPoint ip => ip.Address.ToString(), + _ => "localhost", + }; + } + private sealed class InProcTunnel( InProcessTestServer server, PipeOptions? pipeOptions = null) : Tunnel @@ -225,16 +306,38 @@ private sealed class InProcTunnel( if (server.TryGetNode(endpoint, out var node)) { await server.OnAcceptClientAsync(endpoint); + server._log?.WriteLine( + $"[{endpoint}] accepting {connectionType} mapped to {server.ServerType} node {node} via {(server.UseSsl ? "TLS" : "plaintext")}"); var clientToServer = new Pipe(pipeOptions ?? PipeOptions.Default); var serverToClient = new Pipe(pipeOptions ?? PipeOptions.Default); - var serverSide = new Duplex(clientToServer.Reader, serverToClient.Writer); + var serverInput = clientToServer.Reader.AsStream(); + var serverOutput = serverToClient.Writer.AsStream(); + var serverTransport = new DuplexStream(serverInput, serverOutput); - TaskCompletionSource clientTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - Task.Run(async () => await server.RunClientAsync(serverSide, node: node, state: clientTcs), cancellationToken).RedisFireAndForget(); - if (!clientTcs.Task.Wait(1000)) throw new TimeoutException("Client not connected"); - var client = clientTcs.Task.Result; - server._log?.WriteLine( - $"[{client}] connected ({connectionType} mapped to {server.ServerType} node {node})"); + if (server.UseSsl) + { + Task.Run( + async () => + { + using var ssl = new SslStream(serverTransport, leaveInnerStreamOpen: false); + await ssl.AuthenticateAsServerAsync( + server._serverCertificate!, + clientCertificateRequired: false, + enabledSslProtocols: SslProtocols.None, + checkCertificateRevocation: false).ConfigureAwait(false); + var serverSide = new StreamDuplexPipe(ssl); + await server.RunClientAsync(serverSide, node: node, state: null).ConfigureAwait(false); + }, + cancellationToken).RedisFireAndForget(); + } + else + { + var serverSide = new Duplex(clientToServer.Reader, serverToClient.Writer); + TaskCompletionSource clientTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + Task.Run(async () => await server.RunClientAsync(serverSide, node: node, state: clientTcs), cancellationToken).RedisFireAndForget(); + if (!clientTcs.Task.Wait(1000)) throw new TimeoutException("Client not connected"); + _ = clientTcs.Task.Result; + } var readStream = serverToClient.Reader.AsStream(); var writeStream = clientToServer.Writer.AsStream(); @@ -256,6 +359,12 @@ public ValueTask Dispose() return default; } } + + private sealed class StreamDuplexPipe(Stream stream) : IDuplexPipe + { + public PipeReader Input { get; } = PipeReader.Create(stream); + public PipeWriter Output { get; } = PipeWriter.Create(stream); + } } protected virtual ValueTask OnAcceptClientAsync(EndPoint endpoint) => default; diff --git a/tests/StackExchange.Redis.Tests/SSLTests.cs b/tests/StackExchange.Redis.Tests/SSLTests.cs index 96d964b23..f068e0aae 100644 --- a/tests/StackExchange.Redis.Tests/SSLTests.cs +++ b/tests/StackExchange.Redis.Tests/SSLTests.cs @@ -416,6 +416,8 @@ public void Issue883_Exhaustive() Ssl = true, AbortOnConnectFail = false, }; + _ = a.Defaults; + _ = b.Defaults; // ensure the lazily materialized provider matches the parsed shape Log($"computed: {b.ToString(true)}"); Log("Checking endpoints..."); @@ -429,6 +431,14 @@ public void Issue883_Exhaustive() Array.Sort(fields, (x, y) => string.CompareOrdinal(x.Name, y.Name)); foreach (var field in fields) { + if (field.Name == "defaultOptions") + { + var x = field.GetValue(a); + var y = field.GetValue(b); + Log($"{field.Name}: {(x == null ? "(null)" : x.GetType().Name)} vs {(y == null ? "(null)" : y.GetType().Name)}"); + Check(field.Name + ".Type", x?.GetType(), y?.GetType()); + continue; + } Check(field.Name, field.GetValue(a), field.GetValue(b)); } } diff --git a/toys/StackExchange.Redis.Server/RedisServer.cs b/toys/StackExchange.Redis.Server/RedisServer.cs index 54a7fbe04..a3b52dec1 100644 --- a/toys/StackExchange.Redis.Server/RedisServer.cs +++ b/toys/StackExchange.Redis.Server/RedisServer.cs @@ -462,6 +462,32 @@ protected virtual TypedRedisValue ClientReply(RedisClient client, in RedisReques protected virtual TypedRedisValue ClientId(RedisClient client, in RedisRequest request) => TypedRedisValue.Integer(client.Id); + [RedisCommand(2, nameof(RedisCommand.CLIENT), "list", LockFree = true)] + protected virtual TypedRedisValue ClientList(RedisClient client, in RedisRequest request) + { + var sb = new StringBuilder(); + ForAllClients( + sb, + static (other, state) => + { + if (state.Length != 0) state.AppendLine(); + state.Append("id=").Append(other.Id) + .Append(" addr=").Append(other.Node.Host).Append(':').Append(other.Node.Port) + .Append(" age=0 idle=0") + .Append(" db=").Append(other.Database) + .Append(" sub=").Append(other.SubscriptionCount) + .Append(" psub=").Append(other.PatternSubscriptionCount) + .Append(" ssub=").Append(other.ShardedSubscriptionCount) + .Append(" multi=0") + .Append(" cmd=NULL") + .Append(" name=").Append(other.Name ?? "") + .Append(" resp=").Append(other.Protocol is RedisProtocol.Resp3 ? 3 : 2) + .Append(" flags=").Append(other.IsSubscriber ? "P" : "N"); + return 1; + }); + return TypedRedisValue.BulkString(sb.ToString()); + } + [RedisCommand(4, nameof(RedisCommand.CLIENT), "setinfo", LockFree = true)] protected virtual TypedRedisValue ClientSetInfo(RedisClient client, in RedisRequest request) => TypedRedisValue.OK; // only exists to keep logs clean