diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 71c366e83..cd39b2613 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -44,7 +44,7 @@ public sealed partial class StreamableHttpServerTransport : ITransport private TaskCompletionSource? _httpResponseTcs; private string? _negotiatedProtocolVersion; private bool _getHttpRequestStarted; - private bool _getHttpResponseCompleted; + private bool _disposed; /// /// Initializes a new instance of the class. @@ -137,33 +137,53 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo throw new InvalidOperationException("GET requests are not supported in stateless mode."); } - using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false)) + try { - if (_getHttpRequestStarted) + using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false)) { - throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); - } + if (_getHttpRequestStarted) + { + throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + } - _getHttpRequestStarted = true; - _httpSseWriter = new SseEventWriter(sseResponseStream); - _httpResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false); - if (_storeSseWriter is not null) - { - var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime(), cancellationToken).ConfigureAwait(false); - await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); + _getHttpRequestStarted = true; + _httpSseWriter = new SseEventWriter(sseResponseStream); + _httpResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false); + if (_storeSseWriter is not null) + { + var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime(), cancellationToken).ConfigureAwait(false); + await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); + } + else + { + // If there's no priming write, flush the stream to ensure HTTP response headers are + // sent to the client now that the transport is ready to accept messages via SendMessageAsync. + await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + } } - else + + // Wait for the response to be written before returning from the handler. + // This keeps the HTTP response open until the final response message is sent. + await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + // Release the SseEventWriter's reference to the response stream promptly when the GET + // request ends, regardless of how it exits. Otherwise the response stream (and the + // underlying Kestrel connection and associated memory pool buffers) remains pinned + // in memory until the session itself is disposed (via explicit DELETE or idle timeout). + // Clients that disconnect without sending DELETE — common with long-lived SSE — would + // otherwise accumulate significant unmanaged memory per session during that interval. + using (await _unsolicitedMessageLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) { - // If there's no priming write, flush the stream to ensure HTTP response headers are - // sent to the client now that the transport is ready to accept messages via SendMessageAsync. - await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + if (_httpSseWriter is { } writer) + { + _httpSseWriter = null; + writer.Dispose(); + } } } - - // Wait for the response to be written before returning from the handler. - // This keeps the HTTP response open until the final response message is sent. - await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); } /// @@ -219,23 +239,22 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can return; } - Debug.Assert(_httpSseWriter is not null); Debug.Assert(_httpResponseTcs is not null); var item = SseItem.Message(message); if (_storeSseWriter is not null) { + // Always record the message in the event store (if configured) — even when the GET + // response stream is gone — so a reconnecting client can replay it via Last-Event-ID. item = await _storeSseWriter.WriteEventAsync(item, cancellationToken).ConfigureAwait(false); } - if (!_getHttpResponseCompleted) + if (_httpSseWriter is { } writer) { - // Only write the message to the response if the response has not completed. - try { - await _httpSseWriter!.WriteAsync(item, cancellationToken).ConfigureAwait(false); + await writer.WriteAsync(item, cancellationToken).ConfigureAwait(false); } catch (Exception ex) when (!cancellationToken.IsCancellationRequested) { @@ -249,12 +268,12 @@ public async ValueTask DisposeAsync() { using var _ = await _unsolicitedMessageLock.LockAsync().ConfigureAwait(false); - if (_getHttpResponseCompleted) + if (_disposed) { return; } - _getHttpResponseCompleted = true; + _disposed = true; try { @@ -266,7 +285,11 @@ public async ValueTask DisposeAsync() try { _httpResponseTcs?.TrySetResult(true); - _httpSseWriter?.Dispose(); + if (_httpSseWriter is { } writer) + { + _httpSseWriter = null; + writer.Dispose(); + } if (_storeSseWriter is not null) { diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamableHttpServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StreamableHttpServerTransportTests.cs new file mode 100644 index 000000000..ce2147e27 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/StreamableHttpServerTransportTests.cs @@ -0,0 +1,89 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Transport; + +public class StreamableHttpServerTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + [Fact] + public async Task SendMessageAsync_AfterGetRequestEnds_DoesNotWriteToResponseStream() + { + // Regression test for the SSE response stream being retained after the GET request + // handler returns. Without releasing the stream reference, the Kestrel connection + // and its associated memory pool buffers (~20MiB per SSE session) stay pinned in + // unmanaged memory until the session is eventually disposed (via explicit DELETE or + // idle timeout), causing steady memory growth for servers whose clients disconnect + // without sending DELETE. After the GET handler returns, SendMessageAsync must not + // attempt to write to the (now released) response stream. + + await using var transport = new StreamableHttpServerTransport() + { + SessionId = "test-session", + }; + + var responseStream = new RecordingStream(); + + using var cts = new CancellationTokenSource(); + var getTask = transport.HandleGetRequestAsync(responseStream, cts.Token); + + // Wait until the GET handler has finished initialization (signaled by the initial + // flush that sends HTTP response headers) so we know _httpSseWriter is set. + await responseStream.FirstActivity.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); + + var writeCountBeforeCancel = responseStream.WriteCount; + + cts.Cancel(); + await Assert.ThrowsAnyAsync(() => getTask); + + await transport.SendMessageAsync( + new JsonRpcNotification { Method = "test" }, + TestContext.Current.CancellationToken); + + Assert.Equal(writeCountBeforeCancel, responseStream.WriteCount); + } + + private sealed class RecordingStream : Stream + { + private readonly TaskCompletionSource _firstActivity = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _writeCount; + + public Task FirstActivity => _firstActivity.Task; + public int WriteCount => Volatile.Read(ref _writeCount); + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() => _firstActivity.TrySetResult(true); + + public override Task FlushAsync(CancellationToken cancellationToken) + { + _firstActivity.TrySetResult(true); + return Task.CompletedTask; + } + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) + { + Interlocked.Increment(ref _writeCount); + _firstActivity.TrySetResult(true); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Interlocked.Increment(ref _writeCount); + _firstActivity.TrySetResult(true); + return Task.CompletedTask; + } + } +}