Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pkg/dmsg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,22 @@ func (s *Server) Ready() <-chan struct{} {
}

func (s *Server) handleSession(conn net.Conn) {
defer func() {
if r := recover(); r != nil {
s.log.WithField("panic", r).
WithField("remote_tcp", conn.RemoteAddr()).
Error("Recovered from panic in handleSession, connection will be closed")
if err := conn.Close(); err != nil {
s.log.WithError(err).Warn("Failed to close connection after panic recovery")
}
}
}()

log := s.log.WithField("remote_tcp", conn.RemoteAddr())

dSes, err := makeServerSession(s.m, &s.EntityCommon, conn)
if err != nil {
log.WithError(err).Warn("Failed to create server session")
if err := conn.Close(); err != nil {
log.WithError(err).Warn("On handleSession() failure, close connection resulted in error.")
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/dmsg/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ func (ss *ServerSession) Serve() {
log.Info("Initiating stream.")

go func(sStr *smux.Stream) {
defer func() {
if r := recover(); r != nil {
log.WithField("panic", r).Error("Recovered from panic in serveStream")
}
}()
err := ss.serveStream(log, sStr, ss.sm.addr)
log.WithError(err).Info("Stopped stream.")
}(sStr)
Expand All @@ -83,6 +88,11 @@ func (ss *ServerSession) Serve() {
log.Info("Initiating stream.")

go func(yStr *yamux.Stream) {
defer func() {
if r := recover(); r != nil {
log.WithField("panic", r).Error("Recovered from panic in serveStream")
}
}()
err := ss.serveStream(log, yStr, ss.sm.addr)
log.WithError(err).Info("Stopped stream.")
}(yStr)
Expand Down
87 changes: 87 additions & 0 deletions pkg/dmsg/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,90 @@ func GenKeyPair(t *testing.T, seed string) (cipher.PubKey, cipher.SecKey) {
require.NoError(t, err)
return pk, sk
}

// TestInvalidPublicKeyNoPanic tests that the server doesn't crash when receiving
// a connection with an invalid public key during the noise handshake.
// This is a regression test for a 2+ year old bug where invalid public keys
// would cause the server to panic and crash.
func TestInvalidPublicKeyNoPanic(t *testing.T) {
// Prepare mock discovery.
dc := disc.NewMock(0)
const maxSessions = 10

// Prepare dmsg server.
pkSrv, skSrv := GenKeyPair(t, "server")
srvConf := &ServerConfig{
MaxSessions: maxSessions,
UpdateInterval: 0,
}
srv := NewServer(pkSrv, skSrv, dc, srvConf, nil)
srv.SetLogger(logging.MustGetLogger("server"))
lisSrv, err := net.Listen("tcp", "")
require.NoError(t, err)

// Serve dmsg server.
chSrv := make(chan error, 1)
go func() { chSrv <- srv.Serve(lisSrv, "") }() //nolint:errcheck

// Give server time to start
time.Sleep(500 * time.Millisecond)

// Attempt to send a handshake with invalid public key data
// This simulates a malicious or buggy client
t.Run("invalid_pubkey_handshake", func(t *testing.T) {
conn, err := net.Dial("tcp", lisSrv.Addr().String())
require.NoError(t, err)
defer func() { _ = conn.Close() }() //nolint:errcheck

// Send invalid noise handshake data (contains invalid public key)
// In a real noise handshake, the public key would be embedded in the message
// We send malformed data that will trigger invalid public key error
invalidData := make([]byte, 100)
// Write some invalid data that looks like a handshake but has invalid key
copy(invalidData, []byte{0x00, 0x32}) // frame length prefix (50 bytes)
// Rest is invalid/random data that will fail public key validation
for i := 2; i < len(invalidData); i++ {
invalidData[i] = byte(i) // deterministic but invalid
}

_, err = conn.Write(invalidData)
// Write may succeed, but the server should handle the invalid data gracefully
if err != nil {
t.Logf("Write failed (expected): %v", err)
}

// Give server time to process the invalid handshake
time.Sleep(500 * time.Millisecond)

// Read to see if connection was closed (expected behavior)
buf := make([]byte, 10)
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) //nolint:errcheck
_, _ = conn.Read(buf) //nolint:errcheck
// We expect the connection to be closed or timeout
// The important thing is the server didn't crash
})

// Verify server is still running and can accept valid connections
t.Run("valid_connection_after_invalid", func(t *testing.T) {
// Prepare and serve a valid dmsg client
pkA, skA := GenKeyPair(t, "client A")
clientA := NewClient(pkA, skA, dc, DefaultConfig())
clientA.SetLogger(logging.MustGetLogger("client_A"))
go clientA.Serve(context.Background())

// Wait for client to register
time.Sleep(time.Second * 2)

// Attempt to use the client - if server crashed, this will fail
lis, err := clientA.Listen(8081)
require.NoError(t, err, "Server should still be running and accept valid connections")

// Clean up
require.NoError(t, lis.Close())
require.NoError(t, clientA.Close())
})

// Closing logic - server should still be healthy
require.NoError(t, srv.Close())
require.NoError(t, <-chSrv)
}
21 changes: 18 additions & 3 deletions pkg/noise/dh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,24 @@ func (Secp256k1) GenerateKeypair(_ io.Reader) (noise.DHKey, error) {

// DH helps to implement `noise.DHFunc`.
func (Secp256k1) DH(sk, pk []byte) []byte {
return append(
cipher.MustECDH(cipher.MustNewPubKey(pk), cipher.MustNewSecKey(sk)),
byte(0))
// Use non-panic versions to handle invalid keys gracefully
pubKey, err := cipher.NewPubKey(pk)
if err != nil {
// Return empty key on error to prevent panic
// The handshake will fail with this invalid key
return make([]byte, 33)
}
secKey, err := cipher.NewSecKey(sk)
if err != nil {
// Return empty key on error to prevent panic
return make([]byte, 33)
}
ecdh, err := cipher.ECDH(pubKey, secKey)
if err != nil {
// Return empty key on error to prevent panic
return make([]byte, 33)
}
return append(ecdh, byte(0))
}

// DHLen helps to implement `noise.DHFunc`.
Expand Down
7 changes: 6 additions & 1 deletion pkg/noise/read_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,17 @@ func (rw *ReadWriter) Write(p []byte) (n int, err error) {
func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error {
errCh := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errCh <- fmt.Errorf("handshake panic: %v", r)
}
close(errCh)
}()
if rw.ns.init {
errCh <- InitiatorHandshake(rw.ns, rw.rawInput, rw.origin)
} else {
errCh <- ResponderHandshake(rw.ns, rw.rawInput, rw.origin)
}
close(errCh)
}()
select {
case err := <-errCh:
Expand Down
Loading