From cae907c3b67d7d90cffc2a2d3026b3be359af697 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Thu, 22 Jan 2026 15:37:19 +0000 Subject: [PATCH 1/2] feat: add automatic DNS rebinding protection for localhost servers Add DNS rebinding protection that is automatically enabled when requests arrive via localhost (127.0.0.1, [::1]). This protects against malicious websites using DNS rebinding to interact with local MCP servers. Key changes: - Add DisableLocalhostProtection option to StreamableHTTPOptions - Add isLocalhostAddr and isLocalhostHost helper functions - Validate Host header at start of ServeHTTP, rejecting non-localhost Host headers with 403 Forbidden when the connection arrives via localhost The protection is enabled by default with no code changes required. Users can opt-out by setting DisableLocalhostProtection: true. This uses http.LocalAddrContextKey to detect the connection's local address, which means protection is enabled for any request arriving via localhost, regardless of whether the server listens on 127.0.0.1 or 0.0.0.0. Closes: relates to MCP spec security best practices See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise --- mcp/streamable.go | 50 +++++++++++++++ mcp/streamable_test.go | 135 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) diff --git a/mcp/streamable.go b/mcp/streamable.go index f2a28955..ad4afa90 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -18,6 +18,7 @@ import ( "log/slog" "maps" "math" + "net" "math/rand/v2" "net/http" "slices" @@ -160,6 +161,16 @@ type StreamableHTTPOptions struct { // // If SessionTimeout is the zero value, idle sessions are never closed. SessionTimeout time.Duration + + // DisableLocalhostProtection disables automatic DNS rebinding protection. + // By default, requests arriving via a localhost address (127.0.0.1, [::1]) + // that have a non-localhost Host header are rejected with 403 Forbidden. + // This protects against DNS rebinding attacks regardless of whether the + // server is listening on localhost specifically or on 0.0.0.0. + // + // Only disable this if you understand the security implications. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + DisableLocalhostProtection bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -206,7 +217,46 @@ func (h *StreamableHTTPHandler) closeAll() { } } +// isLocalhostAddr checks if a net.Addr is a localhost address. +func isLocalhostAddr(addr net.Addr) bool { + if addr == nil { + return false + } + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + host = addr.String() + } + // Remove brackets for IPv6 + host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") + return host == "localhost" || host == "127.0.0.1" || host == "::1" +} + +// isLocalhostHost checks if a Host header value is a valid localhost address. +func isLocalhostHost(host string) bool { + if host == "" { + return false + } + hostname, _, err := net.SplitHostPort(host) + if err != nil { + hostname = host + } + // Remove brackets for IPv6 + hostname = strings.TrimPrefix(strings.TrimSuffix(hostname, "]"), "[") + return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" +} + func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // DNS rebinding protection: auto-enabled for localhost servers. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + if !h.opts.DisableLocalhostProtection { + if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { + if isLocalhostAddr(localAddr) && !isLocalhostHost(req.Host) { + http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden) + return + } + } + } + // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 56c0a49b..f2869656 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2280,3 +2280,138 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} } } } + +// TestStreamableLocalhostProtection verifies that DNS rebinding protection +// is automatically enabled for localhost servers. +func TestStreamableLocalhostProtection(t *testing.T) { + server := NewServer(testImpl, nil) + + tests := []struct { + name string + listenAddr string // Address to listen on + hostHeader string // Host header in request + disableProt bool // DisableLocalhostProtection setting + wantStatus int + }{ + // Auto-enabled for localhost listeners (127.0.0.1) + {"127.0.0.1 accepts 127.0.0.1", "127.0.0.1:0", "127.0.0.1:1234", false, http.StatusOK}, + {"127.0.0.1 accepts localhost", "127.0.0.1:0", "localhost:1234", false, http.StatusOK}, + {"127.0.0.1 rejects evil.com", "127.0.0.1:0", "evil.com", false, http.StatusForbidden}, + {"127.0.0.1 rejects evil.com:80", "127.0.0.1:0", "evil.com:80", false, http.StatusForbidden}, + {"127.0.0.1 rejects localhost.evil.com", "127.0.0.1:0", "localhost.evil.com", false, http.StatusForbidden}, + + // When listening on 0.0.0.0, requests arriving via localhost are still protected + // because LocalAddrContextKey returns the actual connection's local address. + // This is actually more secure - DNS rebinding attacks target localhost regardless + // of the listener configuration. + {"0.0.0.0 via localhost rejects evil.com", "0.0.0.0:0", "evil.com", false, http.StatusForbidden}, + + // Explicit disable + {"disabled accepts evil.com", "127.0.0.1:0", "evil.com", true, http.StatusOK}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &StreamableHTTPOptions{ + Stateless: true, // Simpler for testing + DisableLocalhostProtection: tt.disableProt, + } + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) + + // Create a listener on the specified address to control LocalAddrContextKey + listener, err := net.Listen("tcp", tt.listenAddr) + if err != nil { + t.Fatalf("failed to listen on %s: %v", tt.listenAddr, err) + } + defer listener.Close() + + // Start server in background + srv := &http.Server{Handler: handler} + go srv.Serve(listener) + defer srv.Close() + + // Make request with custom Host header + req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`)) + if err != nil { + t.Fatal(err) + } + req.Host = tt.hostHeader + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := resp.StatusCode; got != tt.wantStatus { + t.Errorf("status code: got %d, want %d", got, tt.wantStatus) + } + }) + } +} + +// TestIsLocalhostAddr tests the isLocalhostAddr helper function. +func TestIsLocalhostAddr(t *testing.T) { + tests := []struct { + addr string + want bool + }{ + {"127.0.0.1:3000", true}, + {"127.0.0.1:0", true}, + {"localhost:3000", true}, + {"[::1]:3000", true}, + {"0.0.0.0:3000", false}, + {"192.168.1.1:3000", false}, + {"example.com:3000", false}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + addr, err := net.ResolveTCPAddr("tcp", tt.addr) + if err != nil { + // For hostname-based addresses, use a mock + if strings.HasPrefix(tt.addr, "localhost") { + addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 3000} + } else if strings.HasPrefix(tt.addr, "example.com") { + addr = &net.TCPAddr{IP: net.ParseIP("93.184.216.34"), Port: 3000} + } else { + t.Fatalf("failed to resolve %s: %v", tt.addr, err) + } + } + if got := isLocalhostAddr(addr); got != tt.want { + t.Errorf("isLocalhostAddr(%q) = %v, want %v", tt.addr, got, tt.want) + } + }) + } +} + +// TestIsLocalhostHost tests the isLocalhostHost helper function. +func TestIsLocalhostHost(t *testing.T) { + tests := []struct { + host string + want bool + }{ + {"localhost", true}, + {"localhost:3000", true}, + {"127.0.0.1", true}, + {"127.0.0.1:3000", true}, + {"[::1]", true}, + {"[::1]:3000", true}, + {"::1", true}, + {"", false}, + {"evil.com", false}, + {"evil.com:80", false}, + {"localhost.evil.com", false}, + {"127.0.0.1.evil.com", false}, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := isLocalhostHost(tt.host); got != tt.want { + t.Errorf("isLocalhostHost(%q) = %v, want %v", tt.host, got, tt.want) + } + }) + } +} From 6f3f720a5ff077d62a4fed396fe3ed3345204cd9 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Thu, 22 Jan 2026 15:48:38 +0000 Subject: [PATCH 2/2] fix: correct import order --- mcp/streamable.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index ad4afa90..bded6058 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -18,8 +18,8 @@ import ( "log/slog" "maps" "math" - "net" "math/rand/v2" + "net" "net/http" "slices" "strconv"