From 9792dd963a6333f2abf3605ade6a74257de92639 Mon Sep 17 00:00:00 2001 From: Alex Bumbacea Date: Fri, 19 Dec 2025 21:27:21 +0200 Subject: [PATCH] mcp: improve http transports error handling and make buffer size configurable --- mcp/event.go | 62 +++++++++++++++++++++++++++-------------------- mcp/event_test.go | 48 ++++++++++++++++++++++++++++++++++++ mcp/sse.go | 28 +++++++++++++++------ mcp/streamable.go | 6 ++++- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 5c322c4a..7cc7a834 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -67,9 +67,7 @@ func writeEvent(w io.Writer, evt Event) (int, error) { // TODO(rfindley): consider a different API here that makes failure modes more // apparent. func scanEvents(r io.Reader) iter.Seq2[Event, error] { - scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) + reader := bufio.NewReader(r) // TODO: investigate proper behavior when events are out of order, or have // non-standard names. @@ -94,21 +92,38 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt Event dataBuf *bytes.Buffer // if non-nil, preceding field was also data ) - flushData := func() { + var previousLineType []byte + yieldEvent := func() bool { if dataBuf != nil { evt.Data = dataBuf.Bytes() dataBuf = nil + previousLineType = nil } + if evt.Empty() { + return true + } + if !yield(evt, nil) { + return false + } + evt = Event{} + return true } - for scanner.Scan() { - line := scanner.Bytes() + for { + line, err := reader.ReadBytes('\n') + if err != nil && !errors.Is(err, io.EOF) { + yield(Event{}, fmt.Errorf("error reading event: %v", err)) + return + } + line = bytes.TrimRight(line, "\r\n") + isEOF := errors.Is(err, io.EOF) + if len(line) == 0 { - flushData() - // \n\n is the record delimiter - if !evt.Empty() && !yield(evt, nil) { + if !yieldEvent() { + return + } + if isEOF { return } - evt = Event{} continue } before, after, found := bytes.Cut(line, []byte{':'}) @@ -116,9 +131,6 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) return } - if !bytes.Equal(before, dataKey) { - flushData() - } switch { case bytes.Equal(before, eventKey): evt.Name = strings.TrimSpace(string(after)) @@ -128,27 +140,25 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt.Retry = strings.TrimSpace(string(after)) case bytes.Equal(before, dataKey): data := bytes.TrimSpace(after) - if dataBuf != nil { - dataBuf.WriteByte('\n') + previousLineEmptyOrData := previousLineType == nil || bytes.Equal(previousLineType, dataKey) + if dataBuf == nil { + dataBuf = new(bytes.Buffer) dataBuf.Write(data) + } else if !previousLineEmptyOrData { + yield(Event{}, fmt.Errorf("non-continuous data items in the event")) + return } else { - dataBuf = new(bytes.Buffer) + dataBuf.WriteByte('\n') dataBuf.Write(data) } } - } - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) - } - if !yield(Event{}, err) { + previousLineType = before + + if isEOF { + yieldEvent() return } } - flushData() - if !evt.Empty() { - yield(evt, nil) - } } } diff --git a/mcp/event_test.go b/mcp/event_test.go index dacf30e8..8cab5507 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -54,6 +54,54 @@ func TestScanEvents(t *testing.T) { input: "invalid line\n\n", wantErr: "malformed line", }, + { + name: "message with 2 data lines and another event", + input: "event: message\ndata: hello\ndata: hello\ndata: hello\n\nevent:keepalive", + want: []Event{ + {Name: "message", Data: []byte("hello\nhello\nhello")}, + {Name: "keepalive"}, + }, + }, + { + name: "event with multiple lines", + input: "event: message\ndata: hello\ndata: hello\ndata: hello\nid:1", + want: []Event{ + {Name: "message", ID: "1", Data: []byte("hello\nhello\nhello")}, + }, + }, + { + name: "multiple events, out of order keys", + input: strings.Join([]string{ + "event:message", + "data: hello0", + "\n", + "data: hello1", + "data: hello1", + "id:1", + "event:message", + "\n", + "event:message", + "data: hello3", + "data: hello3", + "id:3", + "\n", + "data: hello4", + "data: hello4", + "id:4", + "event:message", + }, "\n"), + want: []Event{ + {Name: "message", Data: []byte("hello0")}, + {Name: "message", ID: "1", Data: []byte("hello1\nhello1")}, + {Name: "message", ID: "3", Data: []byte("hello3\nhello3")}, + {Name: "message", ID: "4", Data: []byte("hello4\nhello4")}, + }, + }, + { + name: "non-continuous data items in the event", + input: "event: foo\ndata: 123\nretry: 5\ndata: 456", + wantErr: "non-continuous data items in the event", + }, } for _, tt := range tests { diff --git a/mcp/sse.go b/mcp/sse.go index a668c6d0..005f01e8 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -382,7 +382,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { s := &sseClientConn{ client: httpClient, msgEndpoint: msgEndpoint, - incoming: make(chan []byte, 100), + incoming: make(chan sseMessage, 100), body: resp.Body, done: make(chan struct{}), } @@ -392,10 +392,14 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { for evt, err := range scanEvents(resp.Body) { if err != nil { + select { + case s.incoming <- sseMessage{err: err}: + case <-s.done: + } return } select { - case s.incoming <- evt.Data: + case s.incoming <- sseMessage{data: evt.Data}: case <-s.done: return } @@ -405,15 +409,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return s, nil } +// sseMessage represents a message or error from the SSE stream. +type sseMessage struct { + data []byte + err error +} + // An sseClientConn is a logical jsonrpc2 connection that implements the client // half of the SSE protocol: // - Writes are POSTS to the session endpoint. // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan sseMessage // queue of incoming messages or errors mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -438,12 +448,16 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { case <-c.done: return nil, io.EOF - case data := <-c.incoming: + case m := <-c.incoming: + if m.err != nil { + // TODO: bubble up this error + return nil, nil + } // TODO(rfindley): do we really need to check this? We receive from c.done above. if c.isDone() { return nil, io.EOF } - msg, err := jsonrpc2.DecodeMessage(data) + msg, err := jsonrpc2.DecodeMessage(m.data) if err != nil { return nil, err } diff --git a/mcp/streamable.go b/mcp/streamable.go index b4b2fa31..59cb20e6 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1859,7 +1859,11 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if ctx.Err() != nil { return "", 0, true // don't reconnect: client cancelled } - break + + // Network errors during reading should trigger reconnection, not permanent failure. + // Return from processStream so handleSSE can attempt to reconnect. + c.logger.Debug(fmt.Sprintf("%s: stream read error (will attempt reconnect): %v", requestSummary, err)) + return lastEventID, reconnectDelay, false } if evt.ID != "" {