diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md
index 32ec98884..df9bb67d5 100644
--- a/docs/ReleaseNotes.md
+++ b/docs/ReleaseNotes.md
@@ -12,6 +12,7 @@ Current package versions:
- Add Redis 8.8 stream negative acknowledgements (`XNACK`) ([#3058 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3058))
- Update experimental `GCRA` APIs and wire protocol terminology from "requests" to "tokens", to match server change ([#3051 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3051))
- Add experimental `Aggregate.Count` support for sorted-set combination operations against Redis 8.8 ([#3059 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3059))
+- Prefer RESP3 and avoid opening a separate subscription connection for Azure Managed Redis endpoints ([#3067 by @mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/3067))
## 2.12.14
diff --git a/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs b/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs
index 06656b608..7b36b0c03 100644
--- a/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs
+++ b/src/StackExchange.Redis/Configuration/AzureManagedRedisOptionsProvider.cs
@@ -1,7 +1,6 @@
using System;
using System.Net;
using System.Threading.Tasks;
-using StackExchange.Redis.Maintenance;
namespace StackExchange.Redis.Configuration
{
@@ -54,9 +53,15 @@ private bool IsHostInDomains(string hostName, string[] domains)
///
public override Task AfterConnectAsync(ConnectionMultiplexer muxer, Action log)
- => AzureMaintenanceEvent.AddListenerAsync(muxer, log);
+ => Task.CompletedTask;
///
public override bool GetDefaultSsl(EndPointCollection endPoints) => true;
+
+ ///
+ public override RedisProtocol? Protocol => RedisProtocol.Resp3; // prefer RESP3 on AMR
+
+ ///
+ public override string ConfigurationChannel => ""; // disable on AMR
}
}
diff --git a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs
index e4fa25891..f560c8ce4 100644
--- a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs
+++ b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs
@@ -259,6 +259,11 @@ protected virtual string GetDefaultClientName() =>
///
public virtual bool SetClientLibrary => true;
+ ///
+ /// Gets the preferred protocol to use for the connection.
+ ///
+ public virtual RedisProtocol? Protocol => null;
+
///
/// Tries to get the RoleInstance Id if Microsoft.WindowsAzure.ServiceRuntime is loaded.
/// In case of any failure, swallows the exception and returns null.
diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs
index 641fccc95..58a281ddb 100644
--- a/src/StackExchange.Redis/ConfigurationOptions.cs
+++ b/src/StackExchange.Redis/ConfigurationOptions.cs
@@ -1169,13 +1169,18 @@ private ConfigurationOptions DoParse(string configuration, bool ignoreUnknown)
///
/// Specify the redis protocol type.
///
- public RedisProtocol? Protocol { get; set; }
+ public RedisProtocol? Protocol
+ {
+ get => field ?? Defaults.Protocol;
+ set;
+ }
internal bool TryResp3()
{
+ var protocol = Protocol;
// note: deliberately leaving the IsAvailable duplicated to use short-circuit
- // if (Protocol is null)
+ // if (protocol is null)
// {
// // if not specified, lean on the server version and whether HELLO is available
// return new RedisFeatures(DefaultVersion).Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO);
@@ -1187,7 +1192,7 @@ internal bool TryResp3()
// edge case in the library itself, the break is still visible to external callers via Execute[Async]; with an
// abundance of caution, we are therefore making RESP3 explicit opt-in only for now; we may revisit this in a major
{
- return Protocol.GetValueOrDefault() >= RedisProtocol.Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO);
+ return protocol.GetValueOrDefault() >= RedisProtocol.Resp3 && CommandMap.IsAvailable(RedisCommand.HELLO);
}
}
diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt
index ab058de62..983182101 100644
--- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt
+++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt
@@ -1 +1,4 @@
#nullable enable
+override StackExchange.Redis.Configuration.AzureManagedRedisOptionsProvider.ConfigurationChannel.get -> string!
+override StackExchange.Redis.Configuration.AzureManagedRedisOptionsProvider.Protocol.get -> StackExchange.Redis.RedisProtocol?
+virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.Protocol.get -> StackExchange.Redis.RedisProtocol?
diff --git a/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs b/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs
index a01e845da..dca2c4c2f 100644
--- a/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs
+++ b/tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs
@@ -86,6 +86,63 @@ public void IsMatchOnAzureManagedRedisDomain(string hostName)
Assert.IsType(provider);
}
+ [Fact]
+ public async Task AzureManagedRedisConnectsViaResp3WithoutSubscriptionConnection()
+ {
+ using var serverObj = new InProcessTestServer(Output, new DnsEndPoint("contoso.redis.azure.net", 10000), useSsl: true);
+ var config = serverObj.GetClientConfig();
+ config.Protocol = null;
+
+ await using var conn = await ConnectionMultiplexer.ConnectAsync(config, Writer);
+
+ var server = conn.GetServer(conn.GetEndPoints().Single());
+ var interactiveId = ((IInternalConnectionMultiplexer)conn).GetConnectionId(server.EndPoint, ConnectionType.Interactive);
+ var clients = server.ClientList();
+ var namedClients = clients.Where(x => x.Name == conn.ClientName).ToArray();
+
+ Assert.Equal(RedisProtocol.Resp3, server.Protocol);
+ Assert.Equal(1, serverObj.ClientCount);
+ Assert.NotNull(interactiveId);
+ Assert.Single(namedClients);
+ var self = Assert.Single(clients, x => x.Id == interactiveId);
+ Assert.Equal(ClientType.Normal, self.ClientType);
+ Assert.Equal(0, self.SubscriptionCount);
+ Assert.Equal(0, self.PatternSubscriptionCount);
+ Assert.Equal(0, self.ShardedSubscriptionCount);
+ }
+
+ [Fact]
+ public async Task VanillaResp2ConnectsWithSeparatePubSubConnection()
+ {
+ using var serverObj = new InProcessTestServer(Output, new DnsEndPoint("redis.contoso.com", 10000), useSsl: true);
+ var config = serverObj.GetClientConfig();
+ config.Protocol = RedisProtocol.Resp2;
+
+ await using var conn = await ConnectionMultiplexer.ConnectAsync(config, Writer);
+ var sub = conn.GetSubscriber();
+ await sub.SubscribeAsync(RedisChannel.Literal(nameof(VanillaResp2ConnectsWithSeparatePubSubConnection)), (_, _) => { });
+
+ var server = conn.GetServer(conn.GetEndPoints().Single());
+ var mux = (IInternalConnectionMultiplexer)conn;
+ var interactiveId = mux.GetConnectionId(server.EndPoint, ConnectionType.Interactive);
+ var subscriptionId = mux.GetConnectionId(server.EndPoint, ConnectionType.Subscription);
+ var clients = server.ClientList();
+ var namedClients = clients.Where(x => x.Name == conn.ClientName).ToArray();
+
+ Assert.Equal(RedisProtocol.Resp2, server.Protocol);
+ Assert.Equal(2, serverObj.ClientCount);
+ Assert.NotNull(interactiveId);
+ Assert.NotNull(subscriptionId);
+ Assert.NotEqual(interactiveId, subscriptionId);
+ Assert.Equal(2, namedClients.Length);
+
+ var interactive = Assert.Single(clients, x => x.Id == interactiveId);
+ var subscription = Assert.Single(clients, x => x.Id == subscriptionId);
+ Assert.Equal(ClientType.Normal, interactive.ClientType);
+ Assert.Equal(ClientType.PubSub, subscription.ClientType);
+ Assert.True(subscription.SubscriptionCount > 0);
+ }
+
[Fact]
public void AllOverridesFromDefaultsProp()
{
diff --git a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs
index af9f1ee44..338b557c1 100644
--- a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs
+++ b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs
@@ -3,7 +3,11 @@
using System.IO;
using System.IO.Pipelines;
using System.Net;
+using System.Net.Security;
using System.Net.Sockets;
+using System.Security.Authentication;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -17,11 +21,22 @@ namespace StackExchange.Redis.Tests;
public class InProcessTestServer : MemoryCacheRedisServer
{
private readonly ITestOutputHelper? _log;
- public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null)
+ private readonly X509Certificate2? _serverCertificate;
+ private readonly string? _serverCertificateThumbprint;
+ private readonly RemoteCertificateValidationCallback? _certificateValidationCallback;
+
+ public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null, bool useSsl = false)
: base(endpoint)
{
RedisVersion = RedisFeatures.v6_0_0; // for client to expect RESP3
_log = log;
+ UseSsl = useSsl;
+ if (useSsl)
+ {
+ _serverCertificate = CreateServerCertificate(DefaultEndPoint);
+ _serverCertificateThumbprint = _serverCertificate.Thumbprint;
+ _certificateValidationCallback = ValidateServerCertificate;
+ }
// ReSharper disable once VirtualMemberCallInConstructor
_log?.WriteLine($"Creating in-process server: {ToString()}");
Tunnel = new InProcTunnel(this);
@@ -90,6 +105,11 @@ public ConfigurationOptions GetClientConfig(bool withPubSub = true, bool default
// WriteMode = (BufferedStreamWriter.WriteMode)writeMode,
};
if (!string.IsNullOrEmpty(Password)) config.Password = Password;
+ if (UseSsl)
+ {
+ config.Ssl = true;
+ config.CertificateValidation += _certificateValidationCallback;
+ }
/* useful for viewing *outbound* data in the log
#if DEBUG
@@ -121,6 +141,7 @@ public ConfigurationOptions GetClientConfig(bool withPubSub = true, bool default
}
public Tunnel Tunnel { get; }
+ public bool UseSsl { get; }
public override void Log(string message)
{
@@ -200,6 +221,66 @@ protected override void OnSkippedReply(RedisClient client)
base.OnSkippedReply(client);
}
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ _serverCertificate?.Dispose();
+ }
+ base.Dispose(disposing);
+ }
+
+ private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors errors)
+ {
+ if (errors == SslPolicyErrors.None)
+ {
+ return true;
+ }
+
+ return certificate is not null
+ && _serverCertificateThumbprint is not null
+ && string.Equals(certificate.GetCertHashString(), _serverCertificateThumbprint, StringComparison.OrdinalIgnoreCase);
+ }
+
+ private static X509Certificate2 CreateServerCertificate(EndPoint endpoint)
+ {
+ var now = DateTimeOffset.UtcNow;
+ var subjectName = GetCertificateSubjectName(endpoint);
+
+ using var rsa = RSA.Create(2048);
+ var request = new CertificateRequest($"CN={subjectName}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
+ request.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
+ request.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment, false));
+ request.CertificateExtensions.Add(
+ new X509EnhancedKeyUsageExtension(
+ new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") },
+ false));
+
+ var san = new SubjectAlternativeNameBuilder();
+ switch (endpoint)
+ {
+ case DnsEndPoint dns:
+ san.AddDnsName(dns.Host);
+ break;
+ case IPEndPoint ip:
+ san.AddIpAddress(ip.Address);
+ break;
+ }
+ request.CertificateExtensions.Add(san.Build());
+
+ using var certificate = request.CreateSelfSigned(now.AddMinutes(-5), now.AddDays(7));
+#pragma warning disable SYSLIB0057
+ return new X509Certificate2(certificate.Export(X509ContentType.Pfx));
+#pragma warning restore SYSLIB0057
+
+ static string GetCertificateSubjectName(EndPoint endpoint) => endpoint switch
+ {
+ DnsEndPoint dns => dns.Host,
+ IPEndPoint ip => ip.Address.ToString(),
+ _ => "localhost",
+ };
+ }
+
private sealed class InProcTunnel(
InProcessTestServer server,
PipeOptions? pipeOptions = null) : Tunnel
@@ -225,16 +306,38 @@ private sealed class InProcTunnel(
if (server.TryGetNode(endpoint, out var node))
{
await server.OnAcceptClientAsync(endpoint);
+ server._log?.WriteLine(
+ $"[{endpoint}] accepting {connectionType} mapped to {server.ServerType} node {node} via {(server.UseSsl ? "TLS" : "plaintext")}");
var clientToServer = new Pipe(pipeOptions ?? PipeOptions.Default);
var serverToClient = new Pipe(pipeOptions ?? PipeOptions.Default);
- var serverSide = new Duplex(clientToServer.Reader, serverToClient.Writer);
+ var serverInput = clientToServer.Reader.AsStream();
+ var serverOutput = serverToClient.Writer.AsStream();
+ var serverTransport = new DuplexStream(serverInput, serverOutput);
- TaskCompletionSource clientTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
- Task.Run(async () => await server.RunClientAsync(serverSide, node: node, state: clientTcs), cancellationToken).RedisFireAndForget();
- if (!clientTcs.Task.Wait(1000)) throw new TimeoutException("Client not connected");
- var client = clientTcs.Task.Result;
- server._log?.WriteLine(
- $"[{client}] connected ({connectionType} mapped to {server.ServerType} node {node})");
+ if (server.UseSsl)
+ {
+ Task.Run(
+ async () =>
+ {
+ using var ssl = new SslStream(serverTransport, leaveInnerStreamOpen: false);
+ await ssl.AuthenticateAsServerAsync(
+ server._serverCertificate!,
+ clientCertificateRequired: false,
+ enabledSslProtocols: SslProtocols.None,
+ checkCertificateRevocation: false).ConfigureAwait(false);
+ var serverSide = new StreamDuplexPipe(ssl);
+ await server.RunClientAsync(serverSide, node: node, state: null).ConfigureAwait(false);
+ },
+ cancellationToken).RedisFireAndForget();
+ }
+ else
+ {
+ var serverSide = new Duplex(clientToServer.Reader, serverToClient.Writer);
+ TaskCompletionSource clientTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
+ Task.Run(async () => await server.RunClientAsync(serverSide, node: node, state: clientTcs), cancellationToken).RedisFireAndForget();
+ if (!clientTcs.Task.Wait(1000)) throw new TimeoutException("Client not connected");
+ _ = clientTcs.Task.Result;
+ }
var readStream = serverToClient.Reader.AsStream();
var writeStream = clientToServer.Writer.AsStream();
@@ -256,6 +359,12 @@ public ValueTask Dispose()
return default;
}
}
+
+ private sealed class StreamDuplexPipe(Stream stream) : IDuplexPipe
+ {
+ public PipeReader Input { get; } = PipeReader.Create(stream);
+ public PipeWriter Output { get; } = PipeWriter.Create(stream);
+ }
}
protected virtual ValueTask OnAcceptClientAsync(EndPoint endpoint) => default;
diff --git a/tests/StackExchange.Redis.Tests/SSLTests.cs b/tests/StackExchange.Redis.Tests/SSLTests.cs
index 96d964b23..f068e0aae 100644
--- a/tests/StackExchange.Redis.Tests/SSLTests.cs
+++ b/tests/StackExchange.Redis.Tests/SSLTests.cs
@@ -416,6 +416,8 @@ public void Issue883_Exhaustive()
Ssl = true,
AbortOnConnectFail = false,
};
+ _ = a.Defaults;
+ _ = b.Defaults; // ensure the lazily materialized provider matches the parsed shape
Log($"computed: {b.ToString(true)}");
Log("Checking endpoints...");
@@ -429,6 +431,14 @@ public void Issue883_Exhaustive()
Array.Sort(fields, (x, y) => string.CompareOrdinal(x.Name, y.Name));
foreach (var field in fields)
{
+ if (field.Name == "defaultOptions")
+ {
+ var x = field.GetValue(a);
+ var y = field.GetValue(b);
+ Log($"{field.Name}: {(x == null ? "(null)" : x.GetType().Name)} vs {(y == null ? "(null)" : y.GetType().Name)}");
+ Check(field.Name + ".Type", x?.GetType(), y?.GetType());
+ continue;
+ }
Check(field.Name, field.GetValue(a), field.GetValue(b));
}
}
diff --git a/toys/StackExchange.Redis.Server/RedisServer.cs b/toys/StackExchange.Redis.Server/RedisServer.cs
index 54a7fbe04..a3b52dec1 100644
--- a/toys/StackExchange.Redis.Server/RedisServer.cs
+++ b/toys/StackExchange.Redis.Server/RedisServer.cs
@@ -462,6 +462,32 @@ protected virtual TypedRedisValue ClientReply(RedisClient client, in RedisReques
protected virtual TypedRedisValue ClientId(RedisClient client, in RedisRequest request)
=> TypedRedisValue.Integer(client.Id);
+ [RedisCommand(2, nameof(RedisCommand.CLIENT), "list", LockFree = true)]
+ protected virtual TypedRedisValue ClientList(RedisClient client, in RedisRequest request)
+ {
+ var sb = new StringBuilder();
+ ForAllClients(
+ sb,
+ static (other, state) =>
+ {
+ if (state.Length != 0) state.AppendLine();
+ state.Append("id=").Append(other.Id)
+ .Append(" addr=").Append(other.Node.Host).Append(':').Append(other.Node.Port)
+ .Append(" age=0 idle=0")
+ .Append(" db=").Append(other.Database)
+ .Append(" sub=").Append(other.SubscriptionCount)
+ .Append(" psub=").Append(other.PatternSubscriptionCount)
+ .Append(" ssub=").Append(other.ShardedSubscriptionCount)
+ .Append(" multi=0")
+ .Append(" cmd=NULL")
+ .Append(" name=").Append(other.Name ?? "")
+ .Append(" resp=").Append(other.Protocol is RedisProtocol.Resp3 ? 3 : 2)
+ .Append(" flags=").Append(other.IsSubscriber ? "P" : "N");
+ return 1;
+ });
+ return TypedRedisValue.BulkString(sb.ToString());
+ }
+
[RedisCommand(4, nameof(RedisCommand.CLIENT), "setinfo", LockFree = true)]
protected virtual TypedRedisValue ClientSetInfo(RedisClient client, in RedisRequest request)
=> TypedRedisValue.OK; // only exists to keep logs clean