diff --git a/packages/loro-websocket/src/client/index.test.ts b/packages/loro-websocket/src/client/index.test.ts index 96d0a5e..17c8783 100644 --- a/packages/loro-websocket/src/client/index.test.ts +++ b/packages/loro-websocket/src/client/index.test.ts @@ -1,7 +1,142 @@ -import { describe, it, expect } from "vitest"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { + CrdtType, + JoinErrorCode, + MessageType, + type JoinError, +} from "loro-protocol"; +import * as protocol from "loro-protocol"; +import { LoroWebsocketClient } from "./index"; + +class FakeWebSocket { + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + readyState = FakeWebSocket.CLOSED; + bufferedAmount = 0; + binaryType: any = "arraybuffer"; + url: string; + lastSent: unknown; + private listeners = new Map void>>(); + + constructor(url: string) { + this.url = url; + } + + addEventListener(type: string, listener: (ev: any) => void) { + const set = this.listeners.get(type) ?? new Set(); + set.add(listener); + this.listeners.set(type, set); + } + + removeEventListener(type: string, listener: (ev: any) => void) { + const set = this.listeners.get(type); + set?.delete(listener); + } + + dispatch(type: string, ev: any) { + const set = this.listeners.get(type); + if (!set) return; + for (const l of Array.from(set)) l(ev); + } + + send(data: any) { + if (this.readyState !== FakeWebSocket.OPEN) { + throw new Error("WebSocket is not open"); + } + this.lastSent = data; + } + + close() { + this.readyState = FakeWebSocket.CLOSED; + } +} describe("LoroWebsocketClient", () => { - it("is placeholder", () => { - expect(true).toBe(true); + let originalWebSocket: any; + + beforeEach(() => { + originalWebSocket = (globalThis as any).WebSocket; + (globalThis as any).WebSocket = FakeWebSocket as any; + }); + + afterEach(() => { + (globalThis as any).WebSocket = originalWebSocket; + vi.restoreAllMocks(); + }); + + it("does not throw when retrying join after closed socket and reports via onError", async () => { + const onError = vi.fn(); + const client = new LoroWebsocketClient({ + url: "ws://test", + disablePing: true, + reconnect: { enabled: false }, + onError, + }); + + const adaptor = { + crdtType: CrdtType.Loro, + setCtx: () => { }, + getVersion: () => new Uint8Array([0]), + getAlternativeVersion: () => new Uint8Array([1]), + handleJoinOk: async () => { }, + waitForReachingServerVersion: async () => { }, + destroy: () => { }, + } satisfies any; + + const joinError: JoinError = { + type: MessageType.JoinError, + code: JoinErrorCode.VersionUnknown, + message: "", + crdt: adaptor.crdtType, + roomId: "room", + }; + + const pending = { + room: Promise.resolve({} as any), + resolve: () => { }, + reject: () => { }, + adaptor, + roomId: "room", + } satisfies any; + + // Avoid unhandled rejection when the client is destroyed without ever opening. + (client as any).connectedPromise?.catch(() => { }); + + // Force the current socket to a closed state so send will fail. + (client as any).ws.readyState = FakeWebSocket.CLOSED; + + await expect( + (client as any).handleJoinError(joinError, pending, adaptor.crdtType + "room") + ).resolves.not.toThrow(); + + expect(onError).toHaveBeenCalledTimes(1); + expect(((client as any).queuedJoins ?? []).length).toBeGreaterThan(0); + + }); + + it("forwards decode or handler errors to onError instead of crashing", async () => { + const onError = vi.fn(); + const client = new LoroWebsocketClient({ + url: "ws://test", + disablePing: true, + reconnect: { enabled: false }, + onError, + }); + + (client as any).connectedPromise?.catch(() => { }); + + vi.spyOn(protocol, "tryDecode").mockImplementation(() => { + throw new Error("decode failed"); + }); + + await (client as any).onSocketMessage((client as any).ws, { + data: new ArrayBuffer(0), + } as MessageEvent); + + expect(onError).toHaveBeenCalledTimes(1); + }); }); diff --git a/packages/loro-websocket/src/client/index.ts b/packages/loro-websocket/src/client/index.ts index f528c72..7747a4b 100644 --- a/packages/loro-websocket/src/client/index.ts +++ b/packages/loro-websocket/src/client/index.ts @@ -106,6 +106,8 @@ export interface LoroWebsocketClientOptions { disablePing?: boolean; /** Optional callback for low-level ws close (before status transitions). */ onWsClose?: () => void; + /** Optional callback for any client-level errors (socket error, decode/apply failures, send on closed, etc.). */ + onError?: (error: Error) => void; /** * Reconnect policy (kept minimal). * - enabled: toggle auto-retry (default true) @@ -374,7 +376,15 @@ export class LoroWebsocketClient { this.setStatus(ClientStatus.Connecting); - const ws = new WebSocket(this.ops.url); + let ws: WebSocket; + try { + ws = new WebSocket(this.ops.url); + } catch (err) { + const error = err instanceof Error ? err : new Error(String(err)); + this.rejectConnected?.(error); + this.setStatus(ClientStatus.Disconnected); + throw error; + } this.ws = ws; if (current && current !== ws) { @@ -399,7 +409,9 @@ export class LoroWebsocketClient { this.onSocketClose(ws, event); }; const message = (event: MessageEvent) => { - void this.onSocketMessage(ws, event); + void this.onSocketMessage(ws, event).catch(err => { + this.emitError(err instanceof Error ? err : new Error(String(err))); + }); }; ws.addEventListener("open", open); @@ -439,6 +451,7 @@ export class LoroWebsocketClient { if (ws !== this.ws) { this.detachSocketListeners(ws); } + this.emitError(new Error("WebSocket error")); // Leave further handling to the close event for the active socket } @@ -515,20 +528,24 @@ export class LoroWebsocketClient { if (ws !== this.ws) { return; } - if (typeof event.data === "string") { - if (event.data === "ping") { - ws.send("pong"); - return; - } - if (event.data === "pong") { - this.handlePong(); - return; + try { + if (typeof event.data === "string") { + if (event.data === "ping") { + this.safeSend(ws, "pong", "pong"); + return; + } + if (event.data === "pong") { + this.handlePong(); + return; + } + return; // ignore other texts } - return; // ignore other texts + const dataU8 = new Uint8Array(event.data); + const msg = tryDecode(dataU8); + if (msg != null) await this.handleMessage(msg); + } catch (err) { + this.emitError(err instanceof Error ? err : new Error(String(err))); } - const dataU8 = new Uint8Array(event.data); - const msg = tryDecode(dataU8); - if (msg != null) await this.handleMessage(msg); } private scheduleReconnect(immediate = false) { @@ -633,12 +650,7 @@ export class LoroWebsocketClient { } as JoinRequest); try { - if (this.ws && this.ws.readyState === WebSocket.OPEN) { - this.ws.send(payload); - } else { - this.enqueueJoin(payload); - void this.connect(); - } + this.sendJoinPayload(payload); this.emitRoomStatus(id, RoomJoinStatus.Reconnecting); } catch (e) { console.error("Failed to send rejoin request:", e); @@ -735,17 +747,14 @@ export class LoroWebsocketClient { this.fragmentBatches.delete(batchKey); // Notify server to prompt resend try { - if (this.ws && this.ws.readyState === WebSocket.OPEN) { - this.ws.send( - encode({ - type: MessageType.Ack, - crdt: msg.crdt, - roomId: msg.roomId, - refId: msg.batchId, - status: UpdateStatusCode.FragmentTimeout, - } as Ack) - ); - } + const payload = encode({ + type: MessageType.Ack, + crdt: msg.crdt, + roomId: msg.roomId, + refId: msg.batchId, + status: UpdateStatusCode.FragmentTimeout, + } as Ack); + this.safeSend(this.ws, payload, "fragment-timeout-ack"); } catch { } }, 10000); @@ -859,27 +868,25 @@ export class LoroWebsocketClient { pending.adaptor.getAlternativeVersion?.(currentVersion); if (alternativeVersion) { // Retry with alternative version format - this.ws.send( - encode({ - type: MessageType.JoinRequest, - crdt: pending.adaptor.crdtType, - roomId: pending.roomId, - auth: authValue, - version: alternativeVersion, - } as JoinRequest) - ); + const payload = encode({ + type: MessageType.JoinRequest, + crdt: pending.adaptor.crdtType, + roomId: pending.roomId, + auth: authValue, + version: alternativeVersion, + } as JoinRequest); + this.sendJoinPayload(payload); return; } else { console.warn("Version unknown. Now join with an empty version"); - this.ws.send( - encode({ - type: MessageType.JoinRequest, - crdt: pending.adaptor.crdtType, - roomId: pending.roomId, - auth: authValue, - version: new Uint8Array(), - } as JoinRequest) - ); + const payload = encode({ + type: MessageType.JoinRequest, + crdt: pending.adaptor.crdtType, + roomId: pending.roomId, + auth: authValue, + version: new Uint8Array(), + } as JoinRequest); + this.sendJoinPayload(payload); return; } } @@ -940,15 +947,23 @@ export class LoroWebsocketClient { }, timeoutId, }; - this.pingWaiters.push(waiter); - try { - if (this.awaitingPongSince == null) this.awaitingPongSince = Date.now(); - this.ws.send("ping"); - } catch (e) { - this.pingWaiters.pop(); + + // If there's already a pending ping, just wait for the pong + if (this.awaitingPongSince != null) { + this.pingWaiters.push(waiter); + return; + } + + // Try to send ping; if it fails, reject immediately instead of waiting for timeout + const sent = this.safeSend(this.ws, "ping", "ping"); + if (!sent) { clearTimeout(timeoutId); - reject(e instanceof Error ? e : new Error(String(e))); + reject(new Error("Failed to send ping: WebSocket not open")); + return; } + + this.awaitingPongSince = Date.now(); + this.pingWaiters.push(waiter); }); } @@ -1011,14 +1026,16 @@ export class LoroWebsocketClient { }, onJoinFailed: (reason: string) => { console.error(`Join failed: ${reason}`); - this.ws.send( + this.safeSend( + this.ws, encode({ type: MessageType.JoinError, crdt: crdtAdaptor.crdtType, roomId, code: JoinErrorCode.AppError, message: reason, - } as JoinError) + } as JoinError), + "join-error" ); reject(new Error(`Join failed: ${reason}`)); }, @@ -1068,13 +1085,7 @@ export class LoroWebsocketClient { version: crdtAdaptor.getVersion(), } as JoinRequest); - if (this.ws && this.ws.readyState === WebSocket.OPEN) { - this.ws.send(joinPayload); - } else { - this.enqueueJoin(joinPayload); - // ensure a connection attempt is running - void this.connect(); - } + this.sendJoinPayload(joinPayload); }) .catch(err => { const error = err instanceof Error ? err : new Error(String(err)); @@ -1096,6 +1107,7 @@ export class LoroWebsocketClient { this.clearPingTimer(); this.reconnectAttempts = 0; this.rejectConnected?.(new Error("Disconnected")); + void this.connectedPromise?.catch(() => { }); this.rejectConnected = undefined; this.resolveConnected = undefined; this.rejectAllPingWaiters(new Error("Disconnected")); @@ -1147,14 +1159,16 @@ export class LoroWebsocketClient { if (update.length <= FRAG_LIMIT) { // Send as a single DocUpdate with one update entry - ws.send( + this.safeSend( + ws, encode({ type: MessageType.DocUpdate, crdt, roomId, updates: [update], batchId, - } as DocUpdate) + } as DocUpdate), + "send-update" ); return; } @@ -1170,7 +1184,7 @@ export class LoroWebsocketClient { fragmentCount, totalSizeBytes: update.length, }; - ws.send(encode(header)); + this.safeSend(ws, encode(header), "send-fragment-header"); for (let i = 0; i < fragmentCount; i++) { const start = i * FRAG_LIMIT; @@ -1184,19 +1198,20 @@ export class LoroWebsocketClient { index: i, fragment, }; - ws.send(encode(msg)); + this.safeSend(ws, encode(msg), "send-fragment"); } } /** @internal Send Leave on the current websocket. */ sendLeave(crdt: CrdtType, roomId: string) { - if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return; - this.ws.send( + this.safeSend( + this.ws, encode({ type: MessageType.Leave, crdt, roomId, - } as Leave) + } as Leave), + "leave" ); } @@ -1226,6 +1241,7 @@ export class LoroWebsocketClient { this.clearPingTimer(); this.reconnectAttempts = 0; this.rejectConnected?.(new Error("Destroyed")); + void this.connectedPromise?.catch(() => { }); this.rejectConnected = undefined; this.resolveConnected = undefined; this.rejectAllPingWaiters(new Error("Destroyed")); @@ -1272,8 +1288,14 @@ export class LoroWebsocketClient { return typeof raw === "number" ? raw : undefined; }; + const safeClose = () => { + try { + ws.close(code, reason); + } catch { } + }; + if (readBufferedAmount() == null) { - ws.close(code, reason); + safeClose(); return; } @@ -1284,7 +1306,7 @@ export class LoroWebsocketClient { const state = ws.readyState; if (state === WebSocket.CLOSED || state === WebSocket.CLOSING) { requested = true; - ws.close(code, reason); + safeClose(); return; } @@ -1295,7 +1317,7 @@ export class LoroWebsocketClient { Date.now() - start >= timeoutMs ) { requested = true; - ws.close(code, reason); + safeClose(); return; } @@ -1344,8 +1366,9 @@ export class LoroWebsocketClient { if (this.ws && this.ws.readyState === WebSocket.OPEN) { // Avoid overlapping RTT probes if (this.awaitingPongSince == null) { - this.awaitingPongSince = Date.now(); - this.ws.send("ping"); + if (this.safeSend(this.ws, "ping", "ping")) { + this.awaitingPongSince = Date.now(); + } } else { // Still awaiting a pong; skip sending another ping } @@ -1437,9 +1460,40 @@ export class LoroWebsocketClient { return Math.min(max, Math.max(0, Math.floor(withJitter))); } + private safeSend( + ws: WebSocket | undefined, + data: Parameters[0], + context?: string + ): boolean { + if (!ws || ws !== this.ws) return false; + if (ws.readyState !== WebSocket.OPEN) { + if (context) { + this.emitError(new Error(`WebSocket not open during ${context}`)); + } + return false; + } + try { + ws.send(data); + return true; + } catch (err) { + this.emitError(err instanceof Error ? err : new Error(String(err))); + return false; + } + } + private logCbError(context: string, err: unknown) { // eslint-disable-next-line no-console console.error(`[loro-websocket] ${context} callback threw`, err); + this.emitError(err instanceof Error ? err : new Error(String(err))); + } + + private emitError(err: Error) { + try { + this.ops.onError?.(err); + } catch (cbErr) { + // eslint-disable-next-line no-console + console.error("[loro-websocket] onError callback threw", cbErr); + } } private isFatalClose(code?: number, reason?: string): boolean { @@ -1453,6 +1507,12 @@ export class LoroWebsocketClient { this.queuedJoins.push(payload); } + private sendJoinPayload(payload: Uint8Array) { + if (this.safeSend(this.ws, payload, "join")) return; + this.enqueueJoin(payload); + void this.connect(); + } + private flushQueuedJoins() { if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return; if (!this.queuedJoins.length) return;