Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"maps"
"math"
"math/rand/v2"
"net"
"net/http"
"slices"
"strconv"
Expand Down Expand Up @@ -160,6 +161,16 @@ type StreamableHTTPOptions struct {
//
// If SessionTimeout is the zero value, idle sessions are never closed.
SessionTimeout time.Duration

// DisableLocalhostProtection disables automatic DNS rebinding protection.
// By default, requests arriving via a localhost address (127.0.0.1, [::1])
// that have a non-localhost Host header are rejected with 403 Forbidden.
// This protects against DNS rebinding attacks regardless of whether the
// server is listening on localhost specifically or on 0.0.0.0.
//
// Only disable this if you understand the security implications.
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
DisableLocalhostProtection bool
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -206,7 +217,46 @@ func (h *StreamableHTTPHandler) closeAll() {
}
}

// isLocalhostAddr checks if a net.Addr is a localhost address.
func isLocalhostAddr(addr net.Addr) bool {
if addr == nil {
return false
}
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
host = addr.String()
}
// Remove brackets for IPv6
host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}

// isLocalhostHost checks if a Host header value is a valid localhost address.
func isLocalhostHost(host string) bool {
if host == "" {
return false
}
hostname, _, err := net.SplitHostPort(host)
if err != nil {
hostname = host
}
// Remove brackets for IPv6
hostname = strings.TrimPrefix(strings.TrimSuffix(hostname, "]"), "[")
return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1"
}

func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// DNS rebinding protection: auto-enabled for localhost servers.
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
if !h.opts.DisableLocalhostProtection {
if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
if isLocalhostAddr(localAddr) && !isLocalhostHost(req.Host) {
http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden)
return
}
}
}

// Allow multiple 'Accept' headers.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax
accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",")
Expand Down
135 changes: 135 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2280,3 +2280,138 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}}
}
}
}

// TestStreamableLocalhostProtection verifies that DNS rebinding protection
// is automatically enabled for localhost servers.
func TestStreamableLocalhostProtection(t *testing.T) {
server := NewServer(testImpl, nil)

tests := []struct {
name string
listenAddr string // Address to listen on
hostHeader string // Host header in request
disableProt bool // DisableLocalhostProtection setting
wantStatus int
}{
// Auto-enabled for localhost listeners (127.0.0.1)
{"127.0.0.1 accepts 127.0.0.1", "127.0.0.1:0", "127.0.0.1:1234", false, http.StatusOK},
{"127.0.0.1 accepts localhost", "127.0.0.1:0", "localhost:1234", false, http.StatusOK},
{"127.0.0.1 rejects evil.com", "127.0.0.1:0", "evil.com", false, http.StatusForbidden},
{"127.0.0.1 rejects evil.com:80", "127.0.0.1:0", "evil.com:80", false, http.StatusForbidden},
{"127.0.0.1 rejects localhost.evil.com", "127.0.0.1:0", "localhost.evil.com", false, http.StatusForbidden},

// When listening on 0.0.0.0, requests arriving via localhost are still protected
// because LocalAddrContextKey returns the actual connection's local address.
// This is actually more secure - DNS rebinding attacks target localhost regardless
// of the listener configuration.
{"0.0.0.0 via localhost rejects evil.com", "0.0.0.0:0", "evil.com", false, http.StatusForbidden},

// Explicit disable
{"disabled accepts evil.com", "127.0.0.1:0", "evil.com", true, http.StatusOK},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &StreamableHTTPOptions{
Stateless: true, // Simpler for testing
DisableLocalhostProtection: tt.disableProt,
}
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts)

// Create a listener on the specified address to control LocalAddrContextKey
listener, err := net.Listen("tcp", tt.listenAddr)
if err != nil {
t.Fatalf("failed to listen on %s: %v", tt.listenAddr, err)
}
defer listener.Close()

// Start server in background
srv := &http.Server{Handler: handler}
go srv.Serve(listener)
defer srv.Close()

// Make request with custom Host header
req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`))
if err != nil {
t.Fatal(err)
}
req.Host = tt.hostHeader
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

if got := resp.StatusCode; got != tt.wantStatus {
t.Errorf("status code: got %d, want %d", got, tt.wantStatus)
}
})
}
}

// TestIsLocalhostAddr tests the isLocalhostAddr helper function.
func TestIsLocalhostAddr(t *testing.T) {
tests := []struct {
addr string
want bool
}{
{"127.0.0.1:3000", true},
{"127.0.0.1:0", true},
{"localhost:3000", true},
{"[::1]:3000", true},
{"0.0.0.0:3000", false},
{"192.168.1.1:3000", false},
{"example.com:3000", false},
}

for _, tt := range tests {
t.Run(tt.addr, func(t *testing.T) {
addr, err := net.ResolveTCPAddr("tcp", tt.addr)
if err != nil {
// For hostname-based addresses, use a mock
if strings.HasPrefix(tt.addr, "localhost") {
addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 3000}
} else if strings.HasPrefix(tt.addr, "example.com") {
addr = &net.TCPAddr{IP: net.ParseIP("93.184.216.34"), Port: 3000}
} else {
t.Fatalf("failed to resolve %s: %v", tt.addr, err)
}
}
if got := isLocalhostAddr(addr); got != tt.want {
t.Errorf("isLocalhostAddr(%q) = %v, want %v", tt.addr, got, tt.want)
}
})
}
}

// TestIsLocalhostHost tests the isLocalhostHost helper function.
func TestIsLocalhostHost(t *testing.T) {
tests := []struct {
host string
want bool
}{
{"localhost", true},
{"localhost:3000", true},
{"127.0.0.1", true},
{"127.0.0.1:3000", true},
{"[::1]", true},
{"[::1]:3000", true},
{"::1", true},
{"", false},
{"evil.com", false},
{"evil.com:80", false},
{"localhost.evil.com", false},
{"127.0.0.1.evil.com", false},
}

for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := isLocalhostHost(tt.host); got != tt.want {
t.Errorf("isLocalhostHost(%q) = %v, want %v", tt.host, got, tt.want)
}
})
}
}