diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index b2017141..e9dbb82d 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -207,10 +207,22 @@ func (s *Server) Ready() <-chan struct{} { } func (s *Server) handleSession(conn net.Conn) { + defer func() { + if r := recover(); r != nil { + s.log.WithField("panic", r). + WithField("remote_tcp", conn.RemoteAddr()). + Error("Recovered from panic in handleSession, connection will be closed") + if err := conn.Close(); err != nil { + s.log.WithError(err).Warn("Failed to close connection after panic recovery") + } + } + }() + log := s.log.WithField("remote_tcp", conn.RemoteAddr()) dSes, err := makeServerSession(s.m, &s.EntityCommon, conn) if err != nil { + log.WithError(err).Warn("Failed to create server session") if err := conn.Close(); err != nil { log.WithError(err).Warn("On handleSession() failure, close connection resulted in error.") } diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index fd15e52d..699b332e 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -62,6 +62,11 @@ func (ss *ServerSession) Serve() { log.Info("Initiating stream.") go func(sStr *smux.Stream) { + defer func() { + if r := recover(); r != nil { + log.WithField("panic", r).Error("Recovered from panic in serveStream") + } + }() err := ss.serveStream(log, sStr, ss.sm.addr) log.WithError(err).Info("Stopped stream.") }(sStr) @@ -83,6 +88,11 @@ func (ss *ServerSession) Serve() { log.Info("Initiating stream.") go func(yStr *yamux.Stream) { + defer func() { + if r := recover(); r != nil { + log.WithField("panic", r).Error("Recovered from panic in serveStream") + } + }() err := ss.serveStream(log, yStr, ss.sm.addr) log.WithError(err).Info("Stopped stream.") }(yStr) diff --git a/pkg/dmsg/stream_test.go b/pkg/dmsg/stream_test.go index d07a24be..a5a8390a 100644 --- a/pkg/dmsg/stream_test.go +++ b/pkg/dmsg/stream_test.go @@ -307,3 +307,90 @@ func GenKeyPair(t *testing.T, seed string) (cipher.PubKey, cipher.SecKey) { require.NoError(t, err) return pk, sk } + +// TestInvalidPublicKeyNoPanic tests that the server doesn't crash when receiving +// a connection with an invalid public key during the noise handshake. +// This is a regression test for a 2+ year old bug where invalid public keys +// would cause the server to panic and crash. +func TestInvalidPublicKeyNoPanic(t *testing.T) { + // Prepare mock discovery. + dc := disc.NewMock(0) + const maxSessions = 10 + + // Prepare dmsg server. + pkSrv, skSrv := GenKeyPair(t, "server") + srvConf := &ServerConfig{ + MaxSessions: maxSessions, + UpdateInterval: 0, + } + srv := NewServer(pkSrv, skSrv, dc, srvConf, nil) + srv.SetLogger(logging.MustGetLogger("server")) + lisSrv, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Serve dmsg server. + chSrv := make(chan error, 1) + go func() { chSrv <- srv.Serve(lisSrv, "") }() //nolint:errcheck + + // Give server time to start + time.Sleep(500 * time.Millisecond) + + // Attempt to send a handshake with invalid public key data + // This simulates a malicious or buggy client + t.Run("invalid_pubkey_handshake", func(t *testing.T) { + conn, err := net.Dial("tcp", lisSrv.Addr().String()) + require.NoError(t, err) + defer func() { _ = conn.Close() }() //nolint:errcheck + + // Send invalid noise handshake data (contains invalid public key) + // In a real noise handshake, the public key would be embedded in the message + // We send malformed data that will trigger invalid public key error + invalidData := make([]byte, 100) + // Write some invalid data that looks like a handshake but has invalid key + copy(invalidData, []byte{0x00, 0x32}) // frame length prefix (50 bytes) + // Rest is invalid/random data that will fail public key validation + for i := 2; i < len(invalidData); i++ { + invalidData[i] = byte(i) // deterministic but invalid + } + + _, err = conn.Write(invalidData) + // Write may succeed, but the server should handle the invalid data gracefully + if err != nil { + t.Logf("Write failed (expected): %v", err) + } + + // Give server time to process the invalid handshake + time.Sleep(500 * time.Millisecond) + + // Read to see if connection was closed (expected behavior) + buf := make([]byte, 10) + _ = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) //nolint:errcheck + _, _ = conn.Read(buf) //nolint:errcheck + // We expect the connection to be closed or timeout + // The important thing is the server didn't crash + }) + + // Verify server is still running and can accept valid connections + t.Run("valid_connection_after_invalid", func(t *testing.T) { + // Prepare and serve a valid dmsg client + pkA, skA := GenKeyPair(t, "client A") + clientA := NewClient(pkA, skA, dc, DefaultConfig()) + clientA.SetLogger(logging.MustGetLogger("client_A")) + go clientA.Serve(context.Background()) + + // Wait for client to register + time.Sleep(time.Second * 2) + + // Attempt to use the client - if server crashed, this will fail + lis, err := clientA.Listen(8081) + require.NoError(t, err, "Server should still be running and accept valid connections") + + // Clean up + require.NoError(t, lis.Close()) + require.NoError(t, clientA.Close()) + }) + + // Closing logic - server should still be healthy + require.NoError(t, srv.Close()) + require.NoError(t, <-chSrv) +} diff --git a/pkg/noise/dh.go b/pkg/noise/dh.go index 28157d81..e627ae68 100644 --- a/pkg/noise/dh.go +++ b/pkg/noise/dh.go @@ -22,9 +22,24 @@ func (Secp256k1) GenerateKeypair(_ io.Reader) (noise.DHKey, error) { // DH helps to implement `noise.DHFunc`. func (Secp256k1) DH(sk, pk []byte) []byte { - return append( - cipher.MustECDH(cipher.MustNewPubKey(pk), cipher.MustNewSecKey(sk)), - byte(0)) + // Use non-panic versions to handle invalid keys gracefully + pubKey, err := cipher.NewPubKey(pk) + if err != nil { + // Return empty key on error to prevent panic + // The handshake will fail with this invalid key + return make([]byte, 33) + } + secKey, err := cipher.NewSecKey(sk) + if err != nil { + // Return empty key on error to prevent panic + return make([]byte, 33) + } + ecdh, err := cipher.ECDH(pubKey, secKey) + if err != nil { + // Return empty key on error to prevent panic + return make([]byte, 33) + } + return append(ecdh, byte(0)) } // DHLen helps to implement `noise.DHFunc`. diff --git a/pkg/noise/read_writer.go b/pkg/noise/read_writer.go index c728eef6..4f36e830 100644 --- a/pkg/noise/read_writer.go +++ b/pkg/noise/read_writer.go @@ -176,12 +176,17 @@ func (rw *ReadWriter) Write(p []byte) (n int, err error) { func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error { errCh := make(chan error, 1) go func() { + defer func() { + if r := recover(); r != nil { + errCh <- fmt.Errorf("handshake panic: %v", r) + } + close(errCh) + }() if rw.ns.init { errCh <- InitiatorHandshake(rw.ns, rw.rawInput, rw.origin) } else { errCh <- ResponderHandshake(rw.ns, rw.rawInput, rw.origin) } - close(errCh) }() select { case err := <-errCh: