diff --git a/mcp/streamable.go b/mcp/streamable.go index f2a28955..bded6058 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -19,6 +19,7 @@ import ( "maps" "math" "math/rand/v2" + "net" "net/http" "slices" "strconv" @@ -160,6 +161,16 @@ type StreamableHTTPOptions struct { // // If SessionTimeout is the zero value, idle sessions are never closed. SessionTimeout time.Duration + + // DisableLocalhostProtection disables automatic DNS rebinding protection. + // By default, requests arriving via a localhost address (127.0.0.1, [::1]) + // that have a non-localhost Host header are rejected with 403 Forbidden. + // This protects against DNS rebinding attacks regardless of whether the + // server is listening on localhost specifically or on 0.0.0.0. + // + // Only disable this if you understand the security implications. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + DisableLocalhostProtection bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -206,7 +217,46 @@ func (h *StreamableHTTPHandler) closeAll() { } } +// isLocalhostAddr checks if a net.Addr is a localhost address. +func isLocalhostAddr(addr net.Addr) bool { + if addr == nil { + return false + } + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + host = addr.String() + } + // Remove brackets for IPv6 + host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") + return host == "localhost" || host == "127.0.0.1" || host == "::1" +} + +// isLocalhostHost checks if a Host header value is a valid localhost address. +func isLocalhostHost(host string) bool { + if host == "" { + return false + } + hostname, _, err := net.SplitHostPort(host) + if err != nil { + hostname = host + } + // Remove brackets for IPv6 + hostname = strings.TrimPrefix(strings.TrimSuffix(hostname, "]"), "[") + return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" +} + func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // DNS rebinding protection: auto-enabled for localhost servers. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + if !h.opts.DisableLocalhostProtection { + if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { + if isLocalhostAddr(localAddr) && !isLocalhostHost(req.Host) { + http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden) + return + } + } + } + // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 56c0a49b..f2869656 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2280,3 +2280,138 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} } } } + +// TestStreamableLocalhostProtection verifies that DNS rebinding protection +// is automatically enabled for localhost servers. +func TestStreamableLocalhostProtection(t *testing.T) { + server := NewServer(testImpl, nil) + + tests := []struct { + name string + listenAddr string // Address to listen on + hostHeader string // Host header in request + disableProt bool // DisableLocalhostProtection setting + wantStatus int + }{ + // Auto-enabled for localhost listeners (127.0.0.1) + {"127.0.0.1 accepts 127.0.0.1", "127.0.0.1:0", "127.0.0.1:1234", false, http.StatusOK}, + {"127.0.0.1 accepts localhost", "127.0.0.1:0", "localhost:1234", false, http.StatusOK}, + {"127.0.0.1 rejects evil.com", "127.0.0.1:0", "evil.com", false, http.StatusForbidden}, + {"127.0.0.1 rejects evil.com:80", "127.0.0.1:0", "evil.com:80", false, http.StatusForbidden}, + {"127.0.0.1 rejects localhost.evil.com", "127.0.0.1:0", "localhost.evil.com", false, http.StatusForbidden}, + + // When listening on 0.0.0.0, requests arriving via localhost are still protected + // because LocalAddrContextKey returns the actual connection's local address. + // This is actually more secure - DNS rebinding attacks target localhost regardless + // of the listener configuration. + {"0.0.0.0 via localhost rejects evil.com", "0.0.0.0:0", "evil.com", false, http.StatusForbidden}, + + // Explicit disable + {"disabled accepts evil.com", "127.0.0.1:0", "evil.com", true, http.StatusOK}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &StreamableHTTPOptions{ + Stateless: true, // Simpler for testing + DisableLocalhostProtection: tt.disableProt, + } + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) + + // Create a listener on the specified address to control LocalAddrContextKey + listener, err := net.Listen("tcp", tt.listenAddr) + if err != nil { + t.Fatalf("failed to listen on %s: %v", tt.listenAddr, err) + } + defer listener.Close() + + // Start server in background + srv := &http.Server{Handler: handler} + go srv.Serve(listener) + defer srv.Close() + + // Make request with custom Host header + req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`)) + if err != nil { + t.Fatal(err) + } + req.Host = tt.hostHeader + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := resp.StatusCode; got != tt.wantStatus { + t.Errorf("status code: got %d, want %d", got, tt.wantStatus) + } + }) + } +} + +// TestIsLocalhostAddr tests the isLocalhostAddr helper function. +func TestIsLocalhostAddr(t *testing.T) { + tests := []struct { + addr string + want bool + }{ + {"127.0.0.1:3000", true}, + {"127.0.0.1:0", true}, + {"localhost:3000", true}, + {"[::1]:3000", true}, + {"0.0.0.0:3000", false}, + {"192.168.1.1:3000", false}, + {"example.com:3000", false}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + addr, err := net.ResolveTCPAddr("tcp", tt.addr) + if err != nil { + // For hostname-based addresses, use a mock + if strings.HasPrefix(tt.addr, "localhost") { + addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 3000} + } else if strings.HasPrefix(tt.addr, "example.com") { + addr = &net.TCPAddr{IP: net.ParseIP("93.184.216.34"), Port: 3000} + } else { + t.Fatalf("failed to resolve %s: %v", tt.addr, err) + } + } + if got := isLocalhostAddr(addr); got != tt.want { + t.Errorf("isLocalhostAddr(%q) = %v, want %v", tt.addr, got, tt.want) + } + }) + } +} + +// TestIsLocalhostHost tests the isLocalhostHost helper function. +func TestIsLocalhostHost(t *testing.T) { + tests := []struct { + host string + want bool + }{ + {"localhost", true}, + {"localhost:3000", true}, + {"127.0.0.1", true}, + {"127.0.0.1:3000", true}, + {"[::1]", true}, + {"[::1]:3000", true}, + {"::1", true}, + {"", false}, + {"evil.com", false}, + {"evil.com:80", false}, + {"localhost.evil.com", false}, + {"127.0.0.1.evil.com", false}, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := isLocalhostHost(tt.host); got != tt.want { + t.Errorf("isLocalhostHost(%q) = %v, want %v", tt.host, got, tt.want) + } + }) + } +}