diff --git a/src/StackExchange.Redis/Enums/ExpirationFlags.cs b/src/StackExchange.Redis/Enums/ExpirationFlags.cs new file mode 100644 index 000000000..9db24df41 --- /dev/null +++ b/src/StackExchange.Redis/Enums/ExpirationFlags.cs @@ -0,0 +1,21 @@ +using System; + +namespace StackExchange.Redis +{ + /// + /// Additional options for expiration-bearing commands. + /// + [Flags] + public enum ExpirationFlags + { + /// + /// No options specified. + /// + None = 0, + + /// + /// Apply the expiration only if no expiration already exists. + /// + ExpireIfNotExists = 1 << 0, + } +} diff --git a/src/StackExchange.Redis/Enums/RedisCommand.cs b/src/StackExchange.Redis/Enums/RedisCommand.cs index 4f294f46b..138c7095b 100644 --- a/src/StackExchange.Redis/Enums/RedisCommand.cs +++ b/src/StackExchange.Redis/Enums/RedisCommand.cs @@ -101,6 +101,7 @@ internal enum RedisCommand INCR, INCRBY, INCRBYFLOAT, + INCREX, INFO, KEYS, @@ -347,6 +348,7 @@ internal static bool IsPrimaryOnly(this RedisCommand command) case RedisCommand.INCR: case RedisCommand.INCRBY: case RedisCommand.INCRBYFLOAT: + case RedisCommand.INCREX: case RedisCommand.LINSERT: case RedisCommand.LMOVE: case RedisCommand.LMPOP: diff --git a/src/StackExchange.Redis/Expiration.cs b/src/StackExchange.Redis/Expiration.cs index e04094358..c1a874331 100644 --- a/src/StackExchange.Redis/Expiration.cs +++ b/src/StackExchange.Redis/Expiration.cs @@ -16,9 +16,11 @@ public readonly struct Expiration - PX {ms} - relative expiry in milliseconds - EXAT {s} - absolute expiry in seconds - PXAT {ms} - absolute expiry in milliseconds + - ENX - only apply the expiration if no expiration currently exists - We need to distinguish between these 6 scenarios, which we can logically do with 3 bits (8 options). - So; we'll use a ulong for the value, reserving the top 3 bits for the mode. + Historically this packed the mode and value into a single ulong. We now keep the raw long + separate from explicit flags so we can extend expiration behavior without stealing more bits + from the numeric payload. */ /// @@ -39,22 +41,29 @@ public readonly struct Expiration /// /// Expire at the specified absolute time. /// - public Expiration(DateTime when) + public Expiration(DateTime when) : this(when, ExpirationFlags.None) { } + + /// + /// Expire at the specified absolute time. + /// + public Expiration(DateTime when, ExpirationFlags flags) { if (when == DateTime.MaxValue) { - _valueAndMode = s_Default._valueAndMode; + _value = s_Default._value; + _flags = s_Default._flags; return; } long millis = GetUnixTimeMilliseconds(when); + var extraFlags = ToStateFlags(flags); if ((millis % 1000) == 0) { - Init(ExpirationMode.AbsoluteSeconds, millis / 1000, out _valueAndMode); + Init(ExpirationState.HasExpiration | ExpirationState.IsAbsolute | extraFlags, millis / 1000, out _value, out _flags); } else { - Init(ExpirationMode.AbsoluteMilliseconds, millis, out _valueAndMode); + Init(ExpirationState.HasExpiration | ExpirationState.IsAbsolute | ExpirationState.IsMillis | extraFlags, millis, out _value, out _flags); } } @@ -71,70 +80,88 @@ public Expiration(DateTime when) /// /// Expire at the specified relative time. /// - public Expiration(TimeSpan ttl) + public Expiration(TimeSpan ttl) : this(ttl, ExpirationFlags.None) { } + + /// + /// Expire at the specified relative time. + /// + public Expiration(TimeSpan ttl, ExpirationFlags flags) { if (ttl == TimeSpan.MaxValue) { - _valueAndMode = s_Default._valueAndMode; + _value = s_Default._value; + _flags = s_Default._flags; return; } var millis = ttl.Ticks / TimeSpan.TicksPerMillisecond; + var extraFlags = ToStateFlags(flags); if ((millis % 1000) == 0) { - Init(ExpirationMode.RelativeSeconds, millis / 1000, out _valueAndMode); + Init(ExpirationState.HasExpiration | extraFlags, millis / 1000, out _value, out _flags); } else { - Init(ExpirationMode.RelativeMilliseconds, millis, out _valueAndMode); + Init(ExpirationState.HasExpiration | ExpirationState.IsMillis | extraFlags, millis, out _value, out _flags); } } - private readonly ulong _valueAndMode; + private readonly long _value; + private readonly ExpirationState _flags; - private static void Init(ExpirationMode mode, long value, out ulong valueAndMode) + [Flags] + private enum ExpirationState : byte { - // check the caller isn't using the top 3 bits that we have reserved; this includes checking for -ve values - ulong uValue = (ulong)value; - if ((uValue & ~ValueMask) != 0) Throw(); - valueAndMode = (uValue & ValueMask) | ((ulong)mode << 61); - static void Throw() => throw new ArgumentOutOfRangeException(nameof(value)); + None = 0, + ExpireIfNotExists = (byte)ExpirationFlags.ExpireIfNotExists, + HasExpiration = 1 << 1, + IsMillis = 1 << 2, + IsAbsolute = 1 << 3, + KeepTtl = 1 << 4, + Persist = 1 << 5, } - private Expiration(ExpirationMode mode, long value) => Init(mode, value, out _valueAndMode); + private static ExpirationState ToStateFlags(ExpirationFlags flags) + { + const ExpirationFlags validFlags = ExpirationFlags.ExpireIfNotExists; + if ((flags & ~validFlags) != 0) Throw(); + return (ExpirationState)flags; - private enum ExpirationMode : byte + static void Throw() => throw new ArgumentOutOfRangeException(nameof(flags)); + } + + private static void Init(ExpirationState flags, long value, out long rawValue, out ExpirationState rawFlags) { - Default = 0, - RelativeSeconds = 1, - RelativeMilliseconds = 2, - AbsoluteSeconds = 3, - AbsoluteMilliseconds = 4, - KeepTtl = 5, - Persist = 6, - NotUsed = 7, // just to ensure all 8 possible values are covered + if (value < 0) Throw(); + rawValue = value; + rawFlags = flags; + static void Throw() => throw new ArgumentOutOfRangeException(nameof(value)); } - private const ulong ValueMask = (~0UL) >> 3; - internal long Value => unchecked((long)(_valueAndMode & ValueMask)); - private ExpirationMode Mode => (ExpirationMode)(_valueAndMode >> 61); // note unsigned, no need to mask + private Expiration(ExpirationState flags, long value) + { + _value = value; + _flags = flags; + } - internal bool IsKeepTtl => Mode is ExpirationMode.KeepTtl; - internal bool IsPersist => Mode is ExpirationMode.Persist; - internal bool IsNone => Mode is ExpirationMode.Default; - internal bool IsNoneOrKeepTtl => Mode is ExpirationMode.Default or ExpirationMode.KeepTtl; - internal bool IsAbsolute => Mode is ExpirationMode.AbsoluteSeconds or ExpirationMode.AbsoluteMilliseconds; - internal bool IsRelative => Mode is ExpirationMode.RelativeSeconds or ExpirationMode.RelativeMilliseconds; + internal long Value => _value; - internal bool IsMilliseconds => - Mode is ExpirationMode.RelativeMilliseconds or ExpirationMode.AbsoluteMilliseconds; + internal bool IsKeepTtl => (_flags & ExpirationState.KeepTtl) != 0; + internal bool IsPersist => (_flags & ExpirationState.Persist) != 0; + internal bool IsExpireIfNotExists => (_flags & ExpirationState.ExpireIfNotExists) != 0; + internal bool IsNone => _flags == ExpirationState.None; + internal bool IsNoneOrKeepTtl => IsNone || IsKeepTtl; + internal bool IsAbsolute => (_flags & ExpirationState.IsAbsolute) != 0; + internal bool IsRelative => (_flags & ExpirationState.HasExpiration) != 0 && !IsAbsolute; - internal bool IsSeconds => Mode is ExpirationMode.RelativeSeconds or ExpirationMode.AbsoluteSeconds; + internal bool IsMilliseconds => (_flags & ExpirationState.IsMillis) != 0; - private static readonly Expiration s_Default = new(ExpirationMode.Default, 0); + internal bool IsSeconds => (_flags & (ExpirationState.HasExpiration | ExpirationState.IsMillis)) == ExpirationState.HasExpiration; - private static readonly Expiration s_KeepTtl = new(ExpirationMode.KeepTtl, 0), - s_Persist = new(ExpirationMode.Persist, 0); + private static readonly Expiration s_Default = new(ExpirationState.None, 0); + + private static readonly Expiration s_KeepTtl = new(ExpirationState.KeepTtl, 0), + s_Persist = new(ExpirationState.Persist, 0); private static void ThrowExpiryAndKeepTtl() => // ReSharper disable once NotResolvedInText @@ -206,68 +233,78 @@ internal static Expiration CreateOrKeepTtl(in DateTime? ttl, bool keepTtl) internal RedisValue GetOperand(out long value) { value = Value; - var mode = Mode; - return mode switch + if (IsKeepTtl) return RedisLiterals.KEEPTTL; + if (IsPersist) return RedisLiterals.PERSIST; + if ((_flags & ExpirationState.HasExpiration) == 0) return RedisValue.Null; + + return (IsAbsolute, IsMilliseconds) switch { - ExpirationMode.KeepTtl => RedisLiterals.KEEPTTL, - ExpirationMode.Persist => RedisLiterals.PERSIST, - ExpirationMode.RelativeSeconds => RedisLiterals.EX, - ExpirationMode.RelativeMilliseconds => RedisLiterals.PX, - ExpirationMode.AbsoluteSeconds => RedisLiterals.EXAT, - ExpirationMode.AbsoluteMilliseconds => RedisLiterals.PXAT, - _ => RedisValue.Null, + (false, false) => RedisLiterals.EX, + (false, true) => RedisLiterals.PX, + (true, false) => RedisLiterals.EXAT, + (true, true) => RedisLiterals.PXAT, }; } - private static void ThrowMode(ExpirationMode mode) => - throw new InvalidOperationException("Unknown mode: " + mode); - /// - public override string ToString() => Mode switch + public override string ToString() { - ExpirationMode.Default or ExpirationMode.NotUsed => "", - ExpirationMode.KeepTtl => "KEEPTTL", - ExpirationMode.Persist => "PERSIST", - _ => $"{Operand} {Value}", - }; + if (IsNone) return ""; + if (IsKeepTtl) return "KEEPTTL"; + if (IsPersist) return "PERSIST"; + return IsExpireIfNotExists ? $"{Operand} {Value} {RedisLiterals.ENX}" : $"{Operand} {Value}"; + } /// - public override int GetHashCode() => _valueAndMode.GetHashCode(); + public override int GetHashCode() + { + unchecked + { + return (_value.GetHashCode() * 397) ^ (int)_flags; + } + } /// - public override bool Equals(object? obj) => obj is Expiration other && _valueAndMode == other._valueAndMode; + public override bool Equals(object? obj) => obj is Expiration other && _value == other._value && _flags == other._flags; - internal int TokenCount => Mode switch + internal int GetTokenCount(bool allowEnx) { - ExpirationMode.Default or ExpirationMode.NotUsed => 0, - ExpirationMode.KeepTtl or ExpirationMode.Persist => 1, - _ => 2, - }; + if (!allowEnx && IsExpireIfNotExists) return ThrowEnxNotSupported(); + return IsNone ? 0 : (IsKeepTtl || IsPersist ? 1 : (IsExpireIfNotExists ? 3 : 2)); + + static int ThrowEnxNotSupported() => throw new NotSupportedException("ENX is not supported for this command."); + } internal void WriteTo(PhysicalConnection physical) { - var mode = Mode; - switch (Mode) + if (IsNone) + { + return; + } + + if (IsKeepTtl) + { + physical.WriteBulkString("KEEPTTL"u8); + return; + } + + if (IsPersist) + { + physical.WriteBulkString("PERSIST"u8); + return; + } + + physical.WriteBulkString((IsAbsolute, IsMilliseconds) switch + { + (false, false) => "EX"u8, + (false, true) => "PX"u8, + (true, false) => "EXAT"u8, + (true, true) => "PXAT"u8, + }); + physical.WriteBulkString(Value); + if (IsExpireIfNotExists) { - case ExpirationMode.Default or ExpirationMode.NotUsed: - break; - case ExpirationMode.KeepTtl: - physical.WriteBulkString("KEEPTTL"u8); - break; - case ExpirationMode.Persist: - physical.WriteBulkString("PERSIST"u8); - break; - default: - physical.WriteBulkString(mode switch - { - ExpirationMode.RelativeSeconds => "EX"u8, - ExpirationMode.RelativeMilliseconds => "PX"u8, - ExpirationMode.AbsoluteSeconds => "EXAT"u8, - ExpirationMode.AbsoluteMilliseconds => "PXAT"u8, - _ => default, - }); - physical.WriteBulkString(Value); - break; + physical.WriteBulkString("ENX"u8); } } } diff --git a/src/StackExchange.Redis/Increx.IncrexMessage.cs b/src/StackExchange.Redis/Increx.IncrexMessage.cs new file mode 100644 index 000000000..cc35361ff --- /dev/null +++ b/src/StackExchange.Redis/Increx.IncrexMessage.cs @@ -0,0 +1,99 @@ +namespace StackExchange.Redis; + +internal partial class RedisDatabase +{ + internal abstract class IncrexMessageBase( + int database, + CommandFlags flags, + RedisKey key, + Expiration expiry) : Message(database, flags, RedisCommand.INCREX) + { + protected RedisKey Key => key; + protected Expiration Expiry => expiry; + + public override int ArgCount + { + get + { + return 3 + BoundsArgCount + Expiry.GetTokenCount(allowEnx: true); // key, BYINT/BYFLOAT, value, bounds, expiry + } + } + + protected abstract int BoundsArgCount { get; } + protected abstract void WriteIncrementKindAndValue(PhysicalConnection physical); + protected abstract void WriteBounds(PhysicalConnection physical); + + protected override void WriteImpl(PhysicalConnection physical) + { + physical.WriteHeader(Command, ArgCount); + physical.WriteBulkString(Key); + WriteIncrementKindAndValue(physical); + WriteBounds(physical); + Expiry.WriteTo(physical); + } + } + + internal sealed class IncrexInt64Message( + int database, + CommandFlags flags, + RedisKey key, + long value, + long? lowerBound, + long? upperBound, + Expiration expiry) : IncrexMessageBase(database, flags, key, expiry) + { + protected override int BoundsArgCount => (lowerBound.HasValue ? 2 : 0) + (upperBound.HasValue ? 2 : 0); + + protected override void WriteIncrementKindAndValue(PhysicalConnection physical) + { + physical.WriteBulkString("BYINT"u8); + physical.WriteBulkString(value); + } + + protected override void WriteBounds(PhysicalConnection physical) + { + if (lowerBound.HasValue) + { + physical.WriteBulkString("LBOUND"u8); + physical.WriteBulkString(lowerBound.GetValueOrDefault()); + } + if (upperBound.HasValue) + { + physical.WriteBulkString("UBOUND"u8); + physical.WriteBulkString(upperBound.GetValueOrDefault()); + } + } + } + + internal sealed class IncrexDoubleMessage( + int database, + CommandFlags flags, + RedisKey key, + double value, + double? lowerBound, + double? upperBound, + Expiration expiry) : IncrexMessageBase(database, flags, key, expiry) + { + protected override int BoundsArgCount => (lowerBound.HasValue ? 2 : 0) + (upperBound.HasValue ? 2 : 0); + + protected override void WriteIncrementKindAndValue(PhysicalConnection physical) + { + physical.WriteBulkString("BYFLOAT"u8); + physical.WriteBulkString(value); + } + + protected override void WriteBounds(PhysicalConnection physical) + { + if (lowerBound.HasValue) + { + physical.WriteBulkString("LBOUND"u8); + physical.WriteBulkString(lowerBound.GetValueOrDefault()); + } + if (upperBound.HasValue) + { + physical.WriteBulkString("UBOUND"u8); + physical.WriteBulkString(upperBound.GetValueOrDefault()); + } + } + } +} diff --git a/src/StackExchange.Redis/Increx.ResultProcessor.cs b/src/StackExchange.Redis/Increx.ResultProcessor.cs new file mode 100644 index 000000000..dca1395ba --- /dev/null +++ b/src/StackExchange.Redis/Increx.ResultProcessor.cs @@ -0,0 +1,41 @@ +namespace StackExchange.Redis; + +internal static class IncrexResultProcessor +{ + internal static readonly ResultProcessor> Int64 = new Int64ResultProcessor(); + internal static readonly ResultProcessor> Double = new DoubleResultProcessor(); + + private sealed class Int64ResultProcessor : ResultProcessor> + { + protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + { + if (result.Resp2TypeArray == ResultType.Array && result.ItemsCount >= 2) + { + var items = result.GetItems(); + if (items[0].TryGetInt64(out long value) && items[1].TryGetInt64(out long appliedIncrement)) + { + SetResult(message, new StringIncrementResult(value, appliedIncrement)); + return true; + } + } + return false; + } + } + + private sealed class DoubleResultProcessor : ResultProcessor> + { + protected override bool SetResultCore(PhysicalConnection connection, Message message, in RawResult result) + { + if (result.Resp2TypeArray == ResultType.Array && result.ItemsCount >= 2) + { + var items = result.GetItems(); + if (items[0].TryGetDouble(out double value) && items[1].TryGetDouble(out double appliedIncrement)) + { + SetResult(message, new StringIncrementResult(value, appliedIncrement)); + return true; + } + } + return false; + } + } +} diff --git a/src/StackExchange.Redis/Increx.StringIncrementResult.cs b/src/StackExchange.Redis/Increx.StringIncrementResult.cs new file mode 100644 index 000000000..bba2d9017 --- /dev/null +++ b/src/StackExchange.Redis/Increx.StringIncrementResult.cs @@ -0,0 +1,22 @@ +using System.Diagnostics.CodeAnalysis; +using RESPite; + +namespace StackExchange.Redis; + +/// +/// Represents the result of an increment operation including the resulting value and the increment actually applied. +/// +/// The numeric type represented by the result. +[Experimental(Experiments.Server_8_8, UrlFormat = Experiments.UrlFormat)] +public readonly struct StringIncrementResult(T value, T appliedIncrement) +{ + /// + /// The resulting value after the increment operation. + /// + public T Value { get; } = value; + + /// + /// The increment that was actually applied. + /// + public T AppliedIncrement { get; } = appliedIncrement; +} diff --git a/src/StackExchange.Redis/Interfaces/IDatabase.cs b/src/StackExchange.Redis/Interfaces/IDatabase.cs index 149cd3797..4bd14f4d6 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabase.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabase.cs @@ -3433,6 +3433,34 @@ IEnumerable SortedSetScan( /// double StringIncrement(RedisKey key, double value, CommandFlags flags = CommandFlags.None); + /// + /// Atomically increments the integer value stored at key, optionally constraining the result and applying expiration semantics. + /// + /// The key of the string. + /// The amount to increment by. + /// The expiration to apply. Use to clear the existing TTL. + /// The optional lower bound for the resulting value. + /// The optional upper bound for the resulting value. + /// The flags to use for this operation. + /// The resulting value and the increment actually applied. +#pragma warning disable RS0026 // Public API with optional parameter(s) should have the most parameters amongst its public overloads + [Experimental(Experiments.Server_8_8, UrlFormat = Experiments.UrlFormat)] + StringIncrementResult StringIncrement(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None); + + /// + /// Atomically increments the floating point value stored at key, optionally constraining the result and applying expiration semantics. + /// + /// The key of the string. + /// The amount to increment by. + /// The expiration to apply. Use to clear the existing TTL. + /// The optional lower bound for the resulting value. + /// The optional upper bound for the resulting value. + /// The flags to use for this operation. + /// The resulting value and the increment actually applied. + [Experimental(Experiments.Server_8_8, UrlFormat = Experiments.UrlFormat)] + StringIncrementResult StringIncrement(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None); +#pragma warning restore RS0026 + /// /// Returns the length of the string value stored at key. /// diff --git a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs index af131135f..5305f4584 100644 --- a/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs +++ b/src/StackExchange.Redis/Interfaces/IDatabaseAsync.cs @@ -843,6 +843,16 @@ IAsyncEnumerable SortedSetScanAsync( /// Task StringIncrementAsync(RedisKey key, double value, CommandFlags flags = CommandFlags.None); + /// +#pragma warning disable RS0026 // Public API with optional parameter(s) should have the most parameters amongst its public overloads + [Experimental(Experiments.Server_8_8, UrlFormat = Experiments.UrlFormat)] + Task> StringIncrementAsync(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None); + + /// + [Experimental(Experiments.Server_8_8, UrlFormat = Experiments.UrlFormat)] + Task> StringIncrementAsync(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None); +#pragma warning restore RS0026 + /// Task StringLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None); diff --git a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.cs b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.cs index cd8171f5a..9d14724c9 100644 --- a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.cs +++ b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixed.cs @@ -795,6 +795,12 @@ public Task StringIncrementAsync(RedisKey key, double value, CommandFlag public Task StringIncrementAsync(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) => Inner.StringIncrementAsync(ToInner(key), value, flags); + public Task> StringIncrementAsync(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None) => + Inner.StringIncrementAsync(ToInner(key), value, expiry, lowerBound, upperBound, flags); + + public Task> StringIncrementAsync(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None) => + Inner.StringIncrementAsync(ToInner(key), value, expiry, lowerBound, upperBound, flags); + public Task StringLengthAsync(RedisKey key, CommandFlags flags = CommandFlags.None) => Inner.StringLengthAsync(ToInner(key), flags); diff --git a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixedDatabase.cs b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixedDatabase.cs index 78e3959d6..8cc4e1e6a 100644 --- a/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixedDatabase.cs +++ b/src/StackExchange.Redis/KeyspaceIsolation/KeyPrefixedDatabase.cs @@ -777,6 +777,12 @@ public double StringIncrement(RedisKey key, double value, CommandFlags flags = C public long StringIncrement(RedisKey key, long value = 1, CommandFlags flags = CommandFlags.None) => Inner.StringIncrement(ToInner(key), value, flags); + public StringIncrementResult StringIncrement(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None) => + Inner.StringIncrement(ToInner(key), value, expiry, lowerBound, upperBound, flags); + + public StringIncrementResult StringIncrement(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None) => + Inner.StringIncrement(ToInner(key), value, expiry, lowerBound, upperBound, flags); + public long StringLength(RedisKey key, CommandFlags flags = CommandFlags.None) => Inner.StringLength(ToInner(key), flags); diff --git a/src/StackExchange.Redis/Message.ValueCondition.cs b/src/StackExchange.Redis/Message.ValueCondition.cs index a9672d945..931553913 100644 --- a/src/StackExchange.Redis/Message.ValueCondition.cs +++ b/src/StackExchange.Redis/Message.ValueCondition.cs @@ -42,7 +42,7 @@ private sealed class KeyValueExpiryConditionMessage( private readonly ValueCondition _when = when; private readonly Expiration _expiry = expiry; - public override int ArgCount => 2 + _expiry.TokenCount + _when.TokenCount; + public override int ArgCount => 2 + _expiry.GetTokenCount(allowEnx: false) + _when.TokenCount; protected override void WriteImpl(PhysicalConnection physical) { diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 05dfc07c3..5a2f5026b 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -1720,7 +1720,7 @@ public override int GetHashSlot(ServerSelectionStrategy serverSelectionStrategy) // - MSETNX {key1} {value1} [{key2} {value2}...] // - MSETEX {count} {key1} {value1} [{key2} {value2}...] [standard-expiry-tokens] public override int ArgCount => Command == RedisCommand.MSETEX - ? (1 + (2 * values.Length) + expiry.TokenCount + (when is When.Exists or When.NotExists ? 1 : 0)) + ? (1 + (2 * values.Length) + expiry.GetTokenCount(allowEnx: false) + (when is When.Exists or When.NotExists ? 1 : 0)) : (2 * values.Length); // MSET/MSETNX only support simple syntax protected override void WriteImpl(PhysicalConnection physical) diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt index 2d3c191fa..056535b75 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt @@ -491,6 +491,11 @@ StackExchange.Redis.GeoUnit.Miles = 2 -> StackExchange.Redis.GeoUnit [SER006]StackExchange.Redis.GcraRateLimitResult.Limited.get -> bool [SER006]StackExchange.Redis.GcraRateLimitResult.MaxTokens.get -> int [SER006]StackExchange.Redis.GcraRateLimitResult.RetryAfterSeconds.get -> int +[SER006]StackExchange.Redis.StringIncrementResult +[SER006]StackExchange.Redis.StringIncrementResult.AppliedIncrement.get -> T +[SER006]StackExchange.Redis.StringIncrementResult.StringIncrementResult() -> void +[SER006]StackExchange.Redis.StringIncrementResult.StringIncrementResult(T value, T appliedIncrement) -> void +[SER006]StackExchange.Redis.StringIncrementResult.Value.get -> T StackExchange.Redis.HashEntry StackExchange.Redis.HashEntry.Equals(StackExchange.Redis.HashEntry other) -> bool StackExchange.Redis.HashEntry.HashEntry() -> void @@ -788,7 +793,9 @@ StackExchange.Redis.IDatabase.StringGetSetExpiry(StackExchange.Redis.RedisKey ke StackExchange.Redis.IDatabase.StringGetSetExpiry(StackExchange.Redis.RedisKey key, System.DateTime expiry, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> StackExchange.Redis.RedisValue StackExchange.Redis.IDatabase.StringGetWithExpiry(StackExchange.Redis.RedisKey key, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> StackExchange.Redis.RedisValueWithExpiry [SER006]StackExchange.Redis.IDatabase.StringGcraRateLimit(StackExchange.Redis.RedisKey key, int maxBurst, int tokensPerPeriod, double periodSeconds = 1, int count = 1, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> StackExchange.Redis.GcraRateLimitResult +[SER006]StackExchange.Redis.IDatabase.StringIncrement(StackExchange.Redis.RedisKey key, double value, StackExchange.Redis.Expiration expiry, double? lowerBound = null, double? upperBound = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> StackExchange.Redis.StringIncrementResult StackExchange.Redis.IDatabase.StringIncrement(StackExchange.Redis.RedisKey key, double value, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> double +[SER006]StackExchange.Redis.IDatabase.StringIncrement(StackExchange.Redis.RedisKey key, long value, StackExchange.Redis.Expiration expiry, long? lowerBound = null, long? upperBound = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> StackExchange.Redis.StringIncrementResult StackExchange.Redis.IDatabase.StringIncrement(StackExchange.Redis.RedisKey key, long value = 1, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> long StackExchange.Redis.IDatabase.StringLength(StackExchange.Redis.RedisKey key, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> long StackExchange.Redis.IDatabase.StringLongestCommonSubsequence(StackExchange.Redis.RedisKey first, StackExchange.Redis.RedisKey second, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> string? @@ -1033,7 +1040,9 @@ StackExchange.Redis.IDatabaseAsync.StringGetSetExpiryAsync(StackExchange.Redis.R StackExchange.Redis.IDatabaseAsync.StringGetSetExpiryAsync(StackExchange.Redis.RedisKey key, System.DateTime expiry, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! StackExchange.Redis.IDatabaseAsync.StringGetWithExpiryAsync(StackExchange.Redis.RedisKey key, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! [SER006]StackExchange.Redis.IDatabaseAsync.StringGcraRateLimitAsync(StackExchange.Redis.RedisKey key, int maxBurst, int tokensPerPeriod, double periodSeconds = 1, int count = 1, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! +[SER006]StackExchange.Redis.IDatabaseAsync.StringIncrementAsync(StackExchange.Redis.RedisKey key, double value, StackExchange.Redis.Expiration expiry, double? lowerBound = null, double? upperBound = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task>! StackExchange.Redis.IDatabaseAsync.StringIncrementAsync(StackExchange.Redis.RedisKey key, double value, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! +[SER006]StackExchange.Redis.IDatabaseAsync.StringIncrementAsync(StackExchange.Redis.RedisKey key, long value, StackExchange.Redis.Expiration expiry, long? lowerBound = null, long? upperBound = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task>! StackExchange.Redis.IDatabaseAsync.StringIncrementAsync(StackExchange.Redis.RedisKey key, long value = 1, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! StackExchange.Redis.IDatabaseAsync.StringLengthAsync(StackExchange.Redis.RedisKey key, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! StackExchange.Redis.IDatabaseAsync.StringLongestCommonSubsequenceAsync(StackExchange.Redis.RedisKey first, StackExchange.Redis.RedisKey second, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! @@ -2082,7 +2091,9 @@ StackExchange.Redis.IDatabaseAsync.StringSetAsync(System.Collections.Generic.Key StackExchange.Redis.Expiration StackExchange.Redis.Expiration.Expiration() -> void StackExchange.Redis.Expiration.Expiration(System.DateTime when) -> void +StackExchange.Redis.Expiration.Expiration(System.DateTime when, StackExchange.Redis.ExpirationFlags flags) -> void StackExchange.Redis.Expiration.Expiration(System.TimeSpan ttl) -> void +StackExchange.Redis.Expiration.Expiration(System.TimeSpan ttl, StackExchange.Redis.ExpirationFlags flags) -> void override StackExchange.Redis.Expiration.Equals(object? obj) -> bool override StackExchange.Redis.Expiration.GetHashCode() -> int override StackExchange.Redis.Expiration.ToString() -> string! @@ -2091,6 +2102,9 @@ static StackExchange.Redis.Expiration.KeepTtl.get -> StackExchange.Redis.Expirat static StackExchange.Redis.Expiration.Persist.get -> StackExchange.Redis.Expiration static StackExchange.Redis.Expiration.implicit operator StackExchange.Redis.Expiration(System.DateTime when) -> StackExchange.Redis.Expiration static StackExchange.Redis.Expiration.implicit operator StackExchange.Redis.Expiration(System.TimeSpan ttl) -> StackExchange.Redis.Expiration +StackExchange.Redis.ExpirationFlags +StackExchange.Redis.ExpirationFlags.ExpireIfNotExists = 1 -> StackExchange.Redis.ExpirationFlags +StackExchange.Redis.ExpirationFlags.None = 0 -> StackExchange.Redis.ExpirationFlags override StackExchange.Redis.ValueCondition.Equals(object? obj) -> bool override StackExchange.Redis.ValueCondition.GetHashCode() -> int override StackExchange.Redis.ValueCondition.ToString() -> string! diff --git a/src/StackExchange.Redis/RedisDatabase.Strings.cs b/src/StackExchange.Redis/RedisDatabase.Strings.cs index 4c5e6e1a6..11b26b09f 100644 --- a/src/StackExchange.Redis/RedisDatabase.Strings.cs +++ b/src/StackExchange.Redis/RedisDatabase.Strings.cs @@ -1,4 +1,5 @@ -using System.Runtime.CompilerServices; +using System; +using System.Runtime.CompilerServices; using System.Threading.Tasks; namespace StackExchange.Redis; @@ -59,6 +60,45 @@ public Task StringGcraRateLimitAsync(RedisKey key, int maxB return ExecuteAsync(msg, ResultProcessor.GcraRateLimit); } + public StringIncrementResult StringIncrement(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None) + { + ValidateStringIncrementExpiry(expiry); + var msg = new IncrexInt64Message(Database, flags, key, value, lowerBound, upperBound, expiry); + return ExecuteSync(msg, IncrexResultProcessor.Int64); + } + + public Task> StringIncrementAsync(RedisKey key, long value, Expiration expiry, long? lowerBound = null, long? upperBound = null, CommandFlags flags = CommandFlags.None) + { + ValidateStringIncrementExpiry(expiry); + var msg = new IncrexInt64Message(Database, flags, key, value, lowerBound, upperBound, expiry); + return ExecuteAsync(msg, IncrexResultProcessor.Int64); + } + + public StringIncrementResult StringIncrement(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None) + { + ValidateStringIncrementExpiry(expiry); + var msg = new IncrexDoubleMessage(Database, flags, key, value, lowerBound, upperBound, expiry); + return ExecuteSync(msg, IncrexResultProcessor.Double); + } + + public Task> StringIncrementAsync(RedisKey key, double value, Expiration expiry, double? lowerBound = null, double? upperBound = null, CommandFlags flags = CommandFlags.None) + { + ValidateStringIncrementExpiry(expiry); + var msg = new IncrexDoubleMessage(Database, flags, key, value, lowerBound, upperBound, expiry); + return ExecuteAsync(msg, IncrexResultProcessor.Double); + } + + private static void ValidateStringIncrementExpiry(Expiration expiry) + { + if (expiry.IsKeepTtl) ThrowKeepTtl(); + if (expiry.IsPersist) ThrowPersist(); + if (expiry.IsExpireIfNotExists && !(expiry.IsAbsolute || expiry.IsRelative)) ThrowEnxWithoutExpiry(); + + static void ThrowKeepTtl() => throw new ArgumentException("KEEPTTL is not supported by this operation.", nameof(expiry)); + static void ThrowPersist() => throw new ArgumentException("PERSIST is not supported by this operation; use Expiration.Default to clear the existing TTL.", nameof(expiry)); + static void ThrowEnxWithoutExpiry() => throw new ArgumentException("ENX requires EX, PX, EXAT, or PXAT.", nameof(expiry)); + } + public Task StringSetAsync(RedisKey key, RedisValue value, Expiration expiry, ValueCondition when, CommandFlags flags = CommandFlags.None) { var msg = GetStringSetMessage(key, value, expiry, when, flags); diff --git a/src/StackExchange.Redis/RedisDatabase.cs b/src/StackExchange.Redis/RedisDatabase.cs index cdf4fc9af..9e4abbcd0 100644 --- a/src/StackExchange.Redis/RedisDatabase.cs +++ b/src/StackExchange.Redis/RedisDatabase.cs @@ -489,7 +489,7 @@ public Task HashFieldGetAndDeleteAsync(RedisKey key, RedisValue[] } private Message HashFieldGetAndSetExpiryMessage(in RedisKey key, in RedisValue hashField, Expiration expiry, CommandFlags flags) => - expiry.TokenCount switch + expiry.GetTokenCount(allowEnx: false) switch { // expiry, for example EX 10 2 => Message.Create(Database, flags, RedisCommand.HGETEX, key, expiry.Operand, expiry.Value, RedisLiterals.FIELDS, 1, hashField), @@ -508,12 +508,13 @@ private Message HashFieldGetAndSetExpiryMessage(in RedisKey key, RedisValue[] ha } // precision, time, FIELDS, hashFields.Length, {N x fields} - int extraTokens = expiry.TokenCount + 2; + int expiryTokenCount = expiry.GetTokenCount(allowEnx: false); + int extraTokens = expiryTokenCount + 2; RedisValue[] values = new RedisValue[extraTokens + hashFields.Length]; int index = 0; // add PERSIST or expiry values - switch (expiry.TokenCount) + switch (expiryTokenCount) { case 2: values[index++] = expiry.Operand; @@ -617,9 +618,10 @@ public Task HashFieldGetAndSetExpiryAsync(RedisKey key, RedisValue private Message HashFieldSetAndSetExpiryMessage(in RedisKey key, in RedisValue field, in RedisValue value, Expiration expiry, When when, CommandFlags flags) { + int expiryTokenCount = expiry.GetTokenCount(allowEnx: false); if (when == When.Always) { - return expiry.TokenCount switch + return expiryTokenCount switch { 2 => Message.Create(Database, flags, RedisCommand.HSETEX, key, expiry.Operand, expiry.Value, RedisLiterals.FIELDS, 1, field, value), 1 => Message.Create(Database, flags, RedisCommand.HSETEX, key, expiry.Operand, RedisLiterals.FIELDS, 1, field, value), @@ -636,7 +638,7 @@ private Message HashFieldSetAndSetExpiryMessage(in RedisKey key, in RedisValue f _ => throw new ArgumentOutOfRangeException(nameof(when)), }; - return expiry.TokenCount switch + return expiryTokenCount switch { 2 => Message.Create(Database, flags, RedisCommand.HSETEX, key, existance, expiry.Operand, expiry.Value, RedisLiterals.FIELDS, 1, field, value), 1 => Message.Create(Database, flags, RedisCommand.HSETEX, key, existance, expiry.Operand, RedisLiterals.FIELDS, 1, field, value), @@ -653,7 +655,8 @@ private Message HashFieldSetAndSetExpiryMessage(in RedisKey key, HashEntry[] has return HashFieldSetAndSetExpiryMessage(key, field.Name, field.Value, expiry, when, flags); } // Determine the base array size - var extraTokens = expiry.TokenCount + (when == When.Always ? 2 : 3); // [FXX|FNX] {expiry} FIELDS {length} + int expiryTokenCount = expiry.GetTokenCount(allowEnx: false); + var extraTokens = expiryTokenCount + (when == When.Always ? 2 : 3); // [FXX|FNX] {expiry} FIELDS {length} RedisValue[] values = new RedisValue[(hashFields.Length * 2) + extraTokens]; int index = 0; @@ -670,7 +673,7 @@ private Message HashFieldSetAndSetExpiryMessage(in RedisKey key, HashEntry[] has default: throw new ArgumentOutOfRangeException(nameof(when)); } - switch (expiry.TokenCount) + switch (expiryTokenCount) { case 2: values[index++] = expiry.Operand; @@ -5221,7 +5224,7 @@ private Message GetStringBitOperationMessage(Bitwise operation, RedisKey destina private Message GetStringGetExMessage(in RedisKey key, Expiration expiry, CommandFlags flags = CommandFlags.None) { - return expiry.TokenCount switch + return expiry.GetTokenCount(allowEnx: false) switch { 0 => Message.Create(Database, flags, RedisCommand.GETEX, key), 1 => Message.Create(Database, flags, RedisCommand.GETEX, key, expiry.Operand), @@ -5304,6 +5307,8 @@ private Message GetStringSetMessage( }; } + expiry.GetTokenCount(allowEnx: false); + if (when is When.Always & expiry.IsRelative) { // special case to SETEX/PSETEX diff --git a/src/StackExchange.Redis/RedisLiterals.cs b/src/StackExchange.Redis/RedisLiterals.cs index 9a8f54971..0f1fc02bd 100644 --- a/src/StackExchange.Redis/RedisLiterals.cs +++ b/src/StackExchange.Redis/RedisLiterals.cs @@ -55,6 +55,8 @@ public static readonly RedisValue ASC = "ASC", BEFORE = "BEFORE", BIT = "BIT", + BYFLOAT = "BYFLOAT", + BYINT = "BYINT", BY = "BY", BYLEX = "BYLEX", BYSCORE = "BYSCORE", @@ -68,6 +70,7 @@ public static readonly RedisValue DIFF = "DIFF", DIFF1 = "DIFF1", DOCTOR = "DOCTOR", + ENX = "ENX", ENCODING = "ENCODING", EX = "EX", EXAT = "EXAT", @@ -101,6 +104,7 @@ public static readonly RedisValue LIMIT = "LIMIT", LIST = "LIST", LT = "LT", + LBOUND = "LBOUND", MATCH = "MATCH", MALLOC_STATS = "MALLOC-STATS", MAX = "MAX", @@ -144,6 +148,7 @@ public static readonly RedisValue STOP = "STOP", STORE = "STORE", TYPE = "TYPE", + UBOUND = "UBOUND", USERNAME = "USERNAME", WEIGHTS = "WEIGHTS", WITHMATCHLEN = "WITHMATCHLEN", diff --git a/tests/StackExchange.Redis.Tests/ExpiryTokenTests.cs b/tests/StackExchange.Redis.Tests/ExpirationUnitTests.cs similarity index 67% rename from tests/StackExchange.Redis.Tests/ExpiryTokenTests.cs rename to tests/StackExchange.Redis.Tests/ExpirationUnitTests.cs index 6012422ed..0d376e2ad 100644 --- a/tests/StackExchange.Redis.Tests/ExpiryTokenTests.cs +++ b/tests/StackExchange.Redis.Tests/ExpirationUnitTests.cs @@ -3,14 +3,33 @@ using static StackExchange.Redis.Expiration; namespace StackExchange.Redis.Tests; -public class ExpirationTests // pure tests, no DB +public class ExpirationUnitTests // pure tests, no DB { + [Fact] + public void ExpireIfNotExists_TimeSpan_Seconds() + { + var ex = new Expiration(TimeSpan.FromSeconds(5), ExpirationFlags.ExpireIfNotExists); + Assert.True(ex.IsExpireIfNotExists); + Assert.Equal(3, ex.GetTokenCount(allowEnx: true)); + Assert.Equal("EX 5 ENX", ex.ToString()); + } + + [Fact] + public void ExpireIfNotExists_DateTime_Milliseconds() + { + var when = new DateTime(2025, 7, 23, 10, 4, 14, DateTimeKind.Utc).AddMilliseconds(14); + var ex = new Expiration(when, ExpirationFlags.ExpireIfNotExists); + Assert.True(ex.IsExpireIfNotExists); + Assert.Equal(3, ex.GetTokenCount(allowEnx: true)); + Assert.Equal("PXAT 1753265054014 ENX", ex.ToString()); + } + [Fact] public void Persist_Seconds() { TimeSpan? time = TimeSpan.FromMilliseconds(5000); var ex = CreateOrPersist(time, false); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("EX 5", ex.ToString()); } @@ -19,7 +38,7 @@ public void Persist_Milliseconds() { TimeSpan? time = TimeSpan.FromMilliseconds(5001); var ex = CreateOrPersist(time, false); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("PX 5001", ex.ToString()); } @@ -28,7 +47,7 @@ public void Persist_None_False() { TimeSpan? time = null; var ex = CreateOrPersist(time, false); - Assert.Equal(0, ex.TokenCount); + Assert.Equal(0, ex.GetTokenCount(allowEnx: false)); Assert.Equal("", ex.ToString()); } @@ -37,7 +56,7 @@ public void Persist_None_True() { TimeSpan? time = null; var ex = CreateOrPersist(time, true); - Assert.Equal(1, ex.TokenCount); + Assert.Equal(1, ex.GetTokenCount(allowEnx: false)); Assert.Equal("PERSIST", ex.ToString()); } @@ -55,7 +74,7 @@ public void KeepTtl_Seconds() { TimeSpan? time = TimeSpan.FromMilliseconds(5000); var ex = CreateOrKeepTtl(time, false); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("EX 5", ex.ToString()); } @@ -64,7 +83,7 @@ public void KeepTtl_Milliseconds() { TimeSpan? time = TimeSpan.FromMilliseconds(5001); var ex = CreateOrKeepTtl(time, false); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("PX 5001", ex.ToString()); } @@ -73,7 +92,7 @@ public void KeepTtl_None_False() { TimeSpan? time = null; var ex = CreateOrKeepTtl(time, false); - Assert.Equal(0, ex.TokenCount); + Assert.Equal(0, ex.GetTokenCount(allowEnx: false)); Assert.Equal("", ex.ToString()); } @@ -82,7 +101,7 @@ public void KeepTtl_None_True() { TimeSpan? time = null; var ex = CreateOrKeepTtl(time, true); - Assert.Equal(1, ex.TokenCount); + Assert.Equal(1, ex.GetTokenCount(allowEnx: false)); Assert.Equal("KEEPTTL", ex.ToString()); } @@ -100,7 +119,7 @@ public void DateTime_Seconds() { var when = new DateTime(2025, 7, 23, 10, 4, 14, DateTimeKind.Utc); var ex = new Expiration(when); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("EXAT 1753265054", ex.ToString()); } @@ -110,7 +129,7 @@ public void DateTime_Milliseconds() var when = new DateTime(2025, 7, 23, 10, 4, 14, DateTimeKind.Utc); when = when.AddMilliseconds(14); var ex = new Expiration(when); - Assert.Equal(2, ex.TokenCount); + Assert.Equal(2, ex.GetTokenCount(allowEnx: false)); Assert.Equal("PXAT 1753265054014", ex.ToString()); } } diff --git a/tests/StackExchange.Redis.Tests/IncrexIntegrationTests.cs b/tests/StackExchange.Redis.Tests/IncrexIntegrationTests.cs new file mode 100644 index 000000000..543335363 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/IncrexIntegrationTests.cs @@ -0,0 +1,96 @@ +using System; +using System.Threading.Tasks; +using Xunit; + +namespace StackExchange.Redis.Tests; + +[RunPerProtocol] +public class IncrexIntegrationTests(ITestOutputHelper output, SharedConnectionFixture fixture) : TestBase(output, fixture) +{ + [Fact(Timeout = 5000)] + public async Task StringIncrementIncrex_Int64_WithBoundsAndExpiry() + { + await using var conn = Create(require: RedisFeatures.v8_8_0); + var db = conn.GetDatabase(); + var key = Me(); + db.KeyDelete(key, CommandFlags.FireAndForget); + + db.StringSet(key, 10); + + var result = await db.StringIncrementAsync(key, 2L, TimeSpan.FromSeconds(5), lowerBound: 0, upperBound: 20); + + Assert.Equal(12, result.Value); + Assert.Equal(2, result.AppliedIncrement); + Assert.Equal(12, (long)db.StringGet(key)); + Assert.True((await db.KeyTimeToLiveAsync(key)) > TimeSpan.Zero); + } + + [Fact(Timeout = 5000)] + public async Task StringIncrementIncrex_Double_WithAbsoluteExpiryAndEnx() + { + await using var conn = Create(require: RedisFeatures.v8_8_0); + var db = conn.GetDatabase(); + var key = Me(); + var when = DateTime.UtcNow.AddMinutes(30).AddMilliseconds(14); + db.KeyDelete(key, CommandFlags.FireAndForget); + db.StringSet(key, 3.25, TimeSpan.FromMinutes(10)); + var beforeTtl = await db.KeyTimeToLiveAsync(key); + + var result = await db.StringIncrementAsync(key, 1.25, new Expiration(when, ExpirationFlags.ExpireIfNotExists), lowerBound: -1.5, upperBound: 9.5); + + Assert.Equal(4.5, result.Value); + Assert.Equal(1.25, result.AppliedIncrement); + Assert.Equal(4.5, (double)db.StringGet(key)); + var afterTtl = await db.KeyTimeToLiveAsync(key); + Assert.NotNull(beforeTtl); + Assert.NotNull(afterTtl); + Assert.True(afterTtl <= beforeTtl); + Assert.True(afterTtl > TimeSpan.FromMinutes(8)); + } + + [Fact(Timeout = 5000)] + public async Task StringIncrementIncrex_SyncVersion_ParsesResult() + { + await using var conn = Create(require: RedisFeatures.v8_8_0); + var db = conn.GetDatabase(); + var key = Me(); + db.KeyDelete(key, CommandFlags.FireAndForget); + + var result = db.StringIncrement(key, 3L, Expiration.Default); + + Assert.Equal(3, result.Value); + Assert.Equal(3, result.AppliedIncrement); + } + + [Fact(Timeout = 5000)] + public async Task StringIncrementIncrex_SkipStillAppliesExpiry() + { + await using var conn = Create(require: RedisFeatures.v8_8_0); + var db = conn.GetDatabase(); + var key = Me(); + db.KeyDelete(key, CommandFlags.FireAndForget); + db.StringSet(key, 5); + + var result = await db.StringIncrementAsync(key, 1L, TimeSpan.FromSeconds(5), lowerBound: 10); + + Assert.Equal(5, result.Value); + Assert.Equal(0, result.AppliedIncrement); + Assert.True((await db.KeyTimeToLiveAsync(key)) > TimeSpan.Zero); + } + + [Fact(Timeout = 5000)] + public async Task StringIncrementIncrex_DefaultClearsExistingTtl() + { + await using var conn = Create(require: RedisFeatures.v8_8_0); + var db = conn.GetDatabase(); + var key = Me(); + db.KeyDelete(key, CommandFlags.FireAndForget); + db.StringSet(key, 5, TimeSpan.FromMinutes(5)); + + var result = await db.StringIncrementAsync(key, 2L, Expiration.Default); + + Assert.Equal(7, result.Value); + Assert.Equal(2, result.AppliedIncrement); + Assert.Null(await db.KeyTimeToLiveAsync(key)); + } +} diff --git a/tests/StackExchange.Redis.Tests/IncrexTestServer.cs b/tests/StackExchange.Redis.Tests/IncrexTestServer.cs new file mode 100644 index 000000000..1e8d443ba --- /dev/null +++ b/tests/StackExchange.Redis.Tests/IncrexTestServer.cs @@ -0,0 +1,174 @@ +extern alias respite; +using System; +using System.Globalization; +using respite::RESPite.Messages; +using StackExchange.Redis.Server; +using Xunit; + +namespace StackExchange.Redis.Tests; + +public class IncrexTestServer(ITestOutputHelper? log = null) : InProcessTestServer(log) +{ + public sealed class IncrexRequestSnapshot + { + public RedisKey Key { get; set; } + public bool IsFloat { get; set; } + public string Increment { get; set; } = ""; + public string? LowerBound { get; set; } + public string? UpperBound { get; set; } + public string? ExpiryMode { get; set; } + public string? ExpiryValue { get; set; } + public bool Enx { get; set; } + } + + public IncrexRequestSnapshot? LastRequest { get; private set; } + + [RedisCommand(-4, "INCREX")] + protected virtual TypedRedisValue Increx(RedisClient client, in RedisRequest request) + { + var snapshot = ParseRequest(in request); + LastRequest = snapshot; + + return snapshot.IsFloat + ? ExecuteDouble(client.Database, snapshot) + : ExecuteInt64(client.Database, snapshot); + } + + private IncrexRequestSnapshot ParseRequest(in RedisRequest request) + { + var snapshot = new IncrexRequestSnapshot { Key = request.GetKey(1) }; + int index = 2; + while (index < request.Count) + { + switch (request.GetString(index++)) + { + case "BYINT": + snapshot.IsFloat = false; + snapshot.Increment = request.GetString(index++); + break; + case "BYFLOAT": + snapshot.IsFloat = true; + snapshot.Increment = request.GetString(index++); + break; + case "LBOUND": + snapshot.LowerBound = request.GetString(index++); + break; + case "UBOUND": + snapshot.UpperBound = request.GetString(index++); + break; + case "EX": + case "PX": + case "EXAT": + case "PXAT": + snapshot.ExpiryMode = request.GetString(index - 1); + snapshot.ExpiryValue = request.GetString(index++); + break; + case "ENX": + snapshot.Enx = true; + break; + } + } + return snapshot; + } + + private TypedRedisValue ExecuteInt64(int database, IncrexRequestSnapshot snapshot) + { + var raw = Get(database, snapshot.Key); + bool existed = !raw.IsNull; + long current = raw.IsNull ? 0 : (long)raw; + long delta = long.Parse(snapshot.Increment, CultureInfo.InvariantCulture); + long? lowerBound = snapshot.LowerBound is null ? null : long.Parse(snapshot.LowerBound, CultureInfo.InvariantCulture); + long? upperBound = snapshot.UpperBound is null ? null : long.Parse(snapshot.UpperBound, CultureInfo.InvariantCulture); + + long next = current; + long applied = 0; + + try + { + long candidate = checked(current + delta); + if ((!lowerBound.HasValue || candidate >= lowerBound.GetValueOrDefault()) + && (!upperBound.HasValue || candidate <= upperBound.GetValueOrDefault())) + { + next = candidate; + applied = delta; + } + } + catch (OverflowException) { } + + ApplyValueAndExpiry(database, snapshot, existed, next); + return MakeResult(next, applied); + } + + private TypedRedisValue ExecuteDouble(int database, IncrexRequestSnapshot snapshot) + { + var raw = Get(database, snapshot.Key); + bool existed = !raw.IsNull; + double current = raw.IsNull ? 0D : (double)raw; + double delta = double.Parse(snapshot.Increment, CultureInfo.InvariantCulture); + double? lowerBound = snapshot.LowerBound is null ? null : double.Parse(snapshot.LowerBound, CultureInfo.InvariantCulture); + double? upperBound = snapshot.UpperBound is null ? null : double.Parse(snapshot.UpperBound, CultureInfo.InvariantCulture); + + double next = current; + double applied = 0; + + double candidate = current + delta; + if ((!lowerBound.HasValue || candidate >= lowerBound.GetValueOrDefault()) + && (!upperBound.HasValue || candidate <= upperBound.GetValueOrDefault())) + { + next = candidate; + applied = delta; + } + + ApplyValueAndExpiry(database, snapshot, existed, next); + return MakeResult(next, applied); + } + + private void ApplyValueAndExpiry(int database, IncrexRequestSnapshot snapshot, bool existed, RedisValue value) + { + var priorTtl = existed ? Ttl(database, snapshot.Key) : null; + Set(database, snapshot.Key, value); + + if (snapshot.ExpiryMode is null) + { + return; + } + + if (snapshot.Enx && priorTtl.HasValue && priorTtl.Value != TimeSpan.MaxValue) + { + _ = Expire(database, snapshot.Key, priorTtl.Value); + return; + } + + var ttl = snapshot.ExpiryMode switch + { + "EX" => TimeSpan.FromSeconds(long.Parse(snapshot.ExpiryValue!, CultureInfo.InvariantCulture)), + "PX" => TimeSpan.FromMilliseconds(long.Parse(snapshot.ExpiryValue!, CultureInfo.InvariantCulture)), + "EXAT" => DateTimeOffset.FromUnixTimeSeconds(long.Parse(snapshot.ExpiryValue!, CultureInfo.InvariantCulture)).UtcDateTime - Time(), + "PXAT" => DateTimeOffset.FromUnixTimeMilliseconds(long.Parse(snapshot.ExpiryValue!, CultureInfo.InvariantCulture)).UtcDateTime - Time(), + _ => throw new InvalidOperationException("Unknown expiry mode: " + snapshot.ExpiryMode), + }; + _ = Expire(database, snapshot.Key, ttl); + } + + private static TypedRedisValue MakeResult(long value, long appliedIncrement) + { + var result = TypedRedisValue.Rent(2, out var span, RespPrefix.Array); + span[0] = TypedRedisValue.BulkString((RedisValue)value); + span[1] = TypedRedisValue.BulkString((RedisValue)appliedIncrement); + return result; + } + + private static TypedRedisValue MakeResult(double value, double appliedIncrement) + { + var result = TypedRedisValue.Rent(2, out var span, RespPrefix.Array); + span[0] = TypedRedisValue.BulkString((RedisValue)value); + span[1] = TypedRedisValue.BulkString((RedisValue)appliedIncrement); + return result; + } + + public override void ResetCounters() + { + LastRequest = null; + base.ResetCounters(); + } +} diff --git a/tests/StackExchange.Redis.Tests/IncrexUnitTests.cs b/tests/StackExchange.Redis.Tests/IncrexUnitTests.cs new file mode 100644 index 000000000..68a106e98 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/IncrexUnitTests.cs @@ -0,0 +1,128 @@ +using System; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Xunit; + +namespace StackExchange.Redis.Tests; + +public class IncrexUnitTests(ITestOutputHelper log) +{ + private RedisKey Me([CallerMemberName] string callerName = "") => callerName; + + [Fact] + public async Task StringIncrementIncrex_Int64_WithBoundsAndExpiry() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + var key = Me(); + + db.StringSet(key, 10); + + var result = await db.StringIncrementAsync(key, 2L, TimeSpan.FromSeconds(5), lowerBound: 0, upperBound: 20); + + Assert.Equal(12, result.Value); + Assert.Equal(2, result.AppliedIncrement); + Assert.Equal(12, (long)db.StringGet(key)); + Assert.True((await db.KeyTimeToLiveAsync(key)) > TimeSpan.Zero); + + var request = server.LastRequest!; + Assert.Equal(key, request.Key); + Assert.False(request.IsFloat); + Assert.Equal("2", request.Increment); + Assert.Equal("0", request.LowerBound); + Assert.Equal("20", request.UpperBound); + Assert.Equal("EX", request.ExpiryMode); + Assert.Equal("5", request.ExpiryValue); + Assert.False(request.Enx); + } + + [Fact] + public async Task StringIncrementIncrex_Double_WithAbsoluteExpiryAndEnx() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + var key = Me(); + var when = new DateTime(2025, 7, 23, 10, 4, 14, DateTimeKind.Utc).AddMilliseconds(14); + db.StringSet(key, 3.25, TimeSpan.FromMinutes(10)); + var beforeTtl = await db.KeyTimeToLiveAsync(key); + + var result = await db.StringIncrementAsync(key, 1.25, new Expiration(when, ExpirationFlags.ExpireIfNotExists), lowerBound: -1.5, upperBound: 9.5); + + Assert.Equal(4.5, result.Value); + Assert.Equal(1.25, result.AppliedIncrement); + Assert.Equal(4.5, (double)db.StringGet(key)); + var afterTtl = await db.KeyTimeToLiveAsync(key); + Assert.NotNull(beforeTtl); + Assert.NotNull(afterTtl); + Assert.True(afterTtl <= beforeTtl); + Assert.True(afterTtl > TimeSpan.FromMinutes(8)); + + var request = server.LastRequest!; + Assert.Equal(key, request.Key); + Assert.True(request.IsFloat); + Assert.Equal("1.25", request.Increment); + Assert.Equal("-1.5", request.LowerBound); + Assert.Equal("9.5", request.UpperBound); + Assert.Equal("PXAT", request.ExpiryMode); + Assert.Equal("1753265054014", request.ExpiryValue); + Assert.True(request.Enx); + } + + [Fact] + public async Task StringIncrementIncrex_SyncVersion_ParsesResult() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + + var result = db.StringIncrement(Me(), 3L, Expiration.Default); + + Assert.Equal(3, result.Value); + Assert.Equal(3, result.AppliedIncrement); + } + + [Fact] + public async Task StringIncrementIncrex_SkipStillAppliesExpiry() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + var key = Me(); + db.StringSet(key, 5); + + var result = await db.StringIncrementAsync(key, 1L, TimeSpan.FromSeconds(5), lowerBound: 10); + + Assert.Equal(5, result.Value); + Assert.Equal(0, result.AppliedIncrement); + Assert.True((await db.KeyTimeToLiveAsync(key)) > TimeSpan.Zero); + } + + [Fact] + public async Task StringIncrementIncrex_DefaultClearsExistingTtl() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + var key = Me(); + db.StringSet(key, 5, TimeSpan.FromMinutes(5)); + + var result = await db.StringIncrementAsync(key, 2L, Expiration.Default); + + Assert.Equal(7, result.Value); + Assert.Equal(2, result.AppliedIncrement); + Assert.Null(await db.KeyTimeToLiveAsync(key)); + } + + [Fact] + public async Task StringIncrementIncrex_RejectsKeepTtl() + { + using var server = new IncrexTestServer(log); + await using var muxer = await server.ConnectAsync(); + var db = muxer.GetDatabase(); + + var ex = Assert.Throws(() => db.StringIncrement(Me(), 1L, Expiration.KeepTtl)); + Assert.Equal("expiry", ex.ParamName); + } +} diff --git a/tests/StackExchange.Redis.Tests/KeyPrefixedDatabaseTests.cs b/tests/StackExchange.Redis.Tests/KeyPrefixedDatabaseTests.cs index f117f8c5f..767c4f479 100644 --- a/tests/StackExchange.Redis.Tests/KeyPrefixedDatabaseTests.cs +++ b/tests/StackExchange.Redis.Tests/KeyPrefixedDatabaseTests.cs @@ -1393,6 +1393,20 @@ public void StringIncrement_2() mock.Received().StringIncrement("prefix:key", 1.23, CommandFlags.None); } + [Fact] + public void StringIncrement_3() + { + prefixed.StringIncrement("key", 123L, TimeSpan.FromSeconds(5), lowerBound: 10, upperBound: 200, flags: CommandFlags.None); + mock.Received().StringIncrement("prefix:key", 123L, TimeSpan.FromSeconds(5), 10, 200, CommandFlags.None); + } + + [Fact] + public void StringIncrement_4() + { + prefixed.StringIncrement("key", 1.23, TimeSpan.FromSeconds(5), lowerBound: -1.0, upperBound: 2.0, flags: CommandFlags.None); + mock.Received().StringIncrement("prefix:key", 1.23, TimeSpan.FromSeconds(5), -1.0, 2.0, CommandFlags.None); + } + [Fact] public void StringLength() { diff --git a/tests/StackExchange.Redis.Tests/KeyPrefixedTests.cs b/tests/StackExchange.Redis.Tests/KeyPrefixedTests.cs index 625eb022d..f0f0c2d1d 100644 --- a/tests/StackExchange.Redis.Tests/KeyPrefixedTests.cs +++ b/tests/StackExchange.Redis.Tests/KeyPrefixedTests.cs @@ -103,6 +103,20 @@ public async Task HashIncrementAsync_2() await mock.Received().HashIncrementAsync("prefix:key", "hashField", 1.23, CommandFlags.None); } + [Fact] + public async Task StringIncrementAsync_3() + { + await prefixed.StringIncrementAsync("key", 123L, TimeSpan.FromSeconds(5), lowerBound: 10, upperBound: 200, flags: CommandFlags.None); + await mock.Received().StringIncrementAsync("prefix:key", 123L, TimeSpan.FromSeconds(5), 10, 200, CommandFlags.None); + } + + [Fact] + public async Task StringIncrementAsync_4() + { + await prefixed.StringIncrementAsync("key", 1.23, TimeSpan.FromSeconds(5), lowerBound: -1.0, upperBound: 2.0, flags: CommandFlags.None); + await mock.Received().StringIncrementAsync("prefix:key", 1.23, TimeSpan.FromSeconds(5), -1.0, 2.0, CommandFlags.None); + } + [Fact] public async Task HashKeysAsync() {