Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public sealed partial class StreamableHttpServerTransport : ITransport
private TaskCompletionSource<bool>? _httpResponseTcs;
private string? _negotiatedProtocolVersion;
private bool _getHttpRequestStarted;
private bool _getHttpResponseCompleted;
private bool _disposed;

/// <summary>
/// Initializes a new instance of the <see cref="StreamableHttpServerTransport"/> class.
Expand Down Expand Up @@ -137,33 +137,53 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo
throw new InvalidOperationException("GET requests are not supported in stateless mode.");
}

using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false))
try
{
if (_getHttpRequestStarted)
using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false))
{
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
}
if (_getHttpRequestStarted)
{
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
}

_getHttpRequestStarted = true;
_httpSseWriter = new SseEventWriter(sseResponseStream);
_httpResponseTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false);
if (_storeSseWriter is not null)
{
var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime<JsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false);
_getHttpRequestStarted = true;
_httpSseWriter = new SseEventWriter(sseResponseStream);
_httpResponseTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false);
if (_storeSseWriter is not null)
{
var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime<JsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false);
}
else
{
// If there's no priming write, flush the stream to ensure HTTP response headers are
// sent to the client now that the transport is ready to accept messages via SendMessageAsync.
await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
}
else

// Wait for the response to be written before returning from the handler.
// This keeps the HTTP response open until the final response message is sent.
await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
// Release the SseEventWriter's reference to the response stream promptly when the GET
// request ends, regardless of how it exits. Otherwise the response stream (and the
// underlying Kestrel connection and associated memory pool buffers) remains pinned
// in memory until the session itself is disposed (via explicit DELETE or idle timeout).
// Clients that disconnect without sending DELETE — common with long-lived SSE — would
// otherwise accumulate significant unmanaged memory per session during that interval.
using (await _unsolicitedMessageLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
{
// If there's no priming write, flush the stream to ensure HTTP response headers are
// sent to the client now that the transport is ready to accept messages via SendMessageAsync.
await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
if (_httpSseWriter is { } writer)
{
_httpSseWriter = null;
writer.Dispose();
}
}
}

// Wait for the response to be written before returning from the handler.
// This keeps the HTTP response open until the final response message is sent.
await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -219,23 +239,22 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
return;
}

Debug.Assert(_httpSseWriter is not null);
Debug.Assert(_httpResponseTcs is not null);

var item = SseItem.Message(message);

if (_storeSseWriter is not null)
{
// Always record the message in the event store (if configured) — even when the GET
// response stream is gone — so a reconnecting client can replay it via Last-Event-ID.
item = await _storeSseWriter.WriteEventAsync(item, cancellationToken).ConfigureAwait(false);
}

if (!_getHttpResponseCompleted)
if (_httpSseWriter is { } writer)
{
// Only write the message to the response if the response has not completed.

try
{
await _httpSseWriter!.WriteAsync(item, cancellationToken).ConfigureAwait(false);
await writer.WriteAsync(item, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!cancellationToken.IsCancellationRequested)
{
Expand All @@ -249,12 +268,12 @@ public async ValueTask DisposeAsync()
{
using var _ = await _unsolicitedMessageLock.LockAsync().ConfigureAwait(false);

if (_getHttpResponseCompleted)
if (_disposed)
{
return;
}

_getHttpResponseCompleted = true;
_disposed = true;

try
{
Expand All @@ -266,7 +285,11 @@ public async ValueTask DisposeAsync()
try
{
_httpResponseTcs?.TrySetResult(true);
_httpSseWriter?.Dispose();
if (_httpSseWriter is { } writer)
{
_httpSseWriter = null;
writer.Dispose();
}

if (_storeSseWriter is not null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;

namespace ModelContextProtocol.Tests.Transport;

public class StreamableHttpServerTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
{
[Fact]
public async Task SendMessageAsync_AfterGetRequestEnds_DoesNotWriteToResponseStream()
{
// Regression test for the SSE response stream being retained after the GET request
// handler returns. Without releasing the stream reference, the Kestrel connection
// and its associated memory pool buffers (~20MiB per SSE session) stay pinned in
// unmanaged memory until the session is eventually disposed (via explicit DELETE or
// idle timeout), causing steady memory growth for servers whose clients disconnect
// without sending DELETE. After the GET handler returns, SendMessageAsync must not
// attempt to write to the (now released) response stream.

await using var transport = new StreamableHttpServerTransport()
{
SessionId = "test-session",
};

var responseStream = new RecordingStream();

using var cts = new CancellationTokenSource();
var getTask = transport.HandleGetRequestAsync(responseStream, cts.Token);

// Wait until the GET handler has finished initialization (signaled by the initial
// flush that sends HTTP response headers) so we know _httpSseWriter is set.
await responseStream.FirstActivity.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken);

var writeCountBeforeCancel = responseStream.WriteCount;

cts.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => getTask);

await transport.SendMessageAsync(
new JsonRpcNotification { Method = "test" },
TestContext.Current.CancellationToken);

Assert.Equal(writeCountBeforeCancel, responseStream.WriteCount);
}

private sealed class RecordingStream : Stream
{
private readonly TaskCompletionSource<bool> _firstActivity = new(TaskCreationOptions.RunContinuationsAsynchronously);
private int _writeCount;

public Task FirstActivity => _firstActivity.Task;
public int WriteCount => Volatile.Read(ref _writeCount);

public override bool CanRead => false;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}

public override void Flush() => _firstActivity.TrySetResult(true);

public override Task FlushAsync(CancellationToken cancellationToken)
{
_firstActivity.TrySetResult(true);
return Task.CompletedTask;
}

public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();

public override void Write(byte[] buffer, int offset, int count)
{
Interlocked.Increment(ref _writeCount);
_firstActivity.TrySetResult(true);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Interlocked.Increment(ref _writeCount);
_firstActivity.TrySetResult(true);
return Task.CompletedTask;
}
}
}
Loading