diff --git a/.vscode/settings.json b/.vscode/settings.json index 85f4c06cd6..e0209de61b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -61,5 +61,8 @@ }, "directoryFilters": ["-tsunami/frontend/scaffold", "-dist", "-make"] }, - "tailwindCSS.lint.suggestCanonicalClasses": "ignore" + "tailwindCSS.lint.suggestCanonicalClasses": "ignore", + "go.coverageDecorator": { + "type": "gutter" + } } diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index a59661f0ba..5eb247c75c 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -20,6 +20,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/blocklogger" "github.com/wavetermdev/waveterm/pkg/filebackup" "github.com/wavetermdev/waveterm/pkg/filestore" + "github.com/wavetermdev/waveterm/pkg/jobcontroller" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" @@ -391,7 +392,7 @@ func createMainWshClient() { wshfs.RpcClient = rpc wshutil.DefaultRouter.RegisterTrustedLeaf(rpc, wshutil.DefaultRoute) wps.Broker.SetClient(wshutil.DefaultRouter) - localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local") + localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, wshremote.MakeRemoteRpcServerImpl(nil, wshutil.DefaultRouter, wshclient.GetBareRpcClient(), true), "conn:local") go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) wshutil.DefaultRouter.RegisterTrustedLeaf(localConnWsh, wshutil.MakeConnectionRouteId(wshrpc.LocalConnName)) } @@ -572,6 +573,7 @@ func main() { go backupCleanupLoop() go startupActivityUpdate(firstLaunch) // must be after startConfigWatcher() blocklogger.InitBlockLogger() + jobcontroller.InitJobController() go func() { defer func() { panichandler.PanicHandler("GetSystemSummary", recover()) diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 3fb6a10fcc..6ec0d5e4d7 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -38,11 +38,14 @@ var serverCmd = &cobra.Command{ } var connServerRouter bool +var connServerRouterDomainSocket bool var connServerConnName string var connServerDev bool +var ConnServerWshRouter *wshutil.WshRouter func init() { - serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode") + serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode (stdio upstream)") + serverCmd.Flags().BoolVar(&connServerRouterDomainSocket, "router-domainsocket", false, "run in local router mode (domain socket upstream)") serverCmd.Flags().StringVar(&connServerConnName, "conn", "", "connection name") serverCmd.Flags().BoolVar(&connServerDev, "dev", false, "enable dev mode with file logging and PID in logs") rootCmd.AddCommand(serverCmd) @@ -123,7 +126,12 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh RouteId: routeId, Conn: connServerConnName, } - connServerClient := wshutil.MakeWshRpc(rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, routeId) + + bareRouteId := wshutil.MakeRandomProcRouteId() + bareClient := wshutil.MakeWshRpc(wshrpc.RpcContext{}, &wshclient.WshServer{}, bareRouteId) + router.RegisterTrustedLeaf(bareClient, bareRouteId) + + connServerClient := wshutil.MakeWshRpc(rpcCtx, wshremote.MakeRemoteRpcServerImpl(os.Stdout, router, bareClient, false), routeId) router.RegisterTrustedLeaf(connServerClient, routeId) return connServerClient, nil } @@ -131,6 +139,7 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh func serverRunRouter() error { log.Printf("starting connserver router") router := wshutil.NewWshRouter() + ConnServerWshRouter = router termProxy := wshutil.MakeRpcProxy("connserver-term") rawCh := make(chan []byte, wshutil.DefaultOutputChSize) go func() { @@ -209,8 +218,112 @@ func serverRunRouter() error { select {} } +func serverRunRouterDomainSocket(jwtToken string) error { + log.Printf("starting connserver router (domain socket upstream)") + + // extract socket name from JWT token (unverified - we're on the client side) + sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken) + if err != nil { + return fmt.Errorf("error extracting socket name from JWT: %v", err) + } + + // connect to the forwarded domain socket + sockName = wavebase.ExpandHomeDirSafe(sockName) + conn, err := net.Dial("unix", sockName) + if err != nil { + return fmt.Errorf("error connecting to domain socket %s: %v", sockName, err) + } + + // create router + router := wshutil.NewWshRouter() + ConnServerWshRouter = router + + // create proxy for the domain socket connection + upstreamProxy := wshutil.MakeRpcProxy("connserver-upstream") + + // goroutine to write to the domain socket + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouterDomainSocket:WriteLoop", recover()) + }() + writeErr := wshutil.AdaptOutputChToStream(upstreamProxy.ToRemoteCh, conn) + if writeErr != nil { + log.Printf("error writing to upstream domain socket: %v\n", writeErr) + } + }() + + // goroutine to read from the domain socket + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouterDomainSocket:ReadLoop", recover()) + }() + defer func() { + log.Printf("upstream domain socket closed, shutting down") + wshutil.DoShutdown("", 0, true) + }() + wshutil.AdaptStreamToMsgCh(conn, upstreamProxy.FromRemoteCh) + }() + + // register the domain socket connection as upstream + router.RegisterUpstream(upstreamProxy) + + // setup the connserver rpc client (leaf) + client, err := setupConnServerRpcClientWithRouter(router) + if err != nil { + return fmt.Errorf("error setting up connserver rpc client: %v", err) + } + wshfs.RpcClient = client + + // authenticate with the upstream router using the JWT + _, err = wshclient.AuthenticateCommand(client, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + if err != nil { + return fmt.Errorf("error authenticating with upstream: %v", err) + } + log.Printf("authenticated with upstream router") + + // fetch and set JWT public key + log.Printf("trying to get JWT public key") + jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(client, nil) + if err != nil { + return fmt.Errorf("error getting jwt public key: %v", err) + } + jwtPublicKeyBytes, err := base64.StdEncoding.DecodeString(jwtPublicKeyB64) + if err != nil { + return fmt.Errorf("error decoding jwt public key: %v", err) + } + err = wavejwt.SetPublicKey(jwtPublicKeyBytes) + if err != nil { + return fmt.Errorf("error setting jwt public key: %v", err) + } + log.Printf("got JWT public key") + + // set up the local domain socket listener for local wsh commands + unixListener, err := MakeRemoteUnixListener() + if err != nil { + return fmt.Errorf("cannot create unix listener: %v", err) + } + log.Printf("unix listener started") + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouterDomainSocket:runListener", recover()) + }() + runListener(unixListener, router) + }() + + // run the sysinfo loop + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouterDomainSocket:RunSysInfoLoop", recover()) + }() + wshremote.RunSysInfoLoop(client, connServerConnName) + }() + + log.Printf("running server (router-domainsocket mode), successfully started") + select {} +} + func serverRunNormal(jwtToken string) error { - err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout}, jwtToken) + err := setupRpcClient(wshremote.MakeRemoteRpcServerImpl(os.Stdout, nil, nil, false), jwtToken) if err != nil { return err } @@ -283,6 +396,20 @@ func serverRun(cmd *cobra.Command, args []string) error { } return err } + if connServerRouterDomainSocket { + jwtToken, err := askForJwtToken() + if err != nil { + if logFile != nil { + fmt.Fprintf(logFile, "askForJwtToken error: %v\n", err) + } + return err + } + err = serverRunRouterDomainSocket(jwtToken) + if err != nil && logFile != nil { + fmt.Fprintf(logFile, "serverRunRouterDomainSocket error: %v\n", err) + } + return err + } jwtToken, err := askForJwtToken() if err != nil { if logFile != nil { diff --git a/cmd/wsh/cmd/wshcmd-jobdebug.go b/cmd/wsh/cmd/wshcmd-jobdebug.go new file mode 100644 index 0000000000..5ae68b7051 --- /dev/null +++ b/cmd/wsh/cmd/wshcmd-jobdebug.go @@ -0,0 +1,382 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" +) + +var jobDebugCmd = &cobra.Command{ + Use: "jobdebug", + Short: "debugging commands for the job system", + Hidden: true, + PersistentPreRunE: preRunSetupRpcClient, +} + +var jobDebugListCmd = &cobra.Command{ + Use: "list", + Short: "list all jobs with debug information", + RunE: jobDebugListRun, +} + +var jobDebugDeleteCmd = &cobra.Command{ + Use: "delete", + Short: "delete a job entry by jobid", + RunE: jobDebugDeleteRun, +} + +var jobDebugDeleteAllCmd = &cobra.Command{ + Use: "deleteall", + Short: "delete all jobs", + RunE: jobDebugDeleteAllRun, +} + +var jobDebugPruneCmd = &cobra.Command{ + Use: "prune", + Short: "remove jobs where the job manager is no longer running", + RunE: jobDebugPruneRun, +} + +var jobDebugExitCmd = &cobra.Command{ + Use: "exit", + Short: "exit a job manager", + RunE: jobDebugExitRun, +} + +var jobDebugDisconnectCmd = &cobra.Command{ + Use: "disconnect", + Short: "disconnect from a job manager", + RunE: jobDebugDisconnectRun, +} + +var jobDebugReconnectCmd = &cobra.Command{ + Use: "reconnect", + Short: "reconnect to a job manager", + RunE: jobDebugReconnectRun, +} + +var jobDebugReconnectConnCmd = &cobra.Command{ + Use: "reconnectconn", + Short: "reconnect all jobs for a connection", + RunE: jobDebugReconnectConnRun, +} + +var jobDebugGetOutputCmd = &cobra.Command{ + Use: "getoutput", + Short: "get the terminal output for a job", + RunE: jobDebugGetOutputRun, +} + +var jobDebugStartCmd = &cobra.Command{ + Use: "start", + Short: "start a new job", + Args: cobra.MinimumNArgs(1), + RunE: jobDebugStartRun, +} + +var jobDebugAttachJobCmd = &cobra.Command{ + Use: "attachjob", + Short: "attach a job to a block", + RunE: jobDebugAttachJobRun, +} + +var jobDebugDetachJobCmd = &cobra.Command{ + Use: "detachjob", + Short: "detach a job from its block", + RunE: jobDebugDetachJobRun, +} + +var jobIdFlag string +var jobDebugJsonFlag bool +var jobConnFlag string +var exitJobIdFlag string +var disconnectJobIdFlag string +var reconnectJobIdFlag string +var reconnectConnNameFlag string +var attachJobIdFlag string +var attachBlockIdFlag string +var detachJobIdFlag string + +func init() { + rootCmd.AddCommand(jobDebugCmd) + jobDebugCmd.AddCommand(jobDebugListCmd) + jobDebugCmd.AddCommand(jobDebugDeleteCmd) + jobDebugCmd.AddCommand(jobDebugDeleteAllCmd) + jobDebugCmd.AddCommand(jobDebugPruneCmd) + jobDebugCmd.AddCommand(jobDebugExitCmd) + jobDebugCmd.AddCommand(jobDebugDisconnectCmd) + jobDebugCmd.AddCommand(jobDebugReconnectCmd) + jobDebugCmd.AddCommand(jobDebugReconnectConnCmd) + jobDebugCmd.AddCommand(jobDebugGetOutputCmd) + jobDebugCmd.AddCommand(jobDebugStartCmd) + jobDebugCmd.AddCommand(jobDebugAttachJobCmd) + jobDebugCmd.AddCommand(jobDebugDetachJobCmd) + + jobDebugListCmd.Flags().BoolVar(&jobDebugJsonFlag, "json", false, "output as JSON") + + jobDebugDeleteCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to delete (required)") + jobDebugDeleteCmd.MarkFlagRequired("jobid") + + jobDebugExitCmd.Flags().StringVar(&exitJobIdFlag, "jobid", "", "job id to exit (required)") + jobDebugExitCmd.MarkFlagRequired("jobid") + + jobDebugDisconnectCmd.Flags().StringVar(&disconnectJobIdFlag, "jobid", "", "job id to disconnect (required)") + jobDebugDisconnectCmd.MarkFlagRequired("jobid") + + jobDebugReconnectCmd.Flags().StringVar(&reconnectJobIdFlag, "jobid", "", "job id to reconnect (required)") + jobDebugReconnectCmd.MarkFlagRequired("jobid") + + jobDebugReconnectConnCmd.Flags().StringVar(&reconnectConnNameFlag, "conn", "", "connection name (required)") + jobDebugReconnectConnCmd.MarkFlagRequired("conn") + + jobDebugGetOutputCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to get output for (required)") + jobDebugGetOutputCmd.MarkFlagRequired("jobid") + + jobDebugStartCmd.Flags().StringVar(&jobConnFlag, "conn", "", "connection name (required)") + jobDebugStartCmd.MarkFlagRequired("conn") + + jobDebugAttachJobCmd.Flags().StringVar(&attachJobIdFlag, "jobid", "", "job id to attach (required)") + jobDebugAttachJobCmd.MarkFlagRequired("jobid") + jobDebugAttachJobCmd.Flags().StringVar(&attachBlockIdFlag, "blockid", "", "block id to attach to (required)") + jobDebugAttachJobCmd.MarkFlagRequired("blockid") + + jobDebugDetachJobCmd.Flags().StringVar(&detachJobIdFlag, "jobid", "", "job id to detach (required)") + jobDebugDetachJobCmd.MarkFlagRequired("jobid") +} + +func jobDebugListRun(cmd *cobra.Command, args []string) error { + rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("getting job debug list: %w", err) + } + + connectedJobIds, err := wshclient.JobControllerConnectedJobsCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("getting connected job ids: %w", err) + } + + connectedMap := make(map[string]bool) + for _, jobId := range connectedJobIds { + connectedMap[jobId] = true + } + + if jobDebugJsonFlag { + jsonData, err := json.MarshalIndent(rtnData, "", " ") + if err != nil { + return fmt.Errorf("marshaling json: %w", err) + } + fmt.Printf("%s\n", string(jsonData)) + return nil + } + + fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", "OID", "Connection", "Connected", "Manager", "Cmd", "ExitCode", "Stream") + for _, job := range rtnData { + connectedStatus := "no" + if connectedMap[job.OID] { + connectedStatus = "yes" + } + + streamStatus := "-" + if job.StreamDone { + if job.StreamError == "" { + streamStatus = "EOF" + } else { + streamStatus = fmt.Sprintf("%q", job.StreamError) + } + } + + exitCode := "-" + if job.CmdExitTs > 0 { + if job.CmdExitCode != nil { + exitCode = fmt.Sprintf("%d", *job.CmdExitCode) + } else if job.CmdExitSignal != "" { + exitCode = job.CmdExitSignal + } else { + exitCode = "?" + } + } + + fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", + job.OID, job.Connection, connectedStatus, job.JobManagerStatus, job.Cmd, exitCode, streamStatus) + } + return nil +} + +func jobDebugDeleteRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerDeleteJobCommand(RpcClient, jobIdFlag, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("deleting job: %w", err) + } + + fmt.Printf("Job %s deleted successfully\n", jobIdFlag) + return nil +} + +func jobDebugDeleteAllRun(cmd *cobra.Command, args []string) error { + rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("getting job debug list: %w", err) + } + + if len(rtnData) == 0 { + fmt.Printf("No jobs to delete\n") + return nil + } + + deletedCount := 0 + for _, job := range rtnData { + err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + fmt.Printf("Error deleting job %s: %v\n", job.OID, err) + } else { + deletedCount++ + } + } + + fmt.Printf("Deleted %d of %d job(s)\n", deletedCount, len(rtnData)) + return nil +} + +func jobDebugPruneRun(cmd *cobra.Command, args []string) error { + rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("getting job debug list: %w", err) + } + + if len(rtnData) == 0 { + fmt.Printf("No jobs to prune\n") + return nil + } + + deletedCount := 0 + for _, job := range rtnData { + if job.JobManagerStatus != "running" { + err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + fmt.Printf("Error deleting job %s: %v\n", job.OID, err) + } else { + deletedCount++ + } + } + } + + if deletedCount == 0 { + fmt.Printf("No jobs with stopped job managers to prune\n") + } else { + fmt.Printf("Pruned %d job(s) with stopped job managers\n", deletedCount) + } + return nil +} + +func jobDebugExitRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerExitJobCommand(RpcClient, exitJobIdFlag, nil) + if err != nil { + return fmt.Errorf("exiting job manager: %w", err) + } + + fmt.Printf("Job manager for %s exited successfully\n", exitJobIdFlag) + return nil +} + +func jobDebugDisconnectRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerDisconnectJobCommand(RpcClient, disconnectJobIdFlag, nil) + if err != nil { + return fmt.Errorf("disconnecting from job manager: %w", err) + } + + fmt.Printf("Disconnected from job manager for %s successfully\n", disconnectJobIdFlag) + return nil +} + +func jobDebugReconnectRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerReconnectJobCommand(RpcClient, reconnectJobIdFlag, nil) + if err != nil { + return fmt.Errorf("reconnecting to job manager: %w", err) + } + + fmt.Printf("Reconnected to job manager for %s successfully\n", reconnectJobIdFlag) + return nil +} + +func jobDebugReconnectConnRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerReconnectJobsForConnCommand(RpcClient, reconnectConnNameFlag, nil) + if err != nil { + return fmt.Errorf("reconnecting jobs for connection: %w", err) + } + + fmt.Printf("Reconnected all jobs for connection %s successfully\n", reconnectConnNameFlag) + return nil +} + +func jobDebugGetOutputRun(cmd *cobra.Command, args []string) error { + fileData, err := wshclient.FileReadCommand(RpcClient, wshrpc.FileData{ + Info: &wshrpc.FileInfo{ + Path: fmt.Sprintf("wavefile://%s/term", jobIdFlag), + }, + }, &wshrpc.RpcOpts{Timeout: 10000}) + if err != nil { + return fmt.Errorf("reading job output: %w", err) + } + + if fileData.Data64 != "" { + decoded, err := base64.StdEncoding.DecodeString(fileData.Data64) + if err != nil { + return fmt.Errorf("decoding output data: %w", err) + } + fmt.Printf("%s", string(decoded)) + } + return nil +} + +func jobDebugStartRun(cmd *cobra.Command, args []string) error { + cmdToRun := args[0] + cmdArgs := args[1:] + + data := wshrpc.CommandJobControllerStartJobData{ + ConnName: jobConnFlag, + Cmd: cmdToRun, + Args: cmdArgs, + Env: make(map[string]string), + TermSize: nil, + } + + jobId, err := wshclient.JobControllerStartJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 10000}) + if err != nil { + return fmt.Errorf("starting job: %w", err) + } + + fmt.Printf("Job started successfully with ID: %s\n", jobId) + return nil +} + +func jobDebugAttachJobRun(cmd *cobra.Command, args []string) error { + data := wshrpc.CommandJobControllerAttachJobData{ + JobId: attachJobIdFlag, + BlockId: attachBlockIdFlag, + } + + err := wshclient.JobControllerAttachJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("attaching job: %w", err) + } + + fmt.Printf("Job %s attached to block %s successfully\n", attachJobIdFlag, attachBlockIdFlag) + return nil +} + +func jobDebugDetachJobRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerDetachJobCommand(RpcClient, detachJobIdFlag, &wshrpc.RpcOpts{Timeout: 5000}) + if err != nil { + return fmt.Errorf("detaching job: %w", err) + } + + fmt.Printf("Job %s detached successfully\n", detachJobIdFlag) + return nil +} diff --git a/cmd/wsh/cmd/wshcmd-jobmanager.go b/cmd/wsh/cmd/wshcmd-jobmanager.go new file mode 100644 index 0000000000..bf5562c3a7 --- /dev/null +++ b/cmd/wsh/cmd/wshcmd-jobmanager.go @@ -0,0 +1,119 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "os" + "strings" + "time" + + "github.com/google/uuid" + "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/jobmanager" +) + +var jobManagerCmd = &cobra.Command{ + Use: "jobmanager", + Hidden: true, + Short: "job manager for wave terminal", + Args: cobra.NoArgs, + RunE: jobManagerRun, +} + +var jobManagerJobId string +var jobManagerClientId string + +func init() { + jobManagerCmd.Flags().StringVar(&jobManagerJobId, "jobid", "", "job ID (UUID, required)") + jobManagerCmd.Flags().StringVar(&jobManagerClientId, "clientid", "", "client ID (UUID, required)") + jobManagerCmd.MarkFlagRequired("jobid") + jobManagerCmd.MarkFlagRequired("clientid") + rootCmd.AddCommand(jobManagerCmd) +} + +func jobManagerRun(cmd *cobra.Command, args []string) error { + _, err := uuid.Parse(jobManagerJobId) + if err != nil { + return fmt.Errorf("invalid jobid: must be a valid UUID") + } + + _, err = uuid.Parse(jobManagerClientId) + if err != nil { + return fmt.Errorf("invalid clientid: must be a valid UUID") + } + + publicKeyB64 := os.Getenv("WAVETERM_PUBLICKEY") + if publicKeyB64 == "" { + return fmt.Errorf("WAVETERM_PUBLICKEY environment variable is not set") + } + + publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64) + if err != nil { + return fmt.Errorf("failed to decode WAVETERM_PUBLICKEY: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + jobAuthToken, err := readJobAuthToken(ctx) + if err != nil { + return fmt.Errorf("failed to read job auth token: %v", err) + } + + readyFile := os.NewFile(3, "ready-pipe") + _, err = readyFile.Stat() + if err != nil { + return fmt.Errorf("ready pipe (fd 3) not available: %v", err) + } + + err = jobmanager.SetupJobManager(jobManagerClientId, jobManagerJobId, publicKeyBytes, jobAuthToken, readyFile) + if err != nil { + return fmt.Errorf("error setting up job manager: %v", err) + } + + select {} +} + +func readJobAuthToken(ctx context.Context) (string, error) { + resultCh := make(chan string, 1) + errorCh := make(chan error, 1) + + go func() { + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil { + errorCh <- fmt.Errorf("error reading from stdin: %v", err) + return + } + + line = strings.TrimSpace(line) + prefix := jobmanager.JobAccessTokenLabel + ":" + if !strings.HasPrefix(line, prefix) { + errorCh <- fmt.Errorf("invalid token format: expected '%s'", prefix) + return + } + + token := strings.TrimPrefix(line, prefix) + token = strings.TrimSpace(token) + if token == "" { + errorCh <- fmt.Errorf("empty job auth token") + return + } + + resultCh <- token + }() + + select { + case token := <-resultCh: + return token, nil + case err := <-errorCh: + return "", err + case <-ctx.Done(): + return "", ctx.Err() + } +} diff --git a/db/migrations-wstore/000011_job.down.sql b/db/migrations-wstore/000011_job.down.sql new file mode 100644 index 0000000000..34620c17aa --- /dev/null +++ b/db/migrations-wstore/000011_job.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS db_job; diff --git a/db/migrations-wstore/000011_job.up.sql b/db/migrations-wstore/000011_job.up.sql new file mode 100644 index 0000000000..3b032507bb --- /dev/null +++ b/db/migrations-wstore/000011_job.up.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS db_job ( + oid varchar(36) PRIMARY KEY, + version int NOT NULL, + data json NOT NULL +); diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index fffcf3e899..3caeb0f201 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -22,6 +22,21 @@ class RpcApiType { return client.wshRpcCall("authenticate", data, opts); } + // command "authenticatejobmanager" [call] + AuthenticateJobManagerCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise { + return client.wshRpcCall("authenticatejobmanager", data, opts); + } + + // command "authenticatejobmanagerverify" [call] + AuthenticateJobManagerVerifyCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise { + return client.wshRpcCall("authenticatejobmanagerverify", data, opts); + } + + // command "authenticatetojobmanager" [call] + AuthenticateToJobManagerCommand(client: WshClient, data: CommandAuthenticateToJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("authenticatetojobmanager", data, opts); + } + // command "authenticatetoken" [call] AuthenticateTokenCommand(client: WshClient, data: CommandAuthenticateTokenData, opts?: RpcOpts): Promise { return client.wshRpcCall("authenticatetoken", data, opts); @@ -377,6 +392,76 @@ class RpcApiType { return client.wshRpcCall("getwaveairatelimit", null, opts); } + // command "jobcmdexited" [call] + JobCmdExitedCommand(client: WshClient, data: CommandJobCmdExitedData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcmdexited", data, opts); + } + + // command "jobcontrollerattachjob" [call] + JobControllerAttachJobCommand(client: WshClient, data: CommandJobControllerAttachJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerattachjob", data, opts); + } + + // command "jobcontrollerconnectedjobs" [call] + JobControllerConnectedJobsCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerconnectedjobs", null, opts); + } + + // command "jobcontrollerdeletejob" [call] + JobControllerDeleteJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerdeletejob", data, opts); + } + + // command "jobcontrollerdetachjob" [call] + JobControllerDetachJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerdetachjob", data, opts); + } + + // command "jobcontrollerdisconnectjob" [call] + JobControllerDisconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerdisconnectjob", data, opts); + } + + // command "jobcontrollerexitjob" [call] + JobControllerExitJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerexitjob", data, opts); + } + + // command "jobcontrollerlist" [call] + JobControllerListCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerlist", null, opts); + } + + // command "jobcontrollerreconnectjob" [call] + JobControllerReconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerreconnectjob", data, opts); + } + + // command "jobcontrollerreconnectjobsforconn" [call] + JobControllerReconnectJobsForConnCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerreconnectjobsforconn", data, opts); + } + + // command "jobcontrollerstartjob" [call] + JobControllerStartJobCommand(client: WshClient, data: CommandJobControllerStartJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobcontrollerstartjob", data, opts); + } + + // command "jobinput" [call] + JobInputCommand(client: WshClient, data: CommandJobInputData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobinput", data, opts); + } + + // command "jobprepareconnect" [call] + JobPrepareConnectCommand(client: WshClient, data: CommandJobPrepareConnectData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobprepareconnect", data, opts); + } + + // command "jobstartstream" [call] + JobStartStreamCommand(client: WshClient, data: CommandJobStartStreamData, opts?: RpcOpts): Promise { + return client.wshRpcCall("jobstartstream", data, opts); + } + // command "listallappfiles" [call] ListAllAppFilesCommand(client: WshClient, data: CommandListAllAppFilesData, opts?: RpcOpts): Promise { return client.wshRpcCall("listallappfiles", data, opts); @@ -432,6 +517,11 @@ class RpcApiType { return client.wshRpcCall("recordtevent", data, opts); } + // command "remotedisconnectfromjobmanager" [call] + RemoteDisconnectFromJobManagerCommand(client: WshClient, data: CommandRemoteDisconnectFromJobManagerData, opts?: RpcOpts): Promise { + return client.wshRpcCall("remotedisconnectfromjobmanager", data, opts); + } + // command "remotefilecopy" [call] RemoteFileCopyCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefilecopy", data, opts); @@ -482,6 +572,16 @@ class RpcApiType { return client.wshRpcCall("remotemkdir", data, opts); } + // command "remotereconnecttojobmanager" [call] + RemoteReconnectToJobManagerCommand(client: WshClient, data: CommandRemoteReconnectToJobManagerData, opts?: RpcOpts): Promise { + return client.wshRpcCall("remotereconnecttojobmanager", data, opts); + } + + // command "remotestartjob" [call] + RemoteStartJobCommand(client: WshClient, data: CommandRemoteStartJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("remotestartjob", data, opts); + } + // command "remotestreamcpudata" [responsestream] RemoteStreamCpuDataCommand(client: WshClient, opts?: RpcOpts): AsyncGenerator { return client.wshRpcStream("remotestreamcpudata", null, opts); @@ -497,6 +597,11 @@ class RpcApiType { return client.wshRpcStream("remotetarstream", data, opts); } + // command "remoteterminatejobmanager" [call] + RemoteTerminateJobManagerCommand(client: WshClient, data: CommandRemoteTerminateJobManagerData, opts?: RpcOpts): Promise { + return client.wshRpcCall("remoteterminatejobmanager", data, opts); + } + // command "remotewritefile" [call] RemoteWriteFileCommand(client: WshClient, data: FileData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotewritefile", data, opts); @@ -572,6 +677,11 @@ class RpcApiType { return client.wshRpcCall("startbuilder", data, opts); } + // command "startjob" [call] + StartJobCommand(client: WshClient, data: CommandStartJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("startjob", data, opts); + } + // command "stopbuilder" [call] StopBuilderCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { return client.wshRpcCall("stopbuilder", data, opts); @@ -607,6 +717,11 @@ class RpcApiType { return client.wshRpcCall("termgetscrollbacklines", data, opts); } + // command "termupdateattachedjob" [call] + TermUpdateAttachedJobCommand(client: WshClient, data: CommandTermUpdateAttachedJobData, opts?: RpcOpts): Promise { + return client.wshRpcCall("termupdateattachedjob", data, opts); + } + // command "test" [call] TestCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { return client.wshRpcCall("test", data, opts); diff --git a/frontend/app/view/term/term-wsh.tsx b/frontend/app/view/term/term-wsh.tsx index 782a174913..16e31ae334 100644 --- a/frontend/app/view/term/term-wsh.tsx +++ b/frontend/app/view/term/term-wsh.tsx @@ -104,6 +104,11 @@ export class TermWshClient extends WshClient { } } + async handle_termupdateattachedjob(rh: RpcResponseHelper, data: CommandTermUpdateAttachedJobData): Promise { + console.log("term-update-attached-job", this.blockId, data); + // TODO: implement frontend logic to handle job attachment updates + } + async handle_termgetscrollbacklines( rh: RpcResponseHelper, data: CommandTermGetScrollbackLinesData diff --git a/frontend/app/view/term/term.tsx b/frontend/app/view/term/term.tsx index 10fd0fb112..d1ca981c97 100644 --- a/frontend/app/view/term/term.tsx +++ b/frontend/app/view/term/term.tsx @@ -298,6 +298,7 @@ const TerminalView = ({ blockId, model }: ViewComponentProps) => useWebGl: !termSettings?.["term:disablewebgl"], sendDataHandler: model.sendDataToController.bind(model), nodeModel: model.nodeModel, + jobId: blockData?.jobid, } ); (window as any).term = termWrap; diff --git a/frontend/app/view/term/termwrap.ts b/frontend/app/view/term/termwrap.ts index 6393b2165a..60743db584 100644 --- a/frontend/app/view/term/termwrap.ts +++ b/frontend/app/view/term/termwrap.ts @@ -3,7 +3,6 @@ import type { BlockNodeModel } from "@/app/block/blocktypes"; import { getFileSubject } from "@/app/store/wps"; -import { sendWSCommand } from "@/app/store/ws"; import { RpcApi } from "@/app/store/wshclientapi"; import { TabRpcClient } from "@/app/store/wshrpcutil"; import { WOS, fetchWaveFile, getApi, getSettingsKeyAtom, globalStore, openLink, recordTEvent } from "@/store/global"; @@ -50,6 +49,7 @@ type TermWrapOptions = { useWebGl?: boolean; sendDataHandler?: (data: string) => void; nodeModel?: BlockNodeModel; + jobId?: string; }; // for xterm OSC handlers, we return true always because we "own" the OSC number. @@ -375,6 +375,7 @@ function handleOsc16162Command(data: string, blockId: string, loaded: boolean, t export class TermWrap { tabId: string; blockId: string; + jobId: string; ptyOffset: number; dataBytesProcessed: number; terminal: Terminal; @@ -422,6 +423,7 @@ export class TermWrap { this.loaded = false; this.tabId = tabId; this.blockId = blockId; + this.jobId = waveOptions.jobId; this.sendDataHandler = waveOptions.sendDataHandler; this.nodeModel = waveOptions.nodeModel; this.ptyOffset = 0; @@ -495,6 +497,10 @@ export class TermWrap { }); } + getZoneId(): string { + return this.jobId ?? this.blockId; + } + resetCompositionState() { this.isComposing = false; this.composingData = ""; @@ -566,7 +572,7 @@ export class TermWrap { }); } - this.mainFileSubject = getFileSubject(this.blockId, TermFileName); + this.mainFileSubject = getFileSubject(this.getZoneId(), TermFileName); this.mainFileSubject.subscribe(this.handleNewFileSubjectData.bind(this)); try { @@ -699,8 +705,9 @@ export class TermWrap { } async loadInitialTerminalData(): Promise { - let startTs = Date.now(); - const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(this.blockId, TermCacheFileName); + const startTs = Date.now(); + const zoneId = this.getZoneId(); + const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(zoneId, TermCacheFileName); let ptyOffset = 0; if (cacheFile != null) { ptyOffset = cacheFile.meta["ptyoffset"] ?? 0; @@ -722,7 +729,7 @@ export class TermWrap { } } } - const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(this.blockId, TermFileName, ptyOffset); + const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(zoneId, TermFileName, ptyOffset); console.log( `terminal loaded cachefile:${cacheData?.byteLength ?? 0} main:${mainData?.byteLength ?? 0} bytes, ${Date.now() - startTs}ms` ); @@ -751,12 +758,7 @@ export class TermWrap { this.fitAddon.fit(); if (oldRows !== this.terminal.rows || oldCols !== this.terminal.cols) { const termSize: TermSize = { rows: this.terminal.rows, cols: this.terminal.cols }; - const wsCommand: SetBlockTermSizeWSCommand = { - wscommand: "setblocktermsize", - blockid: this.blockId, - termsize: termSize, - }; - sendWSCommand(wsCommand); + RpcApi.ControllerInputCommand(TabRpcClient, { blockid: this.blockId, termsize: termSize }); } dlog("resize", `${this.terminal.rows}x${this.terminal.cols}`, `${oldRows}x${oldCols}`, this.hasResized); if (!this.hasResized) { diff --git a/frontend/app/view/vdom/vdom-model.tsx b/frontend/app/view/vdom/vdom-model.tsx index fbe556daba..40877894f2 100644 --- a/frontend/app/view/vdom/vdom-model.tsx +++ b/frontend/app/view/vdom/vdom-model.tsx @@ -162,7 +162,7 @@ export class VDomModel { this.queueUpdate(true); } this.routeGoneUnsub = waveEventSubscribe({ - eventType: "route:gone", + eventType: "route:down", scope: curBackendRoute, handler: (event: WaveEvent) => { this.disposed = true; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 4658bc1af2..a4ec175c1f 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -112,6 +112,7 @@ declare global { runtimeopts?: RuntimeOpts; stickers?: StickerType[]; subblockids?: string[]; + jobid?: string; }; // blockcontroller.BlockControllerRuntimeStatus @@ -139,13 +140,6 @@ declare global { files: FileInfo[]; }; - // webcmd.BlockInputWSCommand - type BlockInputWSCommand = { - wscommand: "blockinput"; - blockid: string; - inputdata64: string; - }; - // wshrpc.BlocksListEntry type BlocksListEntry = { windowid: string; @@ -179,6 +173,7 @@ declare global { tosagreed?: number; hasoldhistory?: boolean; tempoid?: string; + installid?: string; }; // workspaceservice.CloseTabRtnType @@ -194,6 +189,12 @@ declare global { data: {[key: string]: any}; }; + // wshrpc.CommandAuthenticateJobManagerData + type CommandAuthenticateJobManagerData = { + jobid: string; + jobauthtoken: string; + }; + // wshrpc.CommandAuthenticateRtnData type CommandAuthenticateRtnData = { env?: {[key: string]: string}; @@ -201,6 +202,11 @@ declare global { rpccontext?: RpcContext; }; + // wshrpc.CommandAuthenticateToJobData + type CommandAuthenticateToJobData = { + jobaccesstoken: string; + }; + // wshrpc.CommandAuthenticateTokenData type CommandAuthenticateTokenData = { token: string; @@ -343,6 +349,59 @@ declare global { chatid: string; }; + // wshrpc.CommandJobCmdExitedData + type CommandJobCmdExitedData = { + jobid: string; + exitcode?: number; + exitsignal?: string; + exiterr?: string; + exitts?: number; + }; + + // wshrpc.CommandJobConnectRtnData + type CommandJobConnectRtnData = { + seq: number; + streamdone?: boolean; + streamerror?: string; + hasexited?: boolean; + exitcode?: number; + exitsignal?: string; + exiterr?: string; + }; + + // wshrpc.CommandJobControllerAttachJobData + type CommandJobControllerAttachJobData = { + jobid: string; + blockid: string; + }; + + // wshrpc.CommandJobControllerStartJobData + type CommandJobControllerStartJobData = { + connname: string; + cmd: string; + args: string[]; + env: {[key: string]: string}; + termsize?: TermSize; + }; + + // wshrpc.CommandJobInputData + type CommandJobInputData = { + jobid: string; + inputdata64?: string; + signame?: string; + termsize?: TermSize; + }; + + // wshrpc.CommandJobPrepareConnectData + type CommandJobPrepareConnectData = { + streammeta: StreamMeta; + seq: number; + }; + + // wshrpc.CommandJobStartStreamData + type CommandJobStartStreamData = { + }; + // wshrpc.CommandListAllAppFilesData type CommandListAllAppFilesData = { appid: string; @@ -397,6 +456,11 @@ declare global { modts?: number; }; + // wshrpc.CommandRemoteDisconnectFromJobManagerData + type CommandRemoteDisconnectFromJobManagerData = { + jobid: string; + }; + // wshrpc.CommandRemoteListEntriesData type CommandRemoteListEntriesData = { path: string; @@ -408,6 +472,36 @@ declare global { fileinfo?: FileInfo[]; }; + // wshrpc.CommandRemoteReconnectToJobManagerData + type CommandRemoteReconnectToJobManagerData = { + jobid: string; + jobauthtoken: string; + mainserverjwttoken: string; + jobmanagerpid: number; + jobmanagerstartts: number; + }; + + // wshrpc.CommandRemoteReconnectToJobManagerRtnData + type CommandRemoteReconnectToJobManagerRtnData = { + success: boolean; + jobmanagergone: boolean; + error?: string; + }; + + // wshrpc.CommandRemoteStartJobData + type CommandRemoteStartJobData = { + cmd: string; + args: string[]; + env: {[key: string]: string}; + termsize: TermSize; + streammeta?: StreamMeta; + jobauthtoken: string; + jobid: string; + mainserverjwttoken: string; + clientid: string; + publickeybase64: string; + }; + // wshrpc.CommandRemoteStreamFileData type CommandRemoteStreamFileData = { path: string; @@ -420,6 +514,13 @@ declare global { opts?: FileCopyOpts; }; + // wshrpc.CommandRemoteTerminateJobManagerData + type CommandRemoteTerminateJobManagerData = { + jobid: string; + jobmanagerpid: number; + jobmanagerstartts: number; + }; + // wshrpc.CommandRenameAppFileData type CommandRenameAppFileData = { appid: string; @@ -461,9 +562,26 @@ declare global { builderid: string; }; + // wshrpc.CommandStartJobData + type CommandStartJobData = { + cmd: string; + args: string[]; + env: {[key: string]: string}; + termsize: TermSize; + streammeta?: StreamMeta; + }; + + // wshrpc.CommandStartJobRtnData + type CommandStartJobRtnData = { + cmdpid: number; + cmdstartts: number; + jobmanagerpid: number; + jobmanagerstartts: number; + }; + // wshrpc.CommandStreamAckData type CommandStreamAckData = { - id: number; + id: string; seq: number; rwnd: number; fin?: boolean; @@ -474,7 +592,7 @@ declare global { // wshrpc.CommandStreamData type CommandStreamData = { - id: number; + id: string; seq: number; data64?: string; eof?: boolean; @@ -496,6 +614,12 @@ declare global { lastupdated: number; }; + // wshrpc.CommandTermUpdateAttachedJobData + type CommandTermUpdateAttachedJobData = { + blockid: string; + jobid?: string; + }; + // wshrpc.CommandVarData type CommandVarData = { key: string; @@ -793,6 +917,32 @@ declare global { configerrors: ConfigError[]; }; + // waveobj.Job + type Job = WaveObj & { + connection: string; + jobkind: string; + cmd: string; + cmdargs?: string[]; + cmdenv?: {[key: string]: string}; + jobauthtoken: string; + attachedblockid?: string; + terminateonreconnect?: boolean; + jobmanagerstatus: string; + jobmanagerdonereason?: string; + jobmanagerstartuperror?: string; + jobmanagerpid?: number; + jobmanagerstartts?: number; + cmdpid?: number; + cmdstartts?: number; + cmdtermsize: TermSize; + cmdexitts?: number; + cmdexitcode?: number; + cmdexitsignal?: string; + cmdexiterror?: string; + streamdone?: boolean; + streamerror?: string; + }; + // waveobj.LayoutActionData type LayoutActionData = { actiontype: string; @@ -1062,13 +1212,6 @@ declare global { optional: boolean; }; - // webcmd.SetBlockTermSizeWSCommand - type SetBlockTermSizeWSCommand = { - wscommand: "setblocktermsize"; - blockid: string; - termsize: TermSize; - }; - // wconfig.SettingsType type SettingsType = { "app:*"?: boolean; @@ -1186,6 +1329,14 @@ declare global { display: StickerDisplayOptsType; }; + // wshrpc.StreamMeta + type StreamMeta = { + id: string; + rwnd: number; + readerrouteid: string; + writerrouteid: string; + }; + // wps.SubscriptionRequest type SubscriptionRequest = { event: string; @@ -1640,7 +1791,7 @@ declare global { type WSCommandType = { wscommand: string; - } & ( SetBlockTermSizeWSCommand | BlockInputWSCommand | WSRpcCommand ); + } & ( WSRpcCommand ); // eventbus.WSEventType type WSEventType = { diff --git a/pkg/jobcontroller/jobcontroller.go b/pkg/jobcontroller/jobcontroller.go new file mode 100644 index 0000000000..66f928d920 --- /dev/null +++ b/pkg/jobcontroller/jobcontroller.go @@ -0,0 +1,853 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobcontroller + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "log" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/filestore" + "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" + "github.com/wavetermdev/waveterm/pkg/streamclient" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wavejwt" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wps" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wstore" +) + +const ( + JobStatus_Init = "init" + JobStatus_Running = "running" + JobStatus_Done = "done" +) + +const ( + JobDoneReason_StartupError = "startuperror" + JobDoneReason_Gone = "gone" + JobDoneReason_Terminated = "terminated" +) + +const ( + JobConnStatus_Disconnected = "disconnected" + JobConnStatus_Connecting = "connecting" + JobConnStatus_Connected = "connected" +) + +const DefaultStreamRwnd = 64 * 1024 +const MetaKey_TotalGap = "totalgap" +const JobOutputFileName = "term" + +func isJobManagerRunning(job *waveobj.Job) bool { + return job.JobManagerStatus == JobStatus_Running +} + +var ( + jobConnStates = make(map[string]string) + jobConnStatesLock sync.Mutex +) + +func getMetaInt64(meta wshrpc.FileMeta, key string) int64 { + val, ok := meta[key] + if !ok { + return 0 + } + if intVal, ok := val.(int64); ok { + return intVal + } + if floatVal, ok := val.(float64); ok { + return int64(floatVal) + } + return 0 +} + +func InitJobController() { + rpcClient := wshclient.GetBareRpcClient() + rpcClient.EventListener.On(wps.Event_RouteUp, handleRouteUpEvent) + rpcClient.EventListener.On(wps.Event_RouteDown, handleRouteDownEvent) + wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{ + Event: wps.Event_RouteUp, + AllScopes: true, + }, nil) + wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{ + Event: wps.Event_RouteDown, + AllScopes: true, + }, nil) +} + +func handleRouteUpEvent(event *wps.WaveEvent) { + handleRouteEvent(event, JobConnStatus_Connected) +} + +func handleRouteDownEvent(event *wps.WaveEvent) { + handleRouteEvent(event, JobConnStatus_Disconnected) +} + +func handleRouteEvent(event *wps.WaveEvent, newStatus string) { + for _, scope := range event.Scopes { + if strings.HasPrefix(scope, "job:") { + jobId := strings.TrimPrefix(scope, "job:") + SetJobConnStatus(jobId, newStatus) + log.Printf("[job:%s] connection status changed to %s", jobId, newStatus) + } + } +} + +func GetJobConnStatus(jobId string) string { + jobConnStatesLock.Lock() + defer jobConnStatesLock.Unlock() + status, exists := jobConnStates[jobId] + if !exists { + return JobConnStatus_Disconnected + } + return status +} + +func SetJobConnStatus(jobId string, status string) { + jobConnStatesLock.Lock() + defer jobConnStatesLock.Unlock() + if status == JobConnStatus_Disconnected { + delete(jobConnStates, jobId) + } else { + jobConnStates[jobId] = status + } +} + +func GetConnectedJobIds() []string { + jobConnStatesLock.Lock() + defer jobConnStatesLock.Unlock() + var connectedJobIds []string + for jobId, status := range jobConnStates { + if status == JobConnStatus_Connected { + connectedJobIds = append(connectedJobIds, jobId) + } + } + return connectedJobIds +} + +func ensureJobConnected(ctx context.Context, jobId string) (*waveobj.Job, error) { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + return nil, fmt.Errorf("failed to get job: %w", err) + } + + isConnected, err := conncontroller.IsConnected(job.Connection) + if err != nil { + return nil, fmt.Errorf("error checking connection status: %w", err) + } + if !isConnected { + return nil, fmt.Errorf("connection %q is not connected", job.Connection) + } + + jobConnStatus := GetJobConnStatus(jobId) + if jobConnStatus != JobConnStatus_Connected { + return nil, fmt.Errorf("job is not connected (status: %s)", jobConnStatus) + } + + return job, nil +} + +type StartJobParams struct { + ConnName string + Cmd string + Args []string + Env map[string]string + TermSize *waveobj.TermSize +} + +func StartJob(ctx context.Context, params StartJobParams) (string, error) { + if params.ConnName == "" { + return "", fmt.Errorf("connection name is required") + } + if params.Cmd == "" { + return "", fmt.Errorf("command is required") + } + if params.TermSize == nil { + params.TermSize = &waveobj.TermSize{Rows: 24, Cols: 80} + } + + isConnected, err := conncontroller.IsConnected(params.ConnName) + if err != nil { + return "", fmt.Errorf("error checking connection status: %w", err) + } + if !isConnected { + return "", fmt.Errorf("connection %q is not connected", params.ConnName) + } + + jobId := uuid.New().String() + jobAuthToken, err := utilfn.RandomHexString(32) + if err != nil { + return "", fmt.Errorf("failed to generate job auth token: %w", err) + } + + jobAccessClaims := &wavejwt.WaveJwtClaims{ + MainServer: true, + JobId: jobId, + } + jobAccessToken, err := wavejwt.Sign(jobAccessClaims) + if err != nil { + return "", fmt.Errorf("failed to generate job access token: %w", err) + } + + job := &waveobj.Job{ + OID: jobId, + Connection: params.ConnName, + Cmd: params.Cmd, + CmdArgs: params.Args, + CmdEnv: params.Env, + CmdTermSize: *params.TermSize, + JobAuthToken: jobAuthToken, + JobManagerStatus: JobStatus_Init, + Meta: make(waveobj.MetaMapType), + } + + err = wstore.DBInsert(ctx, job) + if err != nil { + return "", fmt.Errorf("failed to create job in database: %w", err) + } + + bareRpc := wshclient.GetBareRpcClient() + broker := bareRpc.StreamBroker + readerRouteId := wshclient.GetBareRpcClientRouteId() + writerRouteId := wshutil.MakeJobRouteId(jobId) + reader, streamMeta := broker.CreateStreamReader(readerRouteId, writerRouteId, DefaultStreamRwnd) + + fileOpts := wshrpc.FileOpts{ + MaxSize: 10 * 1024 * 1024, + Circular: true, + } + err = filestore.WFS.MakeFile(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{}, fileOpts) + if err != nil { + return "", fmt.Errorf("failed to create WaveFS file: %w", err) + } + + clientId, err := wstore.DBGetSingleton[*waveobj.Client](ctx) + if err != nil || clientId == nil { + return "", fmt.Errorf("failed to get client: %w", err) + } + + publicKey := wavejwt.GetPublicKey() + publicKeyBase64 := base64.StdEncoding.EncodeToString(publicKey) + + startJobData := wshrpc.CommandRemoteStartJobData{ + Cmd: params.Cmd, + Args: params.Args, + Env: params.Env, + TermSize: *params.TermSize, + StreamMeta: streamMeta, + JobAuthToken: jobAuthToken, + JobId: jobId, + MainServerJwtToken: jobAccessToken, + ClientId: clientId.OID, + PublicKeyBase64: publicKeyBase64, + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeConnectionRouteId(params.ConnName), + Timeout: 30000, + } + + log.Printf("[job:%s] sending RemoteStartJobCommand to connection %s", jobId, params.ConnName) + rtnData, err := wshclient.RemoteStartJobCommand(bareRpc, startJobData, rpcOpts) + if err != nil { + log.Printf("[job:%s] RemoteStartJobCommand failed: %v", jobId, err) + errMsg := fmt.Sprintf("failed to start job: %v", err) + wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.JobManagerStatus = JobStatus_Done + job.JobManagerDoneReason = JobDoneReason_StartupError + job.JobManagerStartupError = errMsg + }) + return "", fmt.Errorf("failed to start remote job: %w", err) + } + + log.Printf("[job:%s] RemoteStartJobCommand succeeded, cmdpid=%d cmdstartts=%d jobmanagerpid=%d jobmanagerstartts=%d", jobId, rtnData.CmdPid, rtnData.CmdStartTs, rtnData.JobManagerPid, rtnData.JobManagerStartTs) + err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.CmdPid = rtnData.CmdPid + job.CmdStartTs = rtnData.CmdStartTs + job.JobManagerPid = rtnData.JobManagerPid + job.JobManagerStartTs = rtnData.JobManagerStartTs + job.JobManagerStatus = JobStatus_Running + }) + if err != nil { + log.Printf("[job:%s] warning: failed to update job status to running: %v", jobId, err) + } else { + log.Printf("[job:%s] job status updated to running", jobId) + } + + go func() { + defer func() { + panichandler.PanicHandler("jobcontroller:runOutputLoop", recover()) + }() + runOutputLoop(context.Background(), jobId, reader) + }() + + return jobId, nil +} + +func handleAppendJobFile(ctx context.Context, jobId string, fileName string, data []byte) error { + err := filestore.WFS.AppendData(ctx, jobId, fileName, data) + if err != nil { + return fmt.Errorf("error appending to job file: %w", err) + } + wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_BlockFile, + Scopes: []string{ + waveobj.MakeORef(waveobj.OType_Job, jobId).String(), + }, + Data: &wps.WSFileEventData{ + ZoneId: jobId, + FileName: fileName, + FileOp: wps.FileOp_Append, + Data64: base64.StdEncoding.EncodeToString(data), + }, + }) + return nil +} + +func runOutputLoop(ctx context.Context, jobId string, reader *streamclient.Reader) { + defer func() { + log.Printf("[job:%s] output loop finished", jobId) + }() + + log.Printf("[job:%s] output loop started", jobId) + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + log.Printf("[job:%s] received %d bytes of data", jobId, n) + appendErr := handleAppendJobFile(ctx, jobId, JobOutputFileName, buf[:n]) + if appendErr != nil { + log.Printf("[job:%s] error appending data to WaveFS: %v", jobId, appendErr) + } else { + log.Printf("[job:%s] successfully appended %d bytes to WaveFS", jobId, n) + } + } + + if err == io.EOF { + log.Printf("[job:%s] stream ended (EOF)", jobId) + updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.StreamDone = true + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr) + } + tryTerminateJobManager(ctx, jobId) + break + } + + if err != nil { + log.Printf("[job:%s] stream error: %v", jobId, err) + streamErr := err.Error() + updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.StreamDone = true + job.StreamError = streamErr + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job stream error: %v", jobId, updateErr) + } + tryTerminateJobManager(ctx, jobId) + break + } + } +} + +func HandleCmdJobExited(ctx context.Context, jobId string, data wshrpc.CommandJobCmdExitedData) error { + err := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.CmdExitError = data.ExitErr + job.CmdExitCode = data.ExitCode + job.CmdExitSignal = data.ExitSignal + job.CmdExitTs = data.ExitTs + }) + if err != nil { + return fmt.Errorf("failed to update job exit status: %w", err) + } + tryTerminateJobManager(ctx, jobId) + return nil +} + +func tryTerminateJobManager(ctx context.Context, jobId string) { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + log.Printf("[job:%s] error getting job for termination check: %v", jobId, err) + return + } + + if job.JobManagerStatus != JobStatus_Running { + return + } + + cmdExited := job.CmdExitTs != 0 + + if !cmdExited || !job.StreamDone { + log.Printf("[job:%s] not ready for termination: exited=%v streamDone=%v", jobId, cmdExited, job.StreamDone) + return + } + + log.Printf("[job:%s] both job cmd exited and stream finished, terminating job manager", jobId) + + err = TerminateJobManager(ctx, jobId) + if err != nil { + log.Printf("[job:%s] error terminating job manager: %v", jobId, err) + } +} + +func TerminateJobManager(ctx context.Context, jobId string) error { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + return fmt.Errorf("failed to get job: %w", err) + } + + return remoteTerminateJobManager(ctx, job) +} + +func DisconnectJob(ctx context.Context, jobId string) error { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + return fmt.Errorf("failed to get job: %w", err) + } + + bareRpc := wshclient.GetBareRpcClient() + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeConnectionRouteId(job.Connection), + Timeout: 5000, + } + + disconnectData := wshrpc.CommandRemoteDisconnectFromJobManagerData{ + JobId: jobId, + } + + err = wshclient.RemoteDisconnectFromJobManagerCommand(bareRpc, disconnectData, rpcOpts) + if err != nil { + return fmt.Errorf("failed to send disconnect command: %w", err) + } + + log.Printf("[job:%s] job disconnect command sent successfully", jobId) + return nil +} + +func remoteTerminateJobManager(ctx context.Context, job *waveobj.Job) error { + log.Printf("[job:%s] terminating job manager", job.OID) + + bareRpc := wshclient.GetBareRpcClient() + terminateData := wshrpc.CommandRemoteTerminateJobManagerData{ + JobId: job.OID, + JobManagerPid: job.JobManagerPid, + JobManagerStartTs: job.JobManagerStartTs, + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeConnectionRouteId(job.Connection), + Timeout: 5000, + } + + err := wshclient.RemoteTerminateJobManagerCommand(bareRpc, terminateData, rpcOpts) + if err != nil { + log.Printf("[job:%s] error terminating job manager: %v", job.OID, err) + return fmt.Errorf("failed to terminate job manager: %w", err) + } + + updateErr := wstore.DBUpdateFn(ctx, job.OID, func(job *waveobj.Job) { + job.JobManagerStatus = JobStatus_Done + job.JobManagerDoneReason = JobDoneReason_Terminated + job.TerminateOnReconnect = false + if !job.StreamDone { + job.StreamDone = true + job.StreamError = "job manager terminated" + } + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job status after termination: %v", job.OID, updateErr) + } + + log.Printf("[job:%s] job manager terminated successfully", job.OID) + return nil +} + +func ReconnectJob(ctx context.Context, jobId string) error { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + return fmt.Errorf("failed to get job: %w", err) + } + isConnected, err := conncontroller.IsConnected(job.Connection) + if err != nil { + return fmt.Errorf("error checking connection status: %w", err) + } + if !isConnected { + return fmt.Errorf("connection %q is not connected", job.Connection) + } + + if job.TerminateOnReconnect { + return remoteTerminateJobManager(ctx, job) + } + + bareRpc := wshclient.GetBareRpcClient() + + jobAccessClaims := &wavejwt.WaveJwtClaims{ + MainServer: true, + JobId: jobId, + } + jobAccessToken, err := wavejwt.Sign(jobAccessClaims) + if err != nil { + return fmt.Errorf("failed to generate job access token: %w", err) + } + + reconnectData := wshrpc.CommandRemoteReconnectToJobManagerData{ + JobId: jobId, + JobAuthToken: job.JobAuthToken, + MainServerJwtToken: jobAccessToken, + JobManagerPid: job.JobManagerPid, + JobManagerStartTs: job.JobManagerStartTs, + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeConnectionRouteId(job.Connection), + Timeout: 5000, + } + + log.Printf("[job:%s] sending RemoteReconnectToJobManagerCommand to connection %s", jobId, job.Connection) + rtnData, err := wshclient.RemoteReconnectToJobManagerCommand(bareRpc, reconnectData, rpcOpts) + if err != nil { + log.Printf("[job:%s] RemoteReconnectToJobManagerCommand failed: %v", jobId, err) + return fmt.Errorf("failed to reconnect to job manager: %w", err) + } + + if !rtnData.Success { + log.Printf("[job:%s] RemoteReconnectToJobManagerCommand returned error: %s", jobId, rtnData.Error) + if rtnData.JobManagerGone { + updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.JobManagerStatus = JobStatus_Done + job.JobManagerDoneReason = JobDoneReason_Gone + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job manager running status: %v", jobId, updateErr) + } + return fmt.Errorf("job manager has exited: %s", rtnData.Error) + } + return fmt.Errorf("failed to reconnect to job manager: %s", rtnData.Error) + } + + log.Printf("[job:%s] RemoteReconnectToJobManagerCommand succeeded, waiting for route", jobId) + + routeId := wshutil.MakeJobRouteId(jobId) + waitCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second) + defer cancelFn() + err = wshutil.DefaultRouter.WaitForRegister(waitCtx, routeId) + if err != nil { + return fmt.Errorf("route did not establish after successful reconnection: %w", err) + } + + log.Printf("[job:%s] route established, restarting streaming", jobId) + return RestartStreaming(ctx, jobId, true) +} + +func ReconnectJobsForConn(ctx context.Context, connName string) error { + isConnected, err := conncontroller.IsConnected(connName) + if err != nil { + return fmt.Errorf("error checking connection status: %w", err) + } + if !isConnected { + return fmt.Errorf("connection %q is not connected", connName) + } + + allJobs, err := wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job) + if err != nil { + return fmt.Errorf("failed to get jobs: %w", err) + } + + var jobsToReconnect []*waveobj.Job + for _, job := range allJobs { + if job.Connection == connName && isJobManagerRunning(job) { + jobsToReconnect = append(jobsToReconnect, job) + } + } + + log.Printf("[conn:%s] found %d jobs to reconnect", connName, len(jobsToReconnect)) + + for _, job := range jobsToReconnect { + err = ReconnectJob(ctx, job.OID) + if err != nil { + log.Printf("[job:%s] error reconnecting: %v", job.OID, err) + } + } + + return nil +} + +func RestartStreaming(ctx context.Context, jobId string, knownConnected bool) error { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) + if err != nil { + return fmt.Errorf("failed to get job: %w", err) + } + + if !knownConnected { + isConnected, err := conncontroller.IsConnected(job.Connection) + if err != nil { + return fmt.Errorf("error checking connection status: %w", err) + } + if !isConnected { + return fmt.Errorf("connection %q is not connected", job.Connection) + } + + jobConnStatus := GetJobConnStatus(jobId) + if jobConnStatus != JobConnStatus_Connected { + return fmt.Errorf("job manager is not connected (status: %s)", jobConnStatus) + } + } + + var currentSeq int64 = 0 + var totalGap int64 = 0 + waveFile, err := filestore.WFS.Stat(ctx, jobId, JobOutputFileName) + if err == nil { + currentSeq = waveFile.Size + totalGap = getMetaInt64(waveFile.Meta, MetaKey_TotalGap) + currentSeq += totalGap + } + + bareRpc := wshclient.GetBareRpcClient() + broker := bareRpc.StreamBroker + readerRouteId := wshclient.GetBareRpcClientRouteId() + writerRouteId := wshutil.MakeJobRouteId(jobId) + + reader, streamMeta := broker.CreateStreamReaderWithSeq(readerRouteId, writerRouteId, DefaultStreamRwnd, currentSeq) + + prepareData := wshrpc.CommandJobPrepareConnectData{ + StreamMeta: *streamMeta, + Seq: currentSeq, + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeJobRouteId(jobId), + Timeout: 5000, + } + + log.Printf("[job:%s] sending JobPrepareConnectCommand with seq=%d (fileSize=%d, totalGap=%d)", jobId, currentSeq, waveFile.Size, totalGap) + rtnData, err := wshclient.JobPrepareConnectCommand(bareRpc, prepareData, rpcOpts) + if err != nil { + reader.Close() + return fmt.Errorf("failed to prepare connect: %w", err) + } + + if rtnData.HasExited { + exitCodeStr := "nil" + if rtnData.ExitCode != nil { + exitCodeStr = fmt.Sprintf("%d", *rtnData.ExitCode) + } + log.Printf("[job:%s] job has already exited: code=%s signal=%q err=%q", jobId, exitCodeStr, rtnData.ExitSignal, rtnData.ExitErr) + updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.JobManagerStatus = JobStatus_Done + job.CmdExitCode = rtnData.ExitCode + job.CmdExitSignal = rtnData.ExitSignal + job.CmdExitError = rtnData.ExitErr + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job exit status: %v", jobId, updateErr) + } + } + + if rtnData.StreamDone { + log.Printf("[job:%s] stream is already done: error=%q", jobId, rtnData.StreamError) + updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + if !job.StreamDone { + job.StreamDone = true + if rtnData.StreamError != "" { + job.StreamError = rtnData.StreamError + } + } + }) + if updateErr != nil { + log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr) + } + } + + if rtnData.StreamDone && rtnData.HasExited { + reader.Close() + log.Printf("[job:%s] both stream done and job exited, calling tryExitJobManager", jobId) + tryTerminateJobManager(ctx, jobId) + return nil + } + + if rtnData.StreamDone { + reader.Close() + log.Printf("[job:%s] stream already done, no need to restart streaming", jobId) + return nil + } + + if rtnData.Seq > currentSeq { + gap := rtnData.Seq - currentSeq + totalGap += gap + log.Printf("[job:%s] detected gap: our seq=%d, server seq=%d, gap=%d, new totalGap=%d", jobId, currentSeq, rtnData.Seq, gap, totalGap) + + metaErr := filestore.WFS.WriteMeta(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{ + MetaKey_TotalGap: totalGap, + }, true) + if metaErr != nil { + log.Printf("[job:%s] error updating totalgap metadata: %v", jobId, metaErr) + } + + reader.UpdateNextSeq(rtnData.Seq) + } + + log.Printf("[job:%s] sending JobStartStreamCommand", jobId) + startStreamData := wshrpc.CommandJobStartStreamData{} + err = wshclient.JobStartStreamCommand(bareRpc, startStreamData, rpcOpts) + if err != nil { + reader.Close() + return fmt.Errorf("failed to start stream: %w", err) + } + + go func() { + defer func() { + panichandler.PanicHandler("jobcontroller:RestartStreaming:runOutputLoop", recover()) + }() + runOutputLoop(context.Background(), jobId, reader) + }() + + log.Printf("[job:%s] streaming restarted successfully", jobId) + return nil +} + +func DeleteJob(ctx context.Context, jobId string) error { + SetJobConnStatus(jobId, JobConnStatus_Disconnected) + err := filestore.WFS.DeleteZone(ctx, jobId) + if err != nil { + log.Printf("[job:%s] warning: error deleting WaveFS zone: %v", jobId, err) + } + return wstore.DBDelete(ctx, waveobj.OType_Job, jobId) +} + +func AttachJobToBlock(ctx context.Context, jobId string, blockId string) error { + err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error { + err := wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) { + block.JobId = jobId + }) + if err != nil { + return fmt.Errorf("failed to update block: %w", err) + } + + err = wstore.DBUpdateFnErr(tx.Context(), jobId, func(job *waveobj.Job) error { + if job.AttachedBlockId != "" { + return fmt.Errorf("job %s already attached to block %s", jobId, job.AttachedBlockId) + } + job.AttachedBlockId = blockId + return nil + }) + if err != nil { + return fmt.Errorf("failed to update job: %w", err) + } + + log.Printf("[job:%s] attached to block:%s", jobId, blockId) + return nil + }) + if err != nil { + return err + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeFeBlockRouteId(blockId), + NoResponse: true, + } + bareRpc := wshclient.GetBareRpcClient() + wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{ + BlockId: blockId, + JobId: jobId, + }, rpcOpts) + + return nil +} + +func DetachJobFromBlock(ctx context.Context, jobId string, updateBlock bool) error { + var blockId string + err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error { + job, err := wstore.DBMustGet[*waveobj.Job](tx.Context(), jobId) + if err != nil { + return fmt.Errorf("failed to get job: %w", err) + } + + blockId = job.AttachedBlockId + if blockId == "" { + return nil + } + + if updateBlock { + block, err := wstore.DBGet[*waveobj.Block](tx.Context(), blockId) + if err == nil && block != nil { + err = wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) { + block.JobId = "" + }) + if err != nil { + log.Printf("[job:%s] warning: failed to clear JobId from block:%s: %v", jobId, blockId, err) + } + } + } + + err = wstore.DBUpdateFn(tx.Context(), jobId, func(job *waveobj.Job) { + job.AttachedBlockId = "" + }) + if err != nil { + return fmt.Errorf("failed to update job: %w", err) + } + + log.Printf("[job:%s] detached from block:%s", jobId, blockId) + return nil + }) + if err != nil { + return err + } + + if blockId != "" { + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeFeBlockRouteId(blockId), + NoResponse: true, + } + bareRpc := wshclient.GetBareRpcClient() + wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{ + BlockId: blockId, + JobId: "", + }, rpcOpts) + } + + return nil +} + +func SendInput(ctx context.Context, data wshrpc.CommandJobInputData) error { + jobId := data.JobId + _, err := ensureJobConnected(ctx, jobId) + if err != nil { + return err + } + + rpcOpts := &wshrpc.RpcOpts{ + Route: wshutil.MakeJobRouteId(jobId), + Timeout: 5000, + NoResponse: false, + } + + bareRpc := wshclient.GetBareRpcClient() + err = wshclient.JobInputCommand(bareRpc, data, rpcOpts) + if err != nil { + return fmt.Errorf("failed to send input to job: %w", err) + } + + if data.TermSize != nil { + err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.CmdTermSize = *data.TermSize + }) + if err != nil { + log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err) + } + } + + return nil +} diff --git a/pkg/jobmanager/cirbuf.go b/pkg/jobmanager/cirbuf.go new file mode 100644 index 0000000000..fae4063b85 --- /dev/null +++ b/pkg/jobmanager/cirbuf.go @@ -0,0 +1,218 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "context" + "fmt" + "sync" +) + +type CirBuf struct { + lock sync.Mutex + waiterChan chan chan struct{} + buf []byte + readPos int + writePos int + count int + totalSize int64 + syncMode bool + windowSize int +} + +func MakeCirBuf(maxSize int, initSyncMode bool) *CirBuf { + cb := &CirBuf{ + buf: make([]byte, maxSize), + syncMode: initSyncMode, + waiterChan: make(chan chan struct{}, 1), + windowSize: maxSize, + } + return cb +} + +// SetEffectiveWindow changes the sync mode and effective window size for flow control. +// The windowSize is capped at the buffer size. +// When window shrinks: sync mode blocks new writes, async mode truncates old data to enforce limit. +// When window increases: blocked writers are woken up if space becomes available. +func (cb *CirBuf) SetEffectiveWindow(syncMode bool, windowSize int) { + cb.lock.Lock() + defer cb.lock.Unlock() + + maxSize := len(cb.buf) + if windowSize > maxSize { + windowSize = maxSize + } + + oldSyncMode := cb.syncMode + oldWindowSize := cb.windowSize + cb.windowSize = windowSize + cb.syncMode = syncMode + + // In async mode, enforce window size by truncating buffer if needed + if !syncMode && cb.count > windowSize { + excess := cb.count - windowSize + cb.readPos = (cb.readPos + excess) % maxSize + cb.count = windowSize + } + + // Only sync mode blocks writers, so only wake if we were in sync mode. + // Wake when window grows (more space available) or switching to async (no longer blocking). + if oldSyncMode && (windowSize > oldWindowSize || !syncMode) { + cb.tryWakeWriter() + } +} + +// Write will never block if syncMode is false +// If syncMode is true, write will block until enough data is consumed to allow the write to finish +// to cancel a write in progress use WriteCtx +func (cb *CirBuf) Write(data []byte) (int, error) { + return cb.WriteCtx(context.Background(), data) +} + +// WriteCtx writes data to the circular buffer with context support for cancellation. +// In sync mode, blocks when buffer is full until space is available or context is cancelled. +// Returns partial byte count and context error if cancelled mid-write. +// NOTE: Only one concurrent blocked write is allowed. Multiple blocked writes will panic. +func (cb *CirBuf) WriteCtx(ctx context.Context, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } + + bytesWritten := 0 + for bytesWritten < len(data) { + if err := ctx.Err(); err != nil { + return bytesWritten, err + } + + n, spaceAvailable := cb.writeAvailable(data[bytesWritten:]) + bytesWritten += n + + if spaceAvailable != nil { + select { + case <-spaceAvailable: + continue + case <-ctx.Done(): + tryReadCh(cb.waiterChan) + return bytesWritten, ctx.Err() + } + } + } + + return bytesWritten, nil +} + +func (cb *CirBuf) writeAvailable(data []byte) (int, chan struct{}) { + cb.lock.Lock() + defer cb.lock.Unlock() + + size := len(cb.buf) + written := 0 + + for i := 0; i < len(data); i++ { + if cb.syncMode && cb.count >= cb.windowSize { + spaceAvailable := make(chan struct{}) + if !tryWriteCh(cb.waiterChan, spaceAvailable) { + panic("CirBuf: multiple concurrent blocked writes not allowed") + } + return written, spaceAvailable + } + + cb.buf[cb.writePos] = data[i] + cb.writePos = (cb.writePos + 1) % size + if cb.count < cb.windowSize { + cb.count++ + } else { + cb.readPos = (cb.readPos + 1) % size + } + cb.totalSize++ + written++ + } + + return written, nil +} + +func (cb *CirBuf) PeekData(data []byte) int { + return cb.PeekDataAt(0, data) +} + +func (cb *CirBuf) PeekDataAt(offset int, data []byte) int { + cb.lock.Lock() + defer cb.lock.Unlock() + + if cb.count == 0 || offset >= cb.count { + return 0 + } + + size := len(cb.buf) + pos := (cb.readPos + offset) % size + maxRead := cb.count - offset + read := 0 + + for i := 0; i < len(data) && i < maxRead; i++ { + data[i] = cb.buf[pos] + pos = (pos + 1) % size + read++ + } + + return read +} + +func (cb *CirBuf) Consume(numBytes int) error { + cb.lock.Lock() + defer cb.lock.Unlock() + + if numBytes > cb.count { + return fmt.Errorf("cannot consume %d bytes, only %d available", numBytes, cb.count) + } + + size := len(cb.buf) + cb.readPos = (cb.readPos + numBytes) % size + cb.count -= numBytes + + cb.tryWakeWriter() + + return nil +} + +func (cb *CirBuf) HeadPos() int64 { + cb.lock.Lock() + defer cb.lock.Unlock() + return cb.totalSize - int64(cb.count) +} + +func (cb *CirBuf) Size() int { + cb.lock.Lock() + defer cb.lock.Unlock() + return cb.count +} + +func (cb *CirBuf) TotalSize() int64 { + cb.lock.Lock() + defer cb.lock.Unlock() + return cb.totalSize +} + +func tryWriteCh[T any](ch chan<- T, val T) bool { + select { + case ch <- val: + return true + default: + return false + } +} + +func tryReadCh[T any](ch <-chan T) (*T, bool) { + select { + case rtn := <-ch: + return &rtn, true + default: + return nil, false + } +} + +func (cb *CirBuf) tryWakeWriter() { + if waiterCh, ok := tryReadCh(cb.waiterChan); ok { + close(*waiterCh) + } +} diff --git a/pkg/jobmanager/jobcmd.go b/pkg/jobmanager/jobcmd.go new file mode 100644 index 0000000000..2349e69b35 --- /dev/null +++ b/pkg/jobmanager/jobcmd.go @@ -0,0 +1,208 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "encoding/base64" + "fmt" + "log" + "os" + "os/exec" + "sync" + "syscall" + "time" + + "github.com/creack/pty" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type CmdDef struct { + Cmd string + Args []string + Env map[string]string + TermSize waveobj.TermSize +} + +type JobCmd struct { + jobId string + lock sync.Mutex + cmd *exec.Cmd + cmdPty pty.Pty + ptsName string + cleanedUp bool + ptyClosed bool + processExited bool + exitCode *int + exitSignal string + exitErr error + exitTs int64 +} + +func MakeJobCmd(jobId string, cmdDef CmdDef) (*JobCmd, error) { + jm := &JobCmd{ + jobId: jobId, + } + if cmdDef.TermSize.Rows == 0 || cmdDef.TermSize.Cols == 0 { + cmdDef.TermSize.Rows = 25 + cmdDef.TermSize.Cols = 80 + } + if cmdDef.TermSize.Rows <= 0 || cmdDef.TermSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", cmdDef.TermSize) + } + ecmd := exec.Command(cmdDef.Cmd, cmdDef.Args...) + if len(cmdDef.Env) > 0 { + ecmd.Env = os.Environ() + for key, val := range cmdDef.Env { + ecmd.Env = append(ecmd.Env, fmt.Sprintf("%s=%s", key, val)) + } + } + cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(cmdDef.TermSize.Rows), Cols: uint16(cmdDef.TermSize.Cols)}) + if err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + setCloseOnExec(int(cmdPty.Fd())) + jm.cmd = ecmd + jm.cmdPty = cmdPty + jm.ptsName = jm.cmdPty.Name() + go jm.waitForProcess() + return jm, nil +} + +func (jm *JobCmd) waitForProcess() { + if jm.cmd == nil || jm.cmd.Process == nil { + return + } + err := jm.cmd.Wait() + jm.lock.Lock() + defer jm.lock.Unlock() + + jm.processExited = true + jm.exitTs = time.Now().UnixMilli() + jm.exitErr = err + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + if status.Signaled() { + jm.exitSignal = status.Signal().String() + } else if status.Exited() { + code := status.ExitStatus() + jm.exitCode = &code + } else { + log.Printf("Invalid WaitStatus, not exited or signaled: %v", status) + } + } + } + } else { + code := 0 + jm.exitCode = &code + } + exitCodeStr := "nil" + if jm.exitCode != nil { + exitCodeStr = fmt.Sprintf("%d", *jm.exitCode) + } + log.Printf("process exited: exitcode=%s, signal=%s, err=%v\n", exitCodeStr, jm.exitSignal, jm.exitErr) + + go WshCmdJobManager.sendJobExited() +} + +func (jm *JobCmd) GetCmd() (*exec.Cmd, pty.Pty) { + jm.lock.Lock() + defer jm.lock.Unlock() + return jm.cmd, jm.cmdPty +} + +func (jm *JobCmd) GetPGID() (int, error) { + jm.lock.Lock() + defer jm.lock.Unlock() + if jm.cmd == nil || jm.cmd.Process == nil { + return 0, fmt.Errorf("no active process") + } + if jm.processExited { + return 0, fmt.Errorf("process already exited") + } + pgid, err := getProcessGroupId(jm.cmd.Process.Pid) + if err != nil { + return 0, fmt.Errorf("failed to get pgid: %w", err) + } + if pgid <= 0 { + return 0, fmt.Errorf("invalid pgid returned: %d", pgid) + } + return pgid, nil +} + +func (jm *JobCmd) GetExitInfo() (bool, *wshrpc.CommandJobCmdExitedData) { + jm.lock.Lock() + defer jm.lock.Unlock() + if !jm.processExited { + return false, nil + } + exitData := &wshrpc.CommandJobCmdExitedData{ + JobId: WshCmdJobManager.JobId, + ExitCode: jm.exitCode, + ExitSignal: jm.exitSignal, + ExitTs: jm.exitTs, + } + if jm.exitErr != nil { + exitData.ExitErr = jm.exitErr.Error() + } + return true, exitData +} + +// TODO set up a single input handler loop + queue so we dont need to hold the lock but still get synchronized in-order execution +func (jm *JobCmd) HandleInput(data wshrpc.CommandJobInputData) error { + jm.lock.Lock() + defer jm.lock.Unlock() + + if jm.cmd == nil || jm.cmdPty == nil { + return fmt.Errorf("no active process") + } + + if len(data.InputData64) > 0 { + inputBuf := make([]byte, base64.StdEncoding.DecodedLen(len(data.InputData64))) + nw, err := base64.StdEncoding.Decode(inputBuf, []byte(data.InputData64)) + if err != nil { + return fmt.Errorf("error decoding input data: %w", err) + } + _, err = jm.cmdPty.Write(inputBuf[:nw]) + if err != nil { + return fmt.Errorf("error writing to pty: %w", err) + } + } + + if data.SigName != "" { + sig := normalizeSignal(data.SigName) + if sig != nil && jm.cmd.Process != nil { + err := jm.cmd.Process.Signal(sig) + if err != nil { + return fmt.Errorf("error sending signal: %w", err) + } + } + } + + if data.TermSize != nil { + err := pty.Setsize(jm.cmdPty, &pty.Winsize{ + Rows: uint16(data.TermSize.Rows), + Cols: uint16(data.TermSize.Cols), + }) + if err != nil { + return fmt.Errorf("error setting terminal size: %w", err) + } + } + + return nil +} + +func (jm *JobCmd) TerminateByClosingPtyMaster() { + jm.lock.Lock() + defer jm.lock.Unlock() + if jm.ptyClosed { + return + } + if jm.cmdPty != nil { + jm.cmdPty.Close() + jm.ptyClosed = true + log.Printf("pty closed for job %s\n", jm.jobId) + } +} diff --git a/pkg/jobmanager/jobmanager.go b/pkg/jobmanager/jobmanager.go new file mode 100644 index 0000000000..afa015304f --- /dev/null +++ b/pkg/jobmanager/jobmanager.go @@ -0,0 +1,246 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "fmt" + "log" + "net" + "os" + "path/filepath" + "runtime" + "sync" + + "github.com/wavetermdev/waveterm/pkg/baseds" + "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wavejwt" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" +) + +const JobAccessTokenLabel = "Wave-JobAccessToken" +const JobManagerStartLabel = "Wave-JobManagerStart" + +var WshCmdJobManager JobManager + +type JobManager struct { + ClientId string + JobId string + Cmd *JobCmd + JwtPublicKey []byte + JobAuthToken string + StreamManager *StreamManager + lock sync.Mutex + attachedClient *MainServerConn + connectedStreamClient *MainServerConn + pendingStreamMeta *wshrpc.StreamMeta +} + +func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAuthToken string, readyFile *os.File) error { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + return fmt.Errorf("job manager only supported on unix systems, not %s", runtime.GOOS) + } + WshCmdJobManager.ClientId = clientId + WshCmdJobManager.JobId = jobId + WshCmdJobManager.JwtPublicKey = publicKeyBytes + WshCmdJobManager.JobAuthToken = jobAuthToken + WshCmdJobManager.StreamManager = MakeStreamManager() + err := wavejwt.SetPublicKey(publicKeyBytes) + if err != nil { + return fmt.Errorf("failed to set public key: %w", err) + } + err = MakeJobDomainSocket(clientId, jobId) + if err != nil { + return err + } + fmt.Fprintf(readyFile, JobManagerStartLabel+"\n") + readyFile.Close() + + err = daemonize(clientId, jobId) + if err != nil { + return fmt.Errorf("failed to daemonize: %w", err) + } + + return nil +} + +func (jm *JobManager) GetCmd() *JobCmd { + jm.lock.Lock() + defer jm.lock.Unlock() + return jm.Cmd +} + +func (jm *JobManager) sendJobExited() { + jm.lock.Lock() + attachedClient := jm.attachedClient + cmd := jm.Cmd + jm.lock.Unlock() + + if attachedClient == nil { + log.Printf("sendJobExited: no attached client, exit notification not sent\n") + return + } + if attachedClient.WshRpc == nil { + log.Printf("sendJobExited: no wsh rpc connection, exit notification not sent\n") + return + } + if cmd == nil { + log.Printf("sendJobExited: no cmd, exit notification not sent\n") + return + } + + exited, exitData := cmd.GetExitInfo() + if !exited || exitData == nil { + log.Printf("sendJobExited: process not exited yet\n") + return + } + + exitCodeStr := "nil" + if exitData.ExitCode != nil { + exitCodeStr = fmt.Sprintf("%d", *exitData.ExitCode) + } + log.Printf("sendJobExited: sending exit notification to main server exitcode=%s signal=%s\n", exitCodeStr, exitData.ExitSignal) + err := wshclient.JobCmdExitedCommand(attachedClient.WshRpc, *exitData, nil) + if err != nil { + log.Printf("sendJobExited: error sending exit notification: %v\n", err) + } +} + +func (jm *JobManager) GetJobAuthInfo() (string, string) { + jm.lock.Lock() + defer jm.lock.Unlock() + return jm.JobId, jm.JobAuthToken +} + +func (jm *JobManager) IsJobStarted() bool { + jm.lock.Lock() + defer jm.lock.Unlock() + return jm.Cmd != nil +} + +func (jm *JobManager) connectToStreamHelper_withlock(mainServerConn *MainServerConn, streamMeta wshrpc.StreamMeta, seq int64) (int64, error) { + rwndSize := int(streamMeta.RWnd) + if rwndSize < 0 { + return 0, fmt.Errorf("invalid rwnd size: %d", rwndSize) + } + + if jm.connectedStreamClient != nil { + log.Printf("connectToStreamHelper: disconnecting existing client\n") + oldStreamId := jm.StreamManager.GetStreamId() + jm.StreamManager.ClientDisconnected() + if oldStreamId != "" { + mainServerConn.WshRpc.StreamBroker.DetachStreamWriter(oldStreamId) + log.Printf("connectToStreamHelper: detached old stream id=%s\n", oldStreamId) + } + jm.connectedStreamClient = nil + } + dataSender := &routedDataSender{ + wshRpc: mainServerConn.WshRpc, + route: streamMeta.ReaderRouteId, + } + serverSeq, err := jm.StreamManager.ClientConnected( + streamMeta.Id, + dataSender, + rwndSize, + seq, + ) + if err != nil { + return 0, fmt.Errorf("failed to connect client: %w", err) + } + jm.connectedStreamClient = mainServerConn + return serverSeq, nil +} + +func (jm *JobManager) disconnectFromStreamHelper(mainServerConn *MainServerConn) { + jm.lock.Lock() + defer jm.lock.Unlock() + if jm.connectedStreamClient == nil || jm.connectedStreamClient != mainServerConn { + return + } + jm.StreamManager.ClientDisconnected() + jm.connectedStreamClient = nil +} + +func GetJobSocketPath(jobId string) string { + socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid())) + return filepath.Join(socketDir, fmt.Sprintf("%s.sock", jobId)) +} + +func GetJobFilePath(clientId string, jobId string, extension string) string { + homeDir := wavebase.GetHomeDir() + jobDir := filepath.Join(homeDir, ".waveterm", "jobs", clientId) + return filepath.Join(jobDir, fmt.Sprintf("%s.%s", jobId, extension)) +} + +func MakeJobDomainSocket(clientId string, jobId string) error { + socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid())) + err := os.MkdirAll(socketDir, 0700) + if err != nil { + return fmt.Errorf("failed to create socket directory: %w", err) + } + + socketPath := GetJobSocketPath(jobId) + + os.Remove(socketPath) + + listener, err := net.Listen("unix", socketPath) + if err != nil { + return fmt.Errorf("failed to listen on domain socket: %w", err) + } + + go func() { + defer func() { + panichandler.PanicHandler("MakeJobDomainSocket:accept", recover()) + listener.Close() + os.Remove(socketPath) + }() + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("error accepting connection: %v\n", err) + return + } + go handleJobDomainSocketClient(conn) + } + }() + + return nil +} + +func handleJobDomainSocketClient(conn net.Conn) { + inputCh := make(chan baseds.RpcInputChType, wshutil.DefaultInputChSize) + outputCh := make(chan []byte, wshutil.DefaultOutputChSize) + + serverImpl := &MainServerConn{ + Conn: conn, + inputCh: inputCh, + } + rpcCtx := wshrpc.RpcContext{} + wshRpc := wshutil.MakeWshRpcWithChannels(inputCh, outputCh, rpcCtx, serverImpl, "job-domain") + serverImpl.WshRpc = wshRpc + defer WshCmdJobManager.disconnectFromStreamHelper(serverImpl) + + go func() { + defer func() { + panichandler.PanicHandler("handleJobDomainSocketClient:AdaptOutputChToStream", recover()) + }() + defer serverImpl.Close() + writeErr := wshutil.AdaptOutputChToStream(outputCh, conn) + if writeErr != nil { + log.Printf("error writing to domain socket: %v\n", writeErr) + } + }() + + go func() { + defer func() { + panichandler.PanicHandler("handleJobDomainSocketClient:AdaptStreamToMsgCh", recover()) + }() + defer serverImpl.Close() + wshutil.AdaptStreamToMsgCh(conn, inputCh) + }() + + _ = wshRpc +} diff --git a/pkg/jobmanager/jobmanager_unix.go b/pkg/jobmanager/jobmanager_unix.go new file mode 100644 index 0000000000..bddbea3987 --- /dev/null +++ b/pkg/jobmanager/jobmanager_unix.go @@ -0,0 +1,95 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +//go:build unix + +package jobmanager + +import ( + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "golang.org/x/sys/unix" +) + +func getProcessGroupId(pid int) (int, error) { + pgid, err := syscall.Getpgid(pid) + if err != nil { + return 0, err + } + return pgid, nil +} + +func normalizeSignal(sigName string) os.Signal { + sigName = strings.ToUpper(sigName) + sigName = strings.TrimPrefix(sigName, "SIG") + + switch sigName { + case "HUP": + return syscall.SIGHUP + case "INT": + return syscall.SIGINT + case "QUIT": + return syscall.SIGQUIT + case "KILL": + return syscall.SIGKILL + case "TERM": + return syscall.SIGTERM + case "USR1": + return syscall.SIGUSR1 + case "USR2": + return syscall.SIGUSR2 + case "STOP": + return syscall.SIGSTOP + case "CONT": + return syscall.SIGCONT + default: + return nil + } +} + +func daemonize(clientId string, jobId string) error { + _, err := unix.Setsid() + if err != nil { + return fmt.Errorf("failed to setsid: %w", err) + } + + devNull, err := os.OpenFile("/dev/null", os.O_RDWR, 0) + if err != nil { + return fmt.Errorf("failed to open /dev/null: %w", err) + } + err = unix.Dup2(int(devNull.Fd()), int(os.Stdin.Fd())) + if err != nil { + return fmt.Errorf("failed to dup2 stdin: %w", err) + } + devNull.Close() + + logPath := GetJobFilePath(clientId, jobId, "log") + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + err = unix.Dup2(int(logFile.Fd()), int(os.Stdout.Fd())) + if err != nil { + return fmt.Errorf("failed to dup2 stdout: %w", err) + } + err = unix.Dup2(int(logFile.Fd()), int(os.Stderr.Fd())) + if err != nil { + return fmt.Errorf("failed to dup2 stderr: %w", err) + } + + log.SetOutput(logFile) + log.Printf("job manager daemonized, logging to %s\n", logPath) + + signal.Ignore(syscall.SIGHUP) + + return nil +} + +func setCloseOnExec(fd int) { + unix.CloseOnExec(fd) +} diff --git a/pkg/jobmanager/jobmanager_windows.go b/pkg/jobmanager/jobmanager_windows.go new file mode 100644 index 0000000000..356bfcb66e --- /dev/null +++ b/pkg/jobmanager/jobmanager_windows.go @@ -0,0 +1,29 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +//go:build windows + +package jobmanager + +import ( + "fmt" + "os" +) + +func getProcessGroupId(pid int) (int, error) { + return 0, fmt.Errorf("process group id not supported on windows") +} + +func normalizeSignal(sigName string) os.Signal { + return nil +} + +func daemonize(clientId string, jobId string) error { + return fmt.Errorf("daemonize not supported on windows") +} + +func setupJobManagerSignalHandlers() { +} + +func setCloseOnExec(fd int) { +} diff --git a/pkg/jobmanager/mainserverconn.go b/pkg/jobmanager/mainserverconn.go new file mode 100644 index 0000000000..8f10eed20c --- /dev/null +++ b/pkg/jobmanager/mainserverconn.go @@ -0,0 +1,292 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "context" + "fmt" + "log" + "net" + "os" + "sync" + "sync/atomic" + + "github.com/shirou/gopsutil/v4/process" + "github.com/wavetermdev/waveterm/pkg/baseds" + "github.com/wavetermdev/waveterm/pkg/wavejwt" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" +) + +type MainServerConn struct { + PeerAuthenticated atomic.Bool + SelfAuthenticated atomic.Bool + WshRpc *wshutil.WshRpc + Conn net.Conn + inputCh chan baseds.RpcInputChType + closeOnce sync.Once +} + +func (*MainServerConn) WshServerImpl() {} + +func (msc *MainServerConn) Close() { + msc.closeOnce.Do(func() { + msc.Conn.Close() + close(msc.inputCh) + }) +} + +type routedDataSender struct { + wshRpc *wshutil.WshRpc + route string +} + +func (rds *routedDataSender) SendData(dataPk wshrpc.CommandStreamData) { + log.Printf("SendData: sending seq=%d, len=%d, eof=%t, error=%s, route=%s", + dataPk.Seq, len(dataPk.Data64), dataPk.Eof, dataPk.Error, rds.route) + err := wshclient.StreamDataCommand(rds.wshRpc, dataPk, &wshrpc.RpcOpts{NoResponse: true, Route: rds.route}) + if err != nil { + log.Printf("SendData: error sending stream data: %v\n", err) + } +} + +func (msc *MainServerConn) authenticateSelfToServer(jobAuthToken string) error { + jobId, _ := WshCmdJobManager.GetJobAuthInfo() + authData := wshrpc.CommandAuthenticateJobManagerData{ + JobId: jobId, + JobAuthToken: jobAuthToken, + } + err := wshclient.AuthenticateJobManagerCommand(msc.WshRpc, authData, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + if err != nil { + log.Printf("authenticateSelfToServer: failed to authenticate to server: %v\n", err) + return fmt.Errorf("failed to authenticate to server: %w", err) + } + msc.SelfAuthenticated.Store(true) + log.Printf("authenticateSelfToServer: successfully authenticated to server\n") + return nil +} + +func (msc *MainServerConn) AuthenticateToJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateToJobData) error { + jobId, jobAuthToken := WshCmdJobManager.GetJobAuthInfo() + + claims, err := wavejwt.ValidateAndExtract(data.JobAccessToken) + if err != nil { + log.Printf("AuthenticateToJobManager: failed to validate token: %v\n", err) + return fmt.Errorf("failed to validate token: %w", err) + } + if !claims.MainServer { + log.Printf("AuthenticateToJobManager: MainServer claim not set\n") + return fmt.Errorf("MainServer claim not set") + } + if claims.JobId != jobId { + log.Printf("AuthenticateToJobManager: JobId mismatch: expected %s, got %s\n", jobId, claims.JobId) + return fmt.Errorf("JobId mismatch") + } + msc.PeerAuthenticated.Store(true) + log.Printf("AuthenticateToJobManager: authentication successful for JobId=%s\n", claims.JobId) + + err = msc.authenticateSelfToServer(jobAuthToken) + if err != nil { + msc.PeerAuthenticated.Store(false) + return err + } + + WshCmdJobManager.lock.Lock() + defer WshCmdJobManager.lock.Unlock() + + if WshCmdJobManager.attachedClient != nil { + log.Printf("AuthenticateToJobManager: kicking out existing client\n") + WshCmdJobManager.attachedClient.Close() + } + WshCmdJobManager.attachedClient = msc + return nil +} + +func (msc *MainServerConn) StartJobCommand(ctx context.Context, data wshrpc.CommandStartJobData) (*wshrpc.CommandStartJobRtnData, error) { + log.Printf("StartJobCommand: received command=%s args=%v", data.Cmd, data.Args) + if !msc.PeerAuthenticated.Load() { + log.Printf("StartJobCommand: not authenticated") + return nil, fmt.Errorf("not authenticated") + } + if WshCmdJobManager.IsJobStarted() { + log.Printf("StartJobCommand: job already started") + return nil, fmt.Errorf("job already started") + } + + WshCmdJobManager.lock.Lock() + defer WshCmdJobManager.lock.Unlock() + + if WshCmdJobManager.Cmd != nil { + log.Printf("StartJobCommand: job already started (double check)") + return nil, fmt.Errorf("job already started") + } + + cmdDef := CmdDef{ + Cmd: data.Cmd, + Args: data.Args, + Env: data.Env, + TermSize: data.TermSize, + } + log.Printf("StartJobCommand: creating job cmd for jobid=%s", WshCmdJobManager.JobId) + jobCmd, err := MakeJobCmd(WshCmdJobManager.JobId, cmdDef) + if err != nil { + log.Printf("StartJobCommand: failed to make job cmd: %v", err) + return nil, fmt.Errorf("failed to start job: %w", err) + } + WshCmdJobManager.Cmd = jobCmd + log.Printf("StartJobCommand: job cmd created successfully") + + if data.StreamMeta != nil { + serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, *data.StreamMeta, 0) + if err != nil { + return nil, fmt.Errorf("failed to connect stream: %w", err) + } + err = msc.WshRpc.StreamBroker.AttachStreamWriter(data.StreamMeta, WshCmdJobManager.StreamManager) + if err != nil { + return nil, fmt.Errorf("failed to attach stream writer: %w", err) + } + log.Printf("StartJob: connected stream streamid=%s serverSeq=%d\n", data.StreamMeta.Id, serverSeq) + } + + _, cmdPty := jobCmd.GetCmd() + if cmdPty != nil { + log.Printf("StartJobCommand: attaching pty reader to stream manager") + err = WshCmdJobManager.StreamManager.AttachReader(cmdPty) + if err != nil { + log.Printf("StartJobCommand: failed to attach reader: %v", err) + return nil, fmt.Errorf("failed to attach reader to stream manager: %w", err) + } + log.Printf("StartJobCommand: pty reader attached successfully") + } else { + log.Printf("StartJobCommand: no pty to attach") + } + + cmd, _ := jobCmd.GetCmd() + if cmd == nil || cmd.Process == nil { + log.Printf("StartJobCommand: cmd or process is nil") + return nil, fmt.Errorf("cmd or process is nil") + } + cmdPid := cmd.Process.Pid + cmdProc, err := process.NewProcess(int32(cmdPid)) + if err != nil { + log.Printf("StartJobCommand: failed to get cmd process: %v", err) + return nil, fmt.Errorf("failed to get cmd process: %w", err) + } + cmdStartTs, err := cmdProc.CreateTime() + if err != nil { + log.Printf("StartJobCommand: failed to get cmd start time: %v", err) + return nil, fmt.Errorf("failed to get cmd start time: %w", err) + } + + jobManagerPid := os.Getpid() + jobManagerProc, err := process.NewProcess(int32(jobManagerPid)) + if err != nil { + log.Printf("StartJobCommand: failed to get job manager process: %v", err) + return nil, fmt.Errorf("failed to get job manager process: %w", err) + } + jobManagerStartTs, err := jobManagerProc.CreateTime() + if err != nil { + log.Printf("StartJobCommand: failed to get job manager start time: %v", err) + return nil, fmt.Errorf("failed to get job manager start time: %w", err) + } + + log.Printf("StartJobCommand: job started successfully cmdPid=%d cmdStartTs=%d jobManagerPid=%d jobManagerStartTs=%d", cmdPid, cmdStartTs, jobManagerPid, jobManagerStartTs) + return &wshrpc.CommandStartJobRtnData{ + CmdPid: cmdPid, + CmdStartTs: cmdStartTs, + JobManagerPid: jobManagerPid, + JobManagerStartTs: jobManagerStartTs, + }, nil +} + +func (msc *MainServerConn) JobPrepareConnectCommand(ctx context.Context, data wshrpc.CommandJobPrepareConnectData) (*wshrpc.CommandJobConnectRtnData, error) { + WshCmdJobManager.lock.Lock() + defer WshCmdJobManager.lock.Unlock() + + if !msc.PeerAuthenticated.Load() { + return nil, fmt.Errorf("peer not authenticated") + } + if !msc.SelfAuthenticated.Load() { + return nil, fmt.Errorf("not authenticated to server") + } + if WshCmdJobManager.Cmd == nil { + return nil, fmt.Errorf("job not started") + } + + rtnData := &wshrpc.CommandJobConnectRtnData{} + streamDone, streamError := WshCmdJobManager.StreamManager.GetStreamDoneInfo() + + if streamDone { + log.Printf("JobPrepareConnect: stream already done, skipping connection streamError=%q\n", streamError) + rtnData.Seq = data.Seq + rtnData.StreamDone = true + rtnData.StreamError = streamError + } else { + corkedStreamMeta := data.StreamMeta + corkedStreamMeta.RWnd = 0 + serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, corkedStreamMeta, data.Seq) + if err != nil { + return nil, err + } + WshCmdJobManager.pendingStreamMeta = &data.StreamMeta + rtnData.Seq = serverSeq + rtnData.StreamDone = false + } + + hasExited, exitData := WshCmdJobManager.Cmd.GetExitInfo() + if hasExited && exitData != nil { + rtnData.HasExited = true + rtnData.ExitCode = exitData.ExitCode + rtnData.ExitSignal = exitData.ExitSignal + rtnData.ExitErr = exitData.ExitErr + } + + log.Printf("JobPrepareConnect: streamid=%s clientSeq=%d serverSeq=%d streamDone=%v streamError=%q hasExited=%v\n", data.StreamMeta.Id, data.Seq, rtnData.Seq, rtnData.StreamDone, rtnData.StreamError, hasExited) + return rtnData, nil +} + +func (msc *MainServerConn) JobStartStreamCommand(ctx context.Context, data wshrpc.CommandJobStartStreamData) error { + WshCmdJobManager.lock.Lock() + defer WshCmdJobManager.lock.Unlock() + + if !msc.PeerAuthenticated.Load() { + return fmt.Errorf("not authenticated") + } + if WshCmdJobManager.Cmd == nil { + return fmt.Errorf("job not started") + } + if WshCmdJobManager.pendingStreamMeta == nil { + return fmt.Errorf("no pending stream (call JobPrepareConnect first)") + } + + err := msc.WshRpc.StreamBroker.AttachStreamWriter(WshCmdJobManager.pendingStreamMeta, WshCmdJobManager.StreamManager) + if err != nil { + return fmt.Errorf("failed to attach stream writer: %w", err) + } + + err = WshCmdJobManager.StreamManager.SetRwndSize(int(WshCmdJobManager.pendingStreamMeta.RWnd)) + if err != nil { + return fmt.Errorf("failed to set rwnd size: %w", err) + } + + log.Printf("JobStartStream: streamid=%s rwnd=%d streaming started\n", WshCmdJobManager.pendingStreamMeta.Id, WshCmdJobManager.pendingStreamMeta.RWnd) + WshCmdJobManager.pendingStreamMeta = nil + return nil +} + +func (msc *MainServerConn) JobInputCommand(ctx context.Context, data wshrpc.CommandJobInputData) error { + WshCmdJobManager.lock.Lock() + defer WshCmdJobManager.lock.Unlock() + + if !msc.PeerAuthenticated.Load() { + return fmt.Errorf("not authenticated") + } + if WshCmdJobManager.Cmd == nil { + return fmt.Errorf("job not started") + } + + return WshCmdJobManager.Cmd.HandleInput(data) +} + diff --git a/pkg/jobmanager/streammanager.go b/pkg/jobmanager/streammanager.go new file mode 100644 index 0000000000..43861449b7 --- /dev/null +++ b/pkg/jobmanager/streammanager.go @@ -0,0 +1,419 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "encoding/base64" + "fmt" + "io" + "log" + "sync" + + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +const ( + CwndSize = 64 * 1024 // 64 KB window for connected mode + CirBufSize = 2 * 1024 * 1024 // 2 MB max buffer size + DisconnReadSz = 4 * 1024 // 4 KB read chunks when disconnected + MaxPacketSize = 4 * 1024 // 4 KB max data per packet +) + +type DataSender interface { + SendData(dataPk wshrpc.CommandStreamData) +} + +type streamTerminalEvent struct { + isEof bool + err string +} + +// StreamManager handles PTY output buffering with ACK-based flow control +type StreamManager struct { + lock sync.Mutex + drainCond *sync.Cond + + streamId string + + // this is the data read from the attached reader + buf *CirBuf + terminalEvent *streamTerminalEvent + eofPos int64 // fixed position when EOF/error occurs (-1 if not yet) + + reader io.Reader + + cwndSize int + rwndSize int + // invariant: if connected is true, dataSender is non-nil + connected bool + dataSender DataSender + + // unacked state (reset on disconnect) + sentNotAcked int64 + terminalEventSent bool + + // terminal state - once true, stream is complete + terminalEventAcked bool + closed bool +} + +func MakeStreamManager() *StreamManager { + return MakeStreamManagerWithSizes(CwndSize, CirBufSize) +} + +func MakeStreamManagerWithSizes(cwndSize, cirbufSize int) *StreamManager { + sm := &StreamManager{ + buf: MakeCirBuf(cirbufSize, true), + eofPos: -1, + cwndSize: cwndSize, + rwndSize: cwndSize, + } + sm.drainCond = sync.NewCond(&sm.lock) + go sm.senderLoop() + return sm +} + +// AttachReader starts reading from the given reader +func (sm *StreamManager) AttachReader(r io.Reader) error { + sm.lock.Lock() + defer sm.lock.Unlock() + + if sm.reader != nil { + return fmt.Errorf("reader already attached") + } + + sm.reader = r + go sm.readLoop() + + return nil +} + +// ClientConnected transitions to CONNECTED mode +func (sm *StreamManager) ClientConnected(streamId string, dataSender DataSender, rwndSize int, clientSeq int64) (int64, error) { + sm.lock.Lock() + defer sm.lock.Unlock() + + if sm.closed || sm.terminalEventAcked { + return 0, fmt.Errorf("stream is closed") + } + + if sm.connected { + return 0, fmt.Errorf("client already connected") + } + + if dataSender == nil { + return 0, fmt.Errorf("dataSender cannot be nil") + } + + headPos := sm.buf.HeadPos() + if clientSeq > headPos { + bytesToConsume := int(clientSeq - headPos) + available := sm.buf.Size() + if bytesToConsume > available { + return 0, fmt.Errorf("client seq %d is beyond our stream end (head=%d, size=%d)", clientSeq, headPos, available) + } + if bytesToConsume > 0 { + if err := sm.buf.Consume(bytesToConsume); err != nil { + return 0, fmt.Errorf("failed to consume buffer: %w", err) + } + headPos = sm.buf.HeadPos() + } + } + + sm.streamId = streamId + sm.dataSender = dataSender + sm.connected = true + sm.rwndSize = rwndSize + sm.sentNotAcked = 0 + effectiveWindow := sm.cwndSize + if sm.rwndSize < effectiveWindow { + effectiveWindow = sm.rwndSize + } + sm.buf.SetEffectiveWindow(true, effectiveWindow) + sm.drainCond.Signal() + + startSeq := headPos + if clientSeq > startSeq { + startSeq = clientSeq + } + + return startSeq, nil +} + +// GetStreamId returns the current stream ID (safe to call with lock held by caller) +func (sm *StreamManager) GetStreamId() string { + sm.lock.Lock() + defer sm.lock.Unlock() + return sm.streamId +} + +// GetStreamDoneInfo returns whether the stream is done and the error if there was one. +// The error is only meaningful if done=true, as the error is delivered as part of the stream otherwise. +func (sm *StreamManager) GetStreamDoneInfo() (done bool, streamError string) { + sm.lock.Lock() + defer sm.lock.Unlock() + if !sm.terminalEventAcked { + return false, "" + } + if sm.terminalEvent != nil && !sm.terminalEvent.isEof { + return true, sm.terminalEvent.err + } + return true, "" +} + +// ClientDisconnected transitions to DISCONNECTED mode +func (sm *StreamManager) ClientDisconnected() { + sm.lock.Lock() + defer sm.lock.Unlock() + + if !sm.connected { + return + } + + sm.connected = false + sm.dataSender = nil + sm.sentNotAcked = 0 + if !sm.terminalEventAcked { + sm.terminalEventSent = false + } + sm.buf.SetEffectiveWindow(false, CirBufSize) + sm.drainCond.Signal() +} + +// RecvAck processes an ACK from the client +// must be connected, and streamid must match +func (sm *StreamManager) RecvAck(ackPk wshrpc.CommandStreamAckData) { + sm.lock.Lock() + defer sm.lock.Unlock() + + if !sm.connected || ackPk.Id != sm.streamId { + return + } + + if ackPk.Fin { + sm.terminalEventAcked = true + sm.drainCond.Signal() + return + } + + seq := ackPk.Seq + headPos := sm.buf.HeadPos() + if seq < headPos { + return + } + + ackedBytes := seq - headPos + if ackedBytes > sm.sentNotAcked { + return + } + + if ackedBytes > 0 { + if err := sm.buf.Consume(int(ackedBytes)); err != nil { + return + } + sm.sentNotAcked -= ackedBytes + } + + prevRwnd := sm.rwndSize + sm.rwndSize = int(ackPk.RWnd) + effectiveWindow := sm.cwndSize + if sm.rwndSize < effectiveWindow { + effectiveWindow = sm.rwndSize + } + sm.buf.SetEffectiveWindow(true, effectiveWindow) + + if sm.rwndSize > prevRwnd || ackedBytes > 0 { + sm.drainCond.Signal() + } +} + +// SetRwndSize dynamically updates the receive window size +func (sm *StreamManager) SetRwndSize(rwndSize int) error { + sm.lock.Lock() + defer sm.lock.Unlock() + if rwndSize < 0 { + return fmt.Errorf("rwndSize cannot be negative") + } + if !sm.connected { + return fmt.Errorf("not connected") + } + sm.rwndSize = rwndSize + effectiveWindow := sm.cwndSize + if sm.rwndSize < effectiveWindow { + effectiveWindow = sm.rwndSize + } + sm.buf.SetEffectiveWindow(true, effectiveWindow) + sm.drainCond.Signal() + return nil +} + +// Close shuts down the sender loop. The reader loop will exit on its next iteration +// or when the underlying reader is closed. +func (sm *StreamManager) Close() { + sm.lock.Lock() + defer sm.lock.Unlock() + sm.closed = true + sm.drainCond.Signal() +} + +// readLoop is the main read goroutine +func (sm *StreamManager) readLoop() { + readBuf := make([]byte, MaxPacketSize) + for { + sm.lock.Lock() + closed := sm.closed + sm.lock.Unlock() + + if closed { + return + } + + n, err := sm.reader.Read(readBuf) + log.Printf("readLoop: read %d bytes from PTY, err=%v", n, err) + + if n > 0 { + sm.handleReadData(readBuf[:n]) + } + + if err != nil { + if err == io.EOF { + sm.handleEOF() + } else { + sm.handleError(err) + } + return + } + } +} + +func (sm *StreamManager) handleReadData(data []byte) { + log.Printf("handleReadData: writing %d bytes to buffer", len(data)) + sm.buf.Write(data) + sm.lock.Lock() + defer sm.lock.Unlock() + log.Printf("handleReadData: buffer size=%d, connected=%t, signaling=%t", sm.buf.Size(), sm.connected, sm.connected) + if sm.connected { + sm.drainCond.Signal() + } +} + +func (sm *StreamManager) handleEOF() { + sm.lock.Lock() + defer sm.lock.Unlock() + + log.Printf("handleEOF: PTY reached EOF, totalSize=%d", sm.buf.TotalSize()) + sm.eofPos = sm.buf.TotalSize() + sm.terminalEvent = &streamTerminalEvent{isEof: true} + sm.drainCond.Signal() +} + +func (sm *StreamManager) handleError(err error) { + sm.lock.Lock() + defer sm.lock.Unlock() + + log.Printf("handleError: PTY error=%v, totalSize=%d", err, sm.buf.TotalSize()) + sm.eofPos = sm.buf.TotalSize() + sm.terminalEvent = &streamTerminalEvent{err: err.Error()} + sm.drainCond.Signal() +} + +func (sm *StreamManager) senderLoop() { + for { + done, pkt, sender := sm.prepareNextPacket() + if done { + return + } + if pkt == nil { + continue + } + sender.SendData(*pkt) + } +} + +func (sm *StreamManager) prepareNextPacket() (done bool, pkt *wshrpc.CommandStreamData, sender DataSender) { + sm.lock.Lock() + defer sm.lock.Unlock() + + available := sm.buf.Size() + log.Printf("prepareNextPacket: connected=%t, available=%d, closed=%t, terminalEventAcked=%t, terminalEvent=%v", + sm.connected, available, sm.closed, sm.terminalEventAcked, sm.terminalEvent != nil) + + if sm.closed || sm.terminalEventAcked { + return true, nil, nil + } + + if !sm.connected { + log.Printf("prepareNextPacket: waiting for connection") + sm.drainCond.Wait() + return false, nil, nil + } + + if available == 0 { + if sm.terminalEvent != nil && !sm.terminalEventSent { + log.Printf("prepareNextPacket: preparing terminal packet") + return false, sm.prepareTerminalPacket(), sm.dataSender + } + log.Printf("prepareNextPacket: no data available, waiting") + sm.drainCond.Wait() + return false, nil, nil + } + + effectiveRwnd := sm.rwndSize + if sm.cwndSize < effectiveRwnd { + effectiveRwnd = sm.cwndSize + } + availableToSend := int64(effectiveRwnd) - sm.sentNotAcked + + if availableToSend <= 0 { + sm.drainCond.Wait() + return false, nil, nil + } + + peekSize := int(availableToSend) + if peekSize > MaxPacketSize { + peekSize = MaxPacketSize + } + if peekSize > available { + peekSize = available + } + + data := make([]byte, peekSize) + n := sm.buf.PeekDataAt(int(sm.sentNotAcked), data) + if n == 0 { + log.Printf("prepareNextPacket: PeekDataAt returned 0 bytes, waiting for ACK") + sm.drainCond.Wait() + return false, nil, nil + } + data = data[:n] + + seq := sm.buf.HeadPos() + sm.sentNotAcked + sm.sentNotAcked += int64(n) + + log.Printf("prepareNextPacket: sending packet seq=%d, len=%d bytes", seq, n) + return false, &wshrpc.CommandStreamData{ + Id: sm.streamId, + Seq: seq, + Data64: base64.StdEncoding.EncodeToString(data), + }, sm.dataSender +} + +func (sm *StreamManager) prepareTerminalPacket() *wshrpc.CommandStreamData { + if sm.terminalEventSent || sm.terminalEvent == nil { + return nil + } + + pkt := &wshrpc.CommandStreamData{ + Id: sm.streamId, + Seq: sm.eofPos, + } + + if sm.terminalEvent.isEof { + pkt.Eof = true + } else { + pkt.Error = sm.terminalEvent.err + } + + sm.terminalEventSent = true + return pkt +} diff --git a/pkg/jobmanager/streammanager_test.go b/pkg/jobmanager/streammanager_test.go new file mode 100644 index 0000000000..9a0e3c895e --- /dev/null +++ b/pkg/jobmanager/streammanager_test.go @@ -0,0 +1,348 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package jobmanager + +import ( + "encoding/base64" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type testWriter struct { + mu sync.Mutex + packets []wshrpc.CommandStreamData +} + +func (tw *testWriter) SendData(pkt wshrpc.CommandStreamData) { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.packets = append(tw.packets, pkt) +} + +func (tw *testWriter) GetPackets() []wshrpc.CommandStreamData { + tw.mu.Lock() + defer tw.mu.Unlock() + result := make([]wshrpc.CommandStreamData, len(tw.packets)) + copy(result, tw.packets) + return result +} + +func (tw *testWriter) Clear() { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.packets = nil +} + +func decodeData(data64 string) string { + decoded, _ := base64.StdEncoding.DecodeString(data64) + return string(decoded) +} + +func TestBasicDisconnectedMode(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := strings.NewReader("hello world") + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + packets := tw.GetPackets() + if len(packets) > 0 { + t.Errorf("Expected no packets in DISCONNECTED mode without client, got %d", len(packets)) + } + + sm.Close() +} + +func TestConnectedModeBasicFlow(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := strings.NewReader("hello") + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + _, err = sm.ClientConnected("1", tw, CwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + packets := tw.GetPackets() + if len(packets) == 0 { + t.Fatal("Expected packets after ClientConnected") + } + + // Verify we got the data + allData := "" + for _, pkt := range packets { + if pkt.Data64 != "" { + allData += decodeData(pkt.Data64) + } + } + + if allData != "hello" { + t.Errorf("Expected 'hello', got '%s'", allData) + } + + // Send ACK + sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 5, RWnd: CwndSize}) + + time.Sleep(50 * time.Millisecond) + + // Check for EOF packet + packets = tw.GetPackets() + hasEof := false + for _, pkt := range packets { + if pkt.Eof { + hasEof = true + } + } + + if !hasEof { + t.Error("Expected EOF packet after ACKing all data") + } + + sm.Close() +} + +func TestDisconnectedToConnectedTransition(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := strings.NewReader("test data") + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + _, err = sm.ClientConnected("1", tw, CwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + packets := tw.GetPackets() + if len(packets) == 0 { + t.Fatal("Expected cirbuf drain after connect") + } + + allData := "" + for _, pkt := range packets { + if pkt.Data64 != "" { + allData += decodeData(pkt.Data64) + } + } + + if allData != "test data" { + t.Errorf("Expected 'test data', got '%s'", allData) + } + + sm.Close() +} + +func TestConnectedToDisconnectedTransition(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := &slowReader{data: []byte("slow data"), delay: 50 * time.Millisecond} + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + _, err = sm.ClientConnected("1", tw, CwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(150 * time.Millisecond) + + sm.ClientDisconnected() + + time.Sleep(100 * time.Millisecond) + + sm.Close() +} + +func TestFlowControl(t *testing.T) { + cwndSize := 1024 + tw := &testWriter{} + sm := MakeStreamManagerWithSizes(cwndSize, 8*1024) + + largeData := strings.Repeat("x", cwndSize+500) + reader := strings.NewReader(largeData) + + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + _, err = sm.ClientConnected("1", tw, cwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + packets := tw.GetPackets() + totalData := 0 + for _, pkt := range packets { + if pkt.Data64 != "" { + decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64) + totalData += len(decoded) + } + } + + if totalData > cwndSize { + t.Errorf("Sent %d bytes without ACK, exceeds cwnd size %d", totalData, cwndSize) + } + + sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: int64(totalData), RWnd: int64(cwndSize)}) + + time.Sleep(100 * time.Millisecond) + + sm.Close() +} + +func TestSequenceNumbering(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := strings.NewReader("abcdefghij") + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + _, err = sm.ClientConnected("1", tw, CwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + packets := tw.GetPackets() + if len(packets) == 0 { + t.Fatal("Expected packets") + } + + expectedSeq := int64(0) + for _, pkt := range packets { + if pkt.Data64 == "" { + continue + } + + if pkt.Seq != expectedSeq { + t.Errorf("Expected seq %d, got %d", expectedSeq, pkt.Seq) + } + + decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64) + expectedSeq += int64(len(decoded)) + } + + sm.Close() +} + +func TestTerminalEventOrdering(t *testing.T) { + tw := &testWriter{} + sm := MakeStreamManager() + + reader := strings.NewReader("data") + err := sm.AttachReader(reader) + if err != nil { + t.Fatalf("AttachReader failed: %v", err) + } + + _, err = sm.ClientConnected("1", tw, CwndSize, 0) + if err != nil { + t.Fatalf("ClientConnected failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + packets := tw.GetPackets() + if len(packets) == 0 { + t.Fatal("Expected data packets") + } + + hasData := false + hasEof := false + eofSeq := int64(-1) + + for _, pkt := range packets { + if pkt.Data64 != "" { + hasData = true + } + if pkt.Eof { + hasEof = true + eofSeq = pkt.Seq + } + } + + if !hasData { + t.Error("Expected data packet") + } + + if hasEof { + t.Error("Should not have EOF before ACK") + } + + sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 4, RWnd: CwndSize}) + + time.Sleep(50 * time.Millisecond) + + packets = tw.GetPackets() + hasEof = false + for _, pkt := range packets { + if pkt.Eof { + hasEof = true + eofSeq = pkt.Seq + } + } + + if !hasEof { + t.Error("Expected EOF after ACKing all data") + } + + if eofSeq != 4 { + t.Errorf("Expected EOF at seq 4, got %d", eofSeq) + } + + sm.Close() +} + +type slowReader struct { + data []byte + pos int + delay time.Duration +} + +func (sr *slowReader) Read(p []byte) (n int, err error) { + if sr.pos >= len(sr.data) { + return 0, io.EOF + } + + time.Sleep(sr.delay) + + n = copy(p, sr.data[sr.pos:]) + sr.pos += n + + return n, nil +} diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 34d2033b43..b042eb9693 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -85,7 +85,7 @@ type SSHConn struct { var ConnServerCmdTemplate = strings.TrimSpace( strings.Join([]string{ "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm; exit 0);", - "exec %s connserver --conn %s %s", + "exec %s connserver --conn %s %s %s", }, "\n")) func IsLocalConnName(connName string) bool { @@ -285,8 +285,9 @@ func (conn *SSHConn) GetConfigShellPath() string { // returns (needsInstall, clientVersion, osArchStr, error) // if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr // if clientVersion is set, then no osArchStr will be returned -func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) { - conn.Infof(ctx, "running StartConnServer...\n") +// if useRouterMode is true, will start connserver with --router-domainsocket flag +func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool, useRouterMode bool) (bool, string, string, error) { + conn.Infof(ctx, "running StartConnServer (routerMode=%v)...\n", useRouterMode) allowed := WithLockRtn(conn, func() bool { return conn.Status == Status_Connecting }) @@ -296,10 +297,19 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo client := conn.GetClient() wshPath := conn.getWshPath() sockName := conn.GetDomainSocketName() - rpcCtx := wshrpc.RpcContext{ - RouteId: wshutil.MakeConnectionRouteId(conn.GetName()), - SockName: sockName, - Conn: conn.GetName(), + var rpcCtx wshrpc.RpcContext + if useRouterMode { + rpcCtx = wshrpc.RpcContext{ + IsRouter: true, + SockName: sockName, + Conn: conn.GetName(), + } + } else { + rpcCtx = wshrpc.RpcContext{ + RouteId: wshutil.MakeConnectionRouteId(conn.GetName()), + SockName: sockName, + Conn: conn.GetName(), + } } jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx) if err != nil { @@ -321,7 +331,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo if wavebase.IsDevMode() { devFlag = "--dev" } - cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag) + routerFlag := "" + if useRouterMode { + routerFlag = "--router-domainsocket" + } + cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag, routerFlag) log.Printf("starting conn controller: %q\n", cmdStr) shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr)) blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) @@ -702,7 +716,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) err = fmt.Errorf("error opening domain socket listener: %w", err) return WshCheckResult{NoWshReason: "error opening domain socket", NoWshCode: NoWshCode_DomainSocketError, WshError: err} } - needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false) + needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false, false) if err != nil { conn.Infof(ctx, "ERROR starting conn server: %v\n", err) err = fmt.Errorf("error starting conn server: %w", err) @@ -716,7 +730,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) err = fmt.Errorf("error installing wsh: %w", err) return WshCheckResult{NoWshReason: "error installing wsh/connserver", NoWshCode: NoWshCode_InstallError, WshError: err} } - needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true) + needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true, false) if err != nil { conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err) err = fmt.Errorf("error starting conn server (after install): %w", err) @@ -842,23 +856,38 @@ func (conn *SSHConn) ClearWshError() { }) } -func getConnInternal(opts *remote.SSHOpts) *SSHConn { +func getConnInternal(opts *remote.SSHOpts, createIfNotExists bool) *SSHConn { globalLock.Lock() defer globalLock.Unlock() rtn := clientControllerMap[*opts] - if rtn == nil { + if rtn == nil && createIfNotExists { rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, WshEnabled: &atomic.Bool{}, Opts: opts, HasWaiter: &atomic.Bool{}} clientControllerMap[*opts] = rtn } return rtn } -// does NOT connect, can return nil if connection does not exist +// does NOT connect, does not return nil func GetConn(opts *remote.SSHOpts) *SSHConn { - conn := getConnInternal(opts) + conn := getConnInternal(opts, true) return conn } +func IsConnected(connName string) (bool, error) { + if IsLocalConnName(connName) { + return true, nil + } + connOpts, err := remote.ParseOpts(connName) + if err != nil { + return false, fmt.Errorf("error parsing connection name: %w", err) + } + conn := getConnInternal(connOpts, false) + if conn == nil { + return false, nil + } + return conn.GetStatus() == Status_Connected, nil +} + // Convenience function for ensuring a connection is established func EnsureConnection(ctx context.Context, connName string) error { if IsLocalConnName(connName) { @@ -888,7 +917,7 @@ func EnsureConnection(ctx context.Context, connName string) error { } func DisconnectClient(opts *remote.SSHOpts) error { - conn := getConnInternal(opts) + conn := getConnInternal(opts, false) if conn == nil { return fmt.Errorf("client %q not found", opts.String()) } diff --git a/pkg/streamclient/stream_test.go b/pkg/streamclient/stream_test.go index be1d2a1149..67fb3bc057 100644 --- a/pkg/streamclient/stream_test.go +++ b/pkg/streamclient/stream_test.go @@ -2,6 +2,7 @@ package streamclient import ( "bytes" + "encoding/base64" "io" "testing" "time" @@ -32,8 +33,8 @@ func (ft *fakeTransport) SendAck(ackPk wshrpc.CommandStreamAckData) { func TestBasicReadWrite(t *testing.T) { transport := newFakeTransport() - reader := NewReader(1, 1024, transport) - writer := NewWriter(1, 1024, transport) + reader := NewReader("1", 1024, transport) + writer := NewWriter("1", 1024, transport) go func() { for dataPk := range transport.dataChan { @@ -72,8 +73,8 @@ func TestBasicReadWrite(t *testing.T) { func TestEOF(t *testing.T) { transport := newFakeTransport() - reader := NewReader(1, 1024, transport) - writer := NewWriter(1, 1024, transport) + reader := NewReader("1", 1024, transport) + writer := NewWriter("1", 1024, transport) go func() { for dataPk := range transport.dataChan { @@ -110,8 +111,8 @@ func TestFlowControl(t *testing.T) { smallWindow := int64(10) transport := newFakeTransport() - reader := NewReader(1, smallWindow, transport) - writer := NewWriter(1, smallWindow, transport) + reader := NewReader("1", smallWindow, transport) + writer := NewWriter("1", smallWindow, transport) go func() { for dataPk := range transport.dataChan { @@ -163,8 +164,8 @@ func TestFlowControl(t *testing.T) { func TestError(t *testing.T) { transport := newFakeTransport() - reader := NewReader(1, 1024, transport) - writer := NewWriter(1, 1024, transport) + reader := NewReader("1", 1024, transport) + writer := NewWriter("1", 1024, transport) go func() { for dataPk := range transport.dataChan { @@ -194,8 +195,8 @@ func TestError(t *testing.T) { func TestCancel(t *testing.T) { transport := newFakeTransport() - reader := NewReader(1, 1024, transport) - writer := NewWriter(1, 1024, transport) + reader := NewReader("1", 1024, transport) + writer := NewWriter("1", 1024, transport) go func() { for dataPk := range transport.dataChan { @@ -227,8 +228,8 @@ func TestCancel(t *testing.T) { func TestMultipleWrites(t *testing.T) { transport := newFakeTransport() - reader := NewReader(1, 1024, transport) - writer := NewWriter(1, 1024, transport) + reader := NewReader("1", 1024, transport) + writer := NewWriter("1", 1024, transport) go func() { for dataPk := range transport.dataChan { @@ -265,3 +266,258 @@ func TestMultipleWrites(t *testing.T) { t.Fatalf("Expected %q, got %q", expected, string(buf)) } } + +func TestOutOfOrderPackets(t *testing.T) { + transport := newFakeTransport() + reader := NewReader("test-ooo", 1024, transport) + + packet0 := wshrpc.CommandStreamData{ + Id: "test-ooo", + Seq: 0, + Data64: base64.StdEncoding.EncodeToString([]byte("AAAAA")), + } + packet5 := wshrpc.CommandStreamData{ + Id: "test-ooo", + Seq: 5, + Data64: base64.StdEncoding.EncodeToString([]byte("BBBBB")), + } + packet10 := wshrpc.CommandStreamData{ + Id: "test-ooo", + Seq: 10, + Data64: base64.StdEncoding.EncodeToString([]byte("CCCCC")), + } + packet15 := wshrpc.CommandStreamData{ + Id: "test-ooo", + Seq: 15, + Data64: base64.StdEncoding.EncodeToString([]byte("DDDDD")), + } + + // Send packets out of order: 0, 10, 15, 5 + reader.RecvData(packet0) + reader.RecvData(packet10) // OOO - should be buffered + reader.RecvData(packet15) // OOO - should be buffered + reader.RecvData(packet5) // fills the gap - should trigger processing + + // Read all data + buf := make([]byte, 1024) + totalRead := 0 + expectedLen := 20 // 4 packets * 5 bytes each + + readDone := make(chan struct{}) + go func() { + for totalRead < expectedLen { + n, err := reader.Read(buf[totalRead:]) + if err != nil { + t.Errorf("Read failed: %v", err) + return + } + totalRead += n + } + close(readDone) + }() + + select { + case <-readDone: + // Success + case <-time.After(2 * time.Second): + t.Fatalf("Read didn't complete in time. Read %d bytes, expected %d", totalRead, expectedLen) + } + + if totalRead != expectedLen { + t.Fatalf("Expected to read %d bytes, got %d", expectedLen, totalRead) + } +} + +func TestOutOfOrderWithDuplicates(t *testing.T) { + transport := newFakeTransport() + reader := NewReader("test-dup", 1024, transport) + + packet0 := wshrpc.CommandStreamData{ + Id: "test-dup", + Seq: 0, + Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")), + } + packet10 := wshrpc.CommandStreamData{ + Id: "test-dup", + Seq: 10, + Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")), + } + packet5First := wshrpc.CommandStreamData{ + Id: "test-dup", + Seq: 5, + Data64: base64.StdEncoding.EncodeToString([]byte("xxxxx")), + } + packet5Second := wshrpc.CommandStreamData{ + Id: "test-dup", + Seq: 5, + Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")), + } + + reader.RecvData(packet0) + reader.RecvData(packet10) // OOO - buffered + reader.RecvData(packet5First) // OOO - buffered + reader.RecvData(packet5First) // Duplicate - should be ignored + reader.RecvData(packet5Second) // Duplicate with different data - should be ignored + + // Read all data - should get all 3 packets in order + buf := make([]byte, 20) + n, err := reader.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + // Should get all 15 bytes (3 packets * 5 bytes) + if n != 15 { + t.Fatalf("Expected to read 15 bytes, got %d", n) + } + + // Should be "aaaaaxxxxxccccc" (first packet received for each seq wins) + expected := "aaaaaxxxxxccccc" + if string(buf[:n]) != expected { + t.Fatalf("Expected %q, got %q", expected, string(buf[:n])) + } +} + +func TestOutOfOrderWithGaps(t *testing.T) { + transport := newFakeTransport() + reader := NewReader("test-gaps", 1024, transport) + + packet0 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 0, + Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")), + } + packet20 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 20, + Data64: base64.StdEncoding.EncodeToString([]byte("eeeee")), + } + packet40 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 40, + Data64: base64.StdEncoding.EncodeToString([]byte("iiiii")), + } + packet5 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 5, + Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")), + } + + reader.RecvData(packet0) + reader.RecvData(packet40) // Way ahead - should be buffered + reader.RecvData(packet20) // Still ahead - should be buffered + + // Read first packet + buf := make([]byte, 10) + n, err := reader.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != 5 || string(buf[:n]) != "aaaaa" { + t.Fatalf("Expected 'aaaaa', got %q", string(buf[:n])) + } + + // Send packet to partially fill gap + reader.RecvData(packet5) + + // Should be able to read it now + n, err = reader.Read(buf) + if err != nil { + t.Fatalf("Second read failed: %v", err) + } + if n != 5 || string(buf[:n]) != "bbbbb" { + t.Fatalf("Expected 'bbbbb', got %q", string(buf[:n])) + } + + packet10 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 10, + Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")), + } + packet15 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 15, + Data64: base64.StdEncoding.EncodeToString([]byte("ddddd")), + } + packet25 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 25, + Data64: base64.StdEncoding.EncodeToString([]byte("fffff")), + } + packet30 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 30, + Data64: base64.StdEncoding.EncodeToString([]byte("ggggg")), + } + packet35 := wshrpc.CommandStreamData{ + Id: "test-gaps", + Seq: 35, + Data64: base64.StdEncoding.EncodeToString([]byte("hhhhh")), + } + + reader.RecvData(packet10) + reader.RecvData(packet15) + reader.RecvData(packet25) + reader.RecvData(packet30) + reader.RecvData(packet35) + + // Read all remaining data at once + allData := make([]byte, 100) + totalRead := 0 + for totalRead < 35 { + n, err = reader.Read(allData[totalRead:]) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + totalRead += n + } + + expected := "cccccdddddeeeeefffffggggghhhhhiiiii" + if string(allData[:totalRead]) != expected { + t.Fatalf("Expected %q, got %q", expected, string(allData[:totalRead])) + } +} + +func TestOutOfOrderWithEOF(t *testing.T) { + transport := newFakeTransport() + reader := NewReader("test-eof", 1024, transport) + + packet0 := wshrpc.CommandStreamData{ + Id: "test-eof", + Seq: 0, + Data64: base64.StdEncoding.EncodeToString([]byte("first")), + } + packet11 := wshrpc.CommandStreamData{ + Id: "test-eof", + Seq: 11, + Data64: base64.StdEncoding.EncodeToString([]byte("third")), + Eof: true, + } + packet5 := wshrpc.CommandStreamData{ + Id: "test-eof", + Seq: 5, + Data64: base64.StdEncoding.EncodeToString([]byte("second")), + } + + reader.RecvData(packet0) + reader.RecvData(packet11) // OOO with EOF + reader.RecvData(packet5) // Fill the gap + + // Read all data + buf := make([]byte, 20) + n, err := reader.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + expected := "firstsecondthird" + if string(buf[:n]) != expected { + t.Fatalf("Expected %q, got %q", expected, string(buf[:n])) + } + + // Should get EOF now + _, err = reader.Read(buf) + if err != io.EOF { + t.Fatalf("Expected EOF, got %v", err) + } +} diff --git a/pkg/streamclient/streambroker.go b/pkg/streamclient/streambroker.go index 4d35c9d367..9f3ec173d9 100644 --- a/pkg/streamclient/streambroker.go +++ b/pkg/streamclient/streambroker.go @@ -5,10 +5,9 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/utilds" "github.com/wavetermdev/waveterm/pkg/wshrpc" - "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" - "github.com/wavetermdev/waveterm/pkg/wshutil" ) type workItem struct { @@ -17,36 +16,23 @@ type workItem struct { dataPk wshrpc.CommandStreamData } +type StreamWriter interface { + RecvAck(ackPk wshrpc.CommandStreamAckData) +} + type StreamRpcInterface interface { StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error } -type wshRpcAdapter struct { - rpc *wshutil.WshRpc -} - -func (a *wshRpcAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error { - return wshclient.StreamDataAckCommand(a.rpc, data, opts) -} - -func (a *wshRpcAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error { - return wshclient.StreamDataCommand(a.rpc, data, opts) -} - -func AdaptWshRpc(rpc *wshutil.WshRpc) StreamRpcInterface { - return &wshRpcAdapter{rpc: rpc} -} - type Broker struct { lock sync.Mutex rpcClient StreamRpcInterface - streamIdCounter int64 - readers map[int64]*Reader - writers map[int64]*Writer - readerRoutes map[int64]string - writerRoutes map[int64]string - readerErrorSentTime map[int64]time.Time + readers map[string]*Reader + writers map[string]StreamWriter + readerRoutes map[string]string + writerRoutes map[string]string + readerErrorSentTime map[string]time.Time sendQueue *utilds.WorkQueue[workItem] recvQueue *utilds.WorkQueue[workItem] } @@ -54,12 +40,11 @@ type Broker struct { func NewBroker(rpcClient StreamRpcInterface) *Broker { b := &Broker{ rpcClient: rpcClient, - streamIdCounter: 0, - readers: make(map[int64]*Reader), - writers: make(map[int64]*Writer), - readerRoutes: make(map[int64]string), - writerRoutes: make(map[int64]string), - readerErrorSentTime: make(map[int64]time.Time), + readers: make(map[string]*Reader), + writers: make(map[string]StreamWriter), + readerRoutes: make(map[string]string), + writerRoutes: make(map[string]string), + readerErrorSentTime: make(map[string]time.Time), } b.sendQueue = utilds.NewWorkQueue(b.processSendWork) b.recvQueue = utilds.NewWorkQueue(b.processRecvWork) @@ -67,13 +52,16 @@ func NewBroker(rpcClient StreamRpcInterface) *Broker { } func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd int64) (*Reader, *wshrpc.StreamMeta) { + return b.CreateStreamReaderWithSeq(readerRoute, writerRoute, rwnd, 0) +} + +func (b *Broker) CreateStreamReaderWithSeq(readerRoute string, writerRoute string, rwnd int64, startSeq int64) (*Reader, *wshrpc.StreamMeta) { b.lock.Lock() defer b.lock.Unlock() - b.streamIdCounter++ - streamId := b.streamIdCounter + streamId := uuid.New().String() - reader := NewReader(streamId, rwnd, b) + reader := NewReaderWithSeq(streamId, rwnd, startSeq, b) b.readers[streamId] = reader b.readerRoutes[streamId] = readerRoute b.writerRoutes[streamId] = writerRoute @@ -88,19 +76,35 @@ func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd return reader, meta } -func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) { +func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta, writer StreamWriter) error { b.lock.Lock() defer b.lock.Unlock() if _, exists := b.writers[meta.Id]; exists { - return nil, fmt.Errorf("writer already registered for stream id %d", meta.Id) + return fmt.Errorf("writer already registered for stream id %s", meta.Id) } - writer := NewWriter(meta.Id, meta.RWnd, b) b.writers[meta.Id] = writer b.readerRoutes[meta.Id] = meta.ReaderRouteId b.writerRoutes[meta.Id] = meta.WriterRouteId + return nil +} + +func (b *Broker) DetachStreamWriter(streamId string) { + b.lock.Lock() + defer b.lock.Unlock() + + delete(b.writers, streamId) + delete(b.writerRoutes, streamId) +} + +func (b *Broker) CreateStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) { + writer := NewWriter(meta.Id, meta.RWnd, b) + err := b.AttachStreamWriter(meta, writer) + if err != nil { + return nil, err + } return writer, nil } @@ -112,6 +116,9 @@ func (b *Broker) SendData(dataPk wshrpc.CommandStreamData) { b.sendQueue.Enqueue(workItem{workType: "senddata", dataPk: dataPk}) } +// RecvData and RecvAck are designed to be non-blocking and must remain so to prevent deadlock. +// They only enqueue work items to be processed asynchronously by the work queue's goroutine. +// These methods are called from the main RPC runServer loop, so blocking here would stall all RPC processing. func (b *Broker) RecvData(dataPk wshrpc.CommandStreamData) { b.recvQueue.Enqueue(workItem{workType: "recvdata", dataPk: dataPk}) } @@ -220,7 +227,7 @@ func (b *Broker) Close() { b.recvQueue.Wait() } -func (b *Broker) cleanupReader(streamId int64) { +func (b *Broker) cleanupReader(streamId string) { b.lock.Lock() defer b.lock.Unlock() @@ -229,7 +236,7 @@ func (b *Broker) cleanupReader(streamId int64) { delete(b.readerErrorSentTime, streamId) } -func (b *Broker) cleanupWriter(streamId int64) { +func (b *Broker) cleanupWriter(streamId string) { b.lock.Lock() defer b.lock.Unlock() diff --git a/pkg/streamclient/streambroker_test.go b/pkg/streamclient/streambroker_test.go index 42871caf80..146816ce79 100644 --- a/pkg/streamclient/streambroker_test.go +++ b/pkg/streamclient/streambroker_test.go @@ -68,9 +68,9 @@ func TestBrokerBasicReadWrite(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } testData := []byte("Hello, World!") @@ -105,9 +105,9 @@ func TestBrokerEOF(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } testData := []byte("Test data") @@ -134,9 +134,9 @@ func TestBrokerFlowControl(t *testing.T) { smallWindow := int64(10) reader, meta := broker1.CreateStreamReader("reader1", "writer1", smallWindow) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } largeData := make([]byte, 100) @@ -180,9 +180,9 @@ func TestBrokerError(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } testErr := io.ErrUnexpectedEOF @@ -202,9 +202,9 @@ func TestBrokerCancel(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } reader.Close() @@ -226,9 +226,9 @@ func TestBrokerMultipleWrites(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } messages := []string{"First", "Second", "Third"} @@ -261,9 +261,9 @@ func TestBrokerCleanup(t *testing.T) { broker1, broker2 := setupBrokerPair() reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024) - writer, err := broker2.AttachStreamWriter(meta) + writer, err := broker2.CreateStreamWriter(meta) if err != nil { - t.Fatalf("AttachStreamWriter failed: %v", err) + t.Fatalf("CreateStreamWriter failed: %v", err) } testData := []byte("cleanup test") diff --git a/pkg/streamclient/streamreader.go b/pkg/streamclient/streamreader.go index e1fb7bc10a..541d5c866d 100644 --- a/pkg/streamclient/streamreader.go +++ b/pkg/streamclient/streamreader.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "fmt" "io" + "sort" "sync" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -16,7 +17,7 @@ type AckSender interface { type Reader struct { lock sync.Mutex cond *sync.Cond - id int64 + id string ackSender AckSender readWindow int64 nextSeq int64 @@ -25,14 +26,19 @@ type Reader struct { err error closed bool lastRwndSent int64 + oooPackets []wshrpc.CommandStreamData // out-of-order packets awaiting delivery } -func NewReader(id int64, readWindow int64, ackSender AckSender) *Reader { +func NewReader(id string, readWindow int64, ackSender AckSender) *Reader { + return NewReaderWithSeq(id, readWindow, 0, ackSender) +} + +func NewReaderWithSeq(id string, readWindow int64, startSeq int64, ackSender AckSender) *Reader { r := &Reader{ id: id, readWindow: readWindow, ackSender: ackSender, - nextSeq: 0, + nextSeq: startSeq, lastRwndSent: readWindow, } r.cond = sync.NewCond(&r.lock) @@ -43,7 +49,7 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) { r.lock.Lock() defer r.lock.Unlock() - if r.closed { + if r.closed || r.eof || r.err != nil { return } @@ -59,18 +65,25 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) { return } - if dataPk.Seq != r.nextSeq { - r.err = fmt.Errorf("stream sequence mismatch: expected %d, got %d", r.nextSeq, dataPk.Seq) - r.cond.Broadcast() - r.sendAckLocked(false, true, "sequence mismatch error") + if dataPk.Seq < r.nextSeq { + return + } + if dataPk.Seq > r.nextSeq { + r.addOOOPacketLocked(dataPk) return } + r.recvDataOrderedLocked(dataPk) + r.processOOOPacketsLocked() + r.cond.Broadcast() + r.sendAckLocked(r.eof, false, "") +} + +func (r *Reader) recvDataOrderedLocked(dataPk wshrpc.CommandStreamData) { if dataPk.Data64 != "" { data, err := base64.StdEncoding.DecodeString(dataPk.Data64) if err != nil { r.err = err - r.cond.Broadcast() r.sendAckLocked(false, true, "base64 decode error") return } @@ -80,13 +93,40 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) { if dataPk.Eof { r.eof = true - r.cond.Broadcast() - r.sendAckLocked(true, false, "") - return } +} - r.cond.Broadcast() - r.sendAckLocked(false, false, "") +func (r *Reader) addOOOPacketLocked(dataPk wshrpc.CommandStreamData) { + for _, pkt := range r.oooPackets { + if pkt.Seq == dataPk.Seq { + // this handles duplicates + return + } + } + r.oooPackets = append(r.oooPackets, dataPk) +} + +func (r *Reader) processOOOPacketsLocked() { + if len(r.oooPackets) == 0 { + return + } + sort.Slice(r.oooPackets, func(i, j int) bool { + return r.oooPackets[i].Seq < r.oooPackets[j].Seq + }) + consumed := 0 + for _, pkt := range r.oooPackets { + if r.eof || r.err != nil { + // we're done, so we can clear any pending ooo packets + r.oooPackets = nil + return + } + if pkt.Seq != r.nextSeq { + break + } + r.recvDataOrderedLocked(pkt) + consumed++ + } + r.oooPackets = r.oooPackets[consumed:] } func (r *Reader) sendAckLocked(fin bool, cancel bool, errStr string) { @@ -146,6 +186,12 @@ func (r *Reader) Read(p []byte) (int, error) { return n, nil } +func (r *Reader) UpdateNextSeq(newSeq int64) { + r.lock.Lock() + defer r.lock.Unlock() + r.nextSeq = newSeq +} + func (r *Reader) Close() error { r.lock.Lock() defer r.lock.Unlock() diff --git a/pkg/streamclient/streamwriter.go b/pkg/streamclient/streamwriter.go index aef0a3df4e..862f0c9cfb 100644 --- a/pkg/streamclient/streamwriter.go +++ b/pkg/streamclient/streamwriter.go @@ -16,7 +16,7 @@ type DataSender interface { type Writer struct { lock sync.Mutex cond *sync.Cond - id int64 + id string dataSender DataSender readWindow int64 nextSeq int64 @@ -31,7 +31,7 @@ type Writer struct { closed bool } -func NewWriter(id int64, readWindow int64, dataSender DataSender) *Writer { +func NewWriter(id string, readWindow int64, dataSender DataSender) *Writer { w := &Writer{ id: id, readWindow: readWindow, diff --git a/pkg/wavejwt/wavejwt.go b/pkg/wavejwt/wavejwt.go index 45a621a9a3..9e91003c58 100644 --- a/pkg/wavejwt/wavejwt.go +++ b/pkg/wavejwt/wavejwt.go @@ -26,11 +26,13 @@ var ( type WaveJwtClaims struct { jwt.RegisteredClaims - Sock string `json:"sock,omitempty"` - RouteId string `json:"routeid,omitempty"` - BlockId string `json:"blockid,omitempty"` - Conn string `json:"conn,omitempty"` - Router bool `json:"router,omitempty"` + MainServer bool `json:"mainserver,omitempty"` + Sock string `json:"sock,omitempty"` + RouteId string `json:"routeid,omitempty"` + BlockId string `json:"blockid,omitempty"` + JobId string `json:"jobid,omitempty"` + Conn string `json:"conn,omitempty"` + Router bool `json:"router,omitempty"` } type KeyPair struct { diff --git a/pkg/waveobj/wtype.go b/pkg/waveobj/wtype.go index 2f7e7e0a1f..8df86d3766 100644 --- a/pkg/waveobj/wtype.go +++ b/pkg/waveobj/wtype.go @@ -29,6 +29,7 @@ const ( OType_LayoutState = "layout" OType_Block = "block" OType_MainServer = "mainserver" + OType_Job = "job" OType_Temp = "temp" OType_Builder = "builder" // not persisted to DB ) @@ -41,6 +42,7 @@ var ValidOTypes = map[string]bool{ OType_LayoutState: true, OType_Block: true, OType_MainServer: true, + OType_Job: true, OType_Temp: true, OType_Builder: true, } @@ -134,6 +136,7 @@ type Client struct { TosAgreed int64 `json:"tosagreed,omitempty"` // unix milli HasOldHistory bool `json:"hasoldhistory,omitempty"` TempOID string `json:"tempoid,omitempty"` + InstallId string `json:"installid,omitempty"` } func (*Client) GetOType() string { @@ -288,6 +291,7 @@ type Block struct { Stickers []*StickerType `json:"stickers,omitempty"` Meta MetaMapType `json:"meta"` SubBlockIds []string `json:"subblockids,omitempty"` + JobId string `json:"jobid,omitempty"` // if set, the block will render this jobid's pty output } func (*Block) GetOType() string { @@ -306,6 +310,49 @@ func (*MainServer) GetOType() string { return OType_MainServer } +type Job struct { + OID string `json:"oid"` + Version int `json:"version"` + + // job metadata + Connection string `json:"connection"` + JobKind string `json:"jobkind"` // shell, task + Cmd string `json:"cmd"` + CmdArgs []string `json:"cmdargs,omitempty"` + CmdEnv map[string]string `json:"cmdenv,omitempty"` + JobAuthToken string `json:"jobauthtoken"` // job manger -> wave + AttachedBlockId string `json:"attachedblockid,omitempty"` + + // reconnect option (e.g. orphaned, so we need to kill on connect) + TerminateOnReconnect bool `json:"terminateonreconnect,omitempty"` + + // job manager state + JobManagerStatus string `json:"jobmanagerstatus"` // init, running, done + JobManagerDoneReason string `json:"jobmanagerdonereason,omitempty"` // startuperror, gone, terminated + JobManagerStartupError string `json:"jobmanagerstartuperror,omitempty"` + JobManagerPid int `json:"jobmanagerpid,omitempty"` + JobManagerStartTs int64 `json:"jobmanagerstartts,omitempty"` // exact process start time (milliseconds) + + // cmd/process runtime info + CmdPid int `json:"cmdpid,omitempty"` // command process id + CmdStartTs int64 `json:"cmdstartts,omitempty"` // exact command process start time (milliseconds from epoch) + CmdTermSize TermSize `json:"cmdtermsize"` + CmdExitTs int64 `json:"cmdexitts,omitempty"` // timestamp (milliseconds) -- use CmdExitTs > 0 to check if command has exited + CmdExitCode *int `json:"cmdexitcode,omitempty"` // nil when CmdExitSignal is set. success exit is when CmdExitCode is 0 + CmdExitSignal string `json:"cmdexitsignal,omitempty"` // empty string if CmdExitCode is set + CmdExitError string `json:"cmdexiterror,omitempty"` + + // output info + StreamDone bool `json:"streamdone,omitempty"` + StreamError string `json:"streamerror,omitempty"` + + Meta MetaMapType `json:"meta"` +} + +func (*Job) GetOType() string { + return OType_Job +} + func AllWaveObjTypes() []reflect.Type { return []reflect.Type{ reflect.TypeOf(&Client{}), @@ -315,6 +362,7 @@ func AllWaveObjTypes() []reflect.Type { reflect.TypeOf(&Block{}), reflect.TypeOf(&LayoutState{}), reflect.TypeOf(&MainServer{}), + reflect.TypeOf(&Job{}), } } diff --git a/pkg/wcore/block.go b/pkg/wcore/block.go index 3c6e2e197b..fc66232a32 100644 --- a/pkg/wcore/block.go +++ b/pkg/wcore/block.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/blockcontroller" "github.com/wavetermdev/waveterm/pkg/filestore" + "github.com/wavetermdev/waveterm/pkg/jobcontroller" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" @@ -167,6 +168,19 @@ func DeleteBlock(ctx context.Context, blockId string, recursive bool) error { } } } + if block.JobId != "" { + go func() { + defer func() { + panichandler.PanicHandler("DetachJobFromBlock", recover()) + }() + detachCtx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + err := jobcontroller.DetachJobFromBlock(detachCtx, block.JobId, false) + if err != nil { + log.Printf("error detaching job from block %s: %v", blockId, err) + } + }() + } parentBlockCount, err := deleteBlockObj(ctx, blockId) if err != nil { return fmt.Errorf("error deleting block: %w", err) diff --git a/pkg/wcore/wcore.go b/pkg/wcore/wcore.go index d8603f5caf..0c3fb905eb 100644 --- a/pkg/wcore/wcore.go +++ b/pkg/wcore/wcore.go @@ -50,6 +50,14 @@ func EnsureInitialData() (bool, error) { return firstLaunch, fmt.Errorf("error updating client: %w", err) } } + if client.InstallId == "" { + log.Println("client.InstallId is empty") + client.InstallId = uuid.NewString() + err = wstore.DBUpdate(ctx, client) + if err != nil { + return firstLaunch, fmt.Errorf("error updating client: %w", err) + } + } log.Printf("clientid: %s\n", client.OID) if len(client.WindowIds) == 1 { log.Println("client has one window") diff --git a/pkg/web/webcmd/webcmd.go b/pkg/web/webcmd/webcmd.go index bf732de0c6..b86934ce7a 100644 --- a/pkg/web/webcmd/webcmd.go +++ b/pkg/web/webcmd/webcmd.go @@ -9,14 +9,11 @@ import ( "github.com/wavetermdev/waveterm/pkg/tsgen/tsgenmeta" "github.com/wavetermdev/waveterm/pkg/util/utilfn" - "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshutil" ) const ( - WSCommand_SetBlockTermSize = "setblocktermsize" - WSCommand_BlockInput = "blockinput" - WSCommand_Rpc = "rpc" + WSCommand_Rpc = "rpc" ) type WSCommandType interface { @@ -28,8 +25,6 @@ func WSCommandTypeUnionMeta() tsgenmeta.TypeUnionMeta { BaseType: reflect.TypeOf((*WSCommandType)(nil)).Elem(), TypeFieldName: "wscommand", Types: []reflect.Type{ - reflect.TypeOf(SetBlockTermSizeWSCommand{}), - reflect.TypeOf(BlockInputWSCommand{}), reflect.TypeOf(WSRpcCommand{}), }, } @@ -44,46 +39,12 @@ func (cmd *WSRpcCommand) GetWSCommand() string { return cmd.WSCommand } -type SetBlockTermSizeWSCommand struct { - WSCommand string `json:"wscommand" tstype:"\"setblocktermsize\""` - BlockId string `json:"blockid"` - TermSize waveobj.TermSize `json:"termsize"` -} - -func (cmd *SetBlockTermSizeWSCommand) GetWSCommand() string { - return cmd.WSCommand -} - -type BlockInputWSCommand struct { - WSCommand string `json:"wscommand" tstype:"\"blockinput\""` - BlockId string `json:"blockid"` - InputData64 string `json:"inputdata64"` -} - -func (cmd *BlockInputWSCommand) GetWSCommand() string { - return cmd.WSCommand -} - func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) { cmdType, ok := cmdMap["wscommand"].(string) if !ok { return nil, fmt.Errorf("no wscommand field in command map") } switch cmdType { - case WSCommand_SetBlockTermSize: - var cmd SetBlockTermSizeWSCommand - err := utilfn.DoMapStructure(&cmd, cmdMap) - if err != nil { - return nil, fmt.Errorf("error decoding SetBlockTermSizeWSCommand: %w", err) - } - return &cmd, nil - case WSCommand_BlockInput: - var cmd BlockInputWSCommand - err := utilfn.DoMapStructure(&cmd, cmdMap) - if err != nil { - return nil, fmt.Errorf("error decoding BlockInputWSCommand: %w", err) - } - return &cmd, nil case WSCommand_Rpc: var cmd WSRpcCommand err := utilfn.DoMapStructure(&cmd, cmdMap) @@ -94,5 +55,4 @@ func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) { default: return nil, fmt.Errorf("unknown wscommand type %q", cmdType) } - } diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 0e6f0b0f9b..719753ba3c 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -20,7 +20,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/eventbus" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/web/webcmd" - "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -110,40 +109,6 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan ba } cmdType = wsCommand.GetWSCommand() switch cmd := wsCommand.(type) { - case *webcmd.SetBlockTermSizeWSCommand: - data := wshrpc.CommandBlockInputData{ - BlockId: cmd.BlockId, - TermSize: &cmd.TermSize, - } - rpcMsg := wshutil.RpcMessage{ - Command: wshrpc.Command_ControllerInput, - Data: data, - } - msgBytes, err := json.Marshal(rpcMsg) - if err != nil { - // this really should never fail since we just unmarshalled this value - log.Printf("[websocket] error marshalling rpc message: %v\n", err) - return - } - rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes} - - case *webcmd.BlockInputWSCommand: - data := wshrpc.CommandBlockInputData{ - BlockId: cmd.BlockId, - InputData64: cmd.InputData64, - } - rpcMsg := wshutil.RpcMessage{ - Command: wshrpc.Command_ControllerInput, - Data: data, - } - msgBytes, err := json.Marshal(rpcMsg) - if err != nil { - // this really should never fail since we just unmarshalled this value - log.Printf("[websocket] error marshalling rpc message: %v\n", err) - return - } - rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes} - case *webcmd.WSRpcCommand: rpcMsg := cmd.Message if rpcMsg == nil { diff --git a/pkg/wps/wpstypes.go b/pkg/wps/wpstypes.go index 4f16295e64..4ec65070b5 100644 --- a/pkg/wps/wpstypes.go +++ b/pkg/wps/wpstypes.go @@ -16,7 +16,8 @@ const ( Event_BlockFile = "blockfile" Event_Config = "config" Event_UserInput = "userinput" - Event_RouteGone = "route:gone" + Event_RouteDown = "route:down" + Event_RouteUp = "route:up" Event_WorkspaceUpdate = "workspace:update" Event_WaveAIRateLimit = "waveai:ratelimit" Event_WaveAppAppGoUpdated = "waveapp:appgoupdated" diff --git a/pkg/wshrpc/wshclient/barerpcclient.go b/pkg/wshrpc/wshclient/barerpcclient.go index 62d1f27ea7..d430266372 100644 --- a/pkg/wshrpc/wshclient/barerpcclient.go +++ b/pkg/wshrpc/wshclient/barerpcclient.go @@ -4,8 +4,10 @@ package wshclient import ( + "fmt" "sync" + "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" @@ -17,21 +19,22 @@ func (*WshServer) WshServerImpl() {} var WshServerImpl = WshServer{} -const ( - DefaultOutputChSize = 32 - DefaultInputChSize = 32 -) - var waveSrvClient_Singleton *wshutil.WshRpc var waveSrvClient_Once = &sync.Once{} - -const BareClientRoute = "bare" +var waveSrvClient_RouteId string func GetBareRpcClient() *wshutil.WshRpc { waveSrvClient_Once.Do(func() { waveSrvClient_Singleton = wshutil.MakeWshRpc(wshrpc.RpcContext{}, &WshServerImpl, "bare-client") - wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, BareClientRoute) + waveSrvClient_RouteId = fmt.Sprintf("bare:%s", uuid.New().String()) + // we can safely ignore the error from RegisterTrustedLeaf since the route is valid + wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, waveSrvClient_RouteId) wps.Broker.SetClient(wshutil.DefaultRouter) }) return waveSrvClient_Singleton } + +func GetBareRpcClientRouteId() string { + GetBareRpcClient() + return waveSrvClient_RouteId +} diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index a8f6c46e0d..62cd66d90c 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -35,6 +35,24 @@ func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) ( return resp, err } +// command "authenticatejobmanager", wshserver.AuthenticateJobManagerCommand +func AuthenticateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanager", data, opts) + return err +} + +// command "authenticatejobmanagerverify", wshserver.AuthenticateJobManagerVerifyCommand +func AuthenticateJobManagerVerifyCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanagerverify", data, opts) + return err +} + +// command "authenticatetojobmanager", wshserver.AuthenticateToJobManagerCommand +func AuthenticateToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateToJobData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "authenticatetojobmanager", data, opts) + return err +} + // command "authenticatetoken", wshserver.AuthenticateTokenCommand func AuthenticateTokenCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateTokenData, opts *wshrpc.RpcOpts) (wshrpc.CommandAuthenticateRtnData, error) { resp, err := sendRpcRequestCallHelper[wshrpc.CommandAuthenticateRtnData](w, "authenticatetoken", data, opts) @@ -458,6 +476,90 @@ func GetWaveAIRateLimitCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (*uctype return resp, err } +// command "jobcmdexited", wshserver.JobCmdExitedCommand +func JobCmdExitedCommand(w *wshutil.WshRpc, data wshrpc.CommandJobCmdExitedData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcmdexited", data, opts) + return err +} + +// command "jobcontrollerattachjob", wshserver.JobControllerAttachJobCommand +func JobControllerAttachJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerAttachJobData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerattachjob", data, opts) + return err +} + +// command "jobcontrollerconnectedjobs", wshserver.JobControllerConnectedJobsCommand +func JobControllerConnectedJobsCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) { + resp, err := sendRpcRequestCallHelper[[]string](w, "jobcontrollerconnectedjobs", nil, opts) + return resp, err +} + +// command "jobcontrollerdeletejob", wshserver.JobControllerDeleteJobCommand +func JobControllerDeleteJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdeletejob", data, opts) + return err +} + +// command "jobcontrollerdetachjob", wshserver.JobControllerDetachJobCommand +func JobControllerDetachJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdetachjob", data, opts) + return err +} + +// command "jobcontrollerdisconnectjob", wshserver.JobControllerDisconnectJobCommand +func JobControllerDisconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdisconnectjob", data, opts) + return err +} + +// command "jobcontrollerexitjob", wshserver.JobControllerExitJobCommand +func JobControllerExitJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerexitjob", data, opts) + return err +} + +// command "jobcontrollerlist", wshserver.JobControllerListCommand +func JobControllerListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]*waveobj.Job, error) { + resp, err := sendRpcRequestCallHelper[[]*waveobj.Job](w, "jobcontrollerlist", nil, opts) + return resp, err +} + +// command "jobcontrollerreconnectjob", wshserver.JobControllerReconnectJobCommand +func JobControllerReconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjob", data, opts) + return err +} + +// command "jobcontrollerreconnectjobsforconn", wshserver.JobControllerReconnectJobsForConnCommand +func JobControllerReconnectJobsForConnCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjobsforconn", data, opts) + return err +} + +// command "jobcontrollerstartjob", wshserver.JobControllerStartJobCommand +func JobControllerStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerStartJobData, opts *wshrpc.RpcOpts) (string, error) { + resp, err := sendRpcRequestCallHelper[string](w, "jobcontrollerstartjob", data, opts) + return resp, err +} + +// command "jobinput", wshserver.JobInputCommand +func JobInputCommand(w *wshutil.WshRpc, data wshrpc.CommandJobInputData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobinput", data, opts) + return err +} + +// command "jobprepareconnect", wshserver.JobPrepareConnectCommand +func JobPrepareConnectCommand(w *wshutil.WshRpc, data wshrpc.CommandJobPrepareConnectData, opts *wshrpc.RpcOpts) (*wshrpc.CommandJobConnectRtnData, error) { + resp, err := sendRpcRequestCallHelper[*wshrpc.CommandJobConnectRtnData](w, "jobprepareconnect", data, opts) + return resp, err +} + +// command "jobstartstream", wshserver.JobStartStreamCommand +func JobStartStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandJobStartStreamData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "jobstartstream", data, opts) + return err +} + // command "listallappfiles", wshserver.ListAllAppFilesCommand func ListAllAppFilesCommand(w *wshutil.WshRpc, data wshrpc.CommandListAllAppFilesData, opts *wshrpc.RpcOpts) (*wshrpc.CommandListAllAppFilesRtnData, error) { resp, err := sendRpcRequestCallHelper[*wshrpc.CommandListAllAppFilesRtnData](w, "listallappfiles", data, opts) @@ -524,6 +626,12 @@ func RecordTEventCommand(w *wshutil.WshRpc, data telemetrydata.TEvent, opts *wsh return err } +// command "remotedisconnectfromjobmanager", wshserver.RemoteDisconnectFromJobManagerCommand +func RemoteDisconnectFromJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteDisconnectFromJobManagerData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "remotedisconnectfromjobmanager", data, opts) + return err +} + // command "remotefilecopy", wshserver.RemoteFileCopyCommand func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) (bool, error) { resp, err := sendRpcRequestCallHelper[bool](w, "remotefilecopy", data, opts) @@ -583,6 +691,18 @@ func RemoteMkdirCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) er return err } +// command "remotereconnecttojobmanager", wshserver.RemoteReconnectToJobManagerCommand +func RemoteReconnectToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteReconnectToJobManagerData, opts *wshrpc.RpcOpts) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) { + resp, err := sendRpcRequestCallHelper[*wshrpc.CommandRemoteReconnectToJobManagerRtnData](w, "remotereconnecttojobmanager", data, opts) + return resp, err +} + +// command "remotestartjob", wshserver.RemoteStartJobCommand +func RemoteStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) { + resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "remotestartjob", data, opts) + return resp, err +} + // command "remotestreamcpudata", wshserver.RemoteStreamCpuDataCommand func RemoteStreamCpuDataCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.TimeSeriesData] { return sendRpcRequestResponseStreamHelper[wshrpc.TimeSeriesData](w, "remotestreamcpudata", nil, opts) @@ -598,6 +718,12 @@ func RemoteTarStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTa return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "remotetarstream", data, opts) } +// command "remoteterminatejobmanager", wshserver.RemoteTerminateJobManagerCommand +func RemoteTerminateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteTerminateJobManagerData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "remoteterminatejobmanager", data, opts) + return err +} + // command "remotewritefile", wshserver.RemoteWriteFileCommand func RemoteWriteFileCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotewritefile", data, opts) @@ -688,6 +814,12 @@ func StartBuilderCommand(w *wshutil.WshRpc, data wshrpc.CommandStartBuilderData, return err } +// command "startjob", wshserver.StartJobCommand +func StartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) { + resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "startjob", data, opts) + return resp, err +} + // command "stopbuilder", wshserver.StopBuilderCommand func StopBuilderCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "stopbuilder", data, opts) @@ -727,6 +859,12 @@ func TermGetScrollbackLinesCommand(w *wshutil.WshRpc, data wshrpc.CommandTermGet return resp, err } +// command "termupdateattachedjob", wshserver.TermUpdateAttachedJobCommand +func TermUpdateAttachedJobCommand(w *wshutil.WshRpc, data wshrpc.CommandTermUpdateAttachedJobData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "termupdateattachedjob", data, opts) + return err +} + // command "test", wshserver.TestCommand func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "test", data, opts) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index ec90e367c7..f0cfbb145c 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -4,35 +4,44 @@ package wshremote import ( - "archive/tar" "context" - "encoding/base64" - "errors" "fmt" "io" - "io/fs" "log" - "os" + "net" "path/filepath" - "strings" - "time" + "sync" - "github.com/wavetermdev/waveterm/pkg/remote/connparse" - "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" - "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/suggestion" - "github.com/wavetermdev/waveterm/pkg/util/fileutil" - "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" - "github.com/wavetermdev/waveterm/pkg/util/tarcopy" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wshrpc" - "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" ) +type JobManagerConnection struct { + JobId string + Conn net.Conn + WshRpc *wshutil.WshRpc + CleanupFn func() +} + type ServerImpl struct { - LogWriter io.Writer + LogWriter io.Writer + Router *wshutil.WshRouter + RpcClient *wshutil.WshRpc + IsLocal bool + JobManagerMap map[string]*JobManagerConnection + Lock sync.Mutex +} + +func MakeRemoteRpcServerImpl(logWriter io.Writer, router *wshutil.WshRouter, rpcClient *wshutil.WshRpc, isLocal bool) *ServerImpl { + return &ServerImpl{ + LogWriter: logWriter, + Router: router, + RpcClient: rpcClient, + IsLocal: isLocal, + JobManagerMap: make(map[string]*JobManagerConnection), + } } func (*ServerImpl) WshServerImpl() {} @@ -66,785 +75,6 @@ func (impl *ServerImpl) StreamTestCommand(ctx context.Context) chan wshrpc.RespO return ch } -type ByteRangeType struct { - All bool - Start int64 - End int64 -} - -func parseByteRange(rangeStr string) (ByteRangeType, error) { - if rangeStr == "" { - return ByteRangeType{All: true}, nil - } - var start, end int64 - _, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end) - if err != nil { - return ByteRangeType{}, errors.New("invalid byte range") - } - if start < 0 || end < 0 || start > end { - return ByteRangeType{}, errors.New("invalid byte range") - } - return ByteRangeType{Start: start, End: end}, nil -} - -func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { - innerFilesEntries, err := os.ReadDir(path) - if err != nil { - return fmt.Errorf("cannot open dir %q: %w", path, err) - } - if byteRange.All { - if len(innerFilesEntries) > wshrpc.MaxDirSize { - innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize] - } - } else { - if byteRange.Start < int64(len(innerFilesEntries)) { - realEnd := byteRange.End - if realEnd > int64(len(innerFilesEntries)) { - realEnd = int64(len(innerFilesEntries)) - } - innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd] - } else { - innerFilesEntries = []os.DirEntry{} - } - } - var fileInfoArr []*wshrpc.FileInfo - for _, innerFileEntry := range innerFilesEntries { - if ctx.Err() != nil { - return ctx.Err() - } - innerFileInfoInt, err := innerFileEntry.Info() - if err != nil { - continue - } - innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false) - fileInfoArr = append(fileInfoArr, innerFileInfo) - if len(fileInfoArr) >= wshrpc.DirChunkSize { - dataCallback(fileInfoArr, nil, byteRange) - fileInfoArr = nil - } - } - if len(fileInfoArr) > 0 { - dataCallback(fileInfoArr, nil, byteRange) - } - return nil -} - -func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { - fd, err := os.Open(path) - if err != nil { - return fmt.Errorf("cannot open file %q: %w", path, err) - } - defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path) - var filePos int64 - if !byteRange.All && byteRange.Start > 0 { - _, err := fd.Seek(byteRange.Start, io.SeekStart) - if err != nil { - return fmt.Errorf("seeking file %q: %w", path, err) - } - filePos = byteRange.Start - } - buf := make([]byte, wshrpc.FileChunkSize) - for { - if ctx.Err() != nil { - return ctx.Err() - } - n, err := fd.Read(buf) - if n > 0 { - if !byteRange.All && filePos+int64(n) > byteRange.End { - n = int(byteRange.End - filePos) - } - filePos += int64(n) - dataCallback(nil, buf[:n], byteRange) - } - if !byteRange.All && filePos >= byteRange.End { - break - } - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return fmt.Errorf("reading file %q: %w", path, err) - } - } - return nil -} - -func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { - byteRange, err := parseByteRange(data.ByteRange) - if err != nil { - return err - } - path, err := wavebase.ExpandHomeDir(data.Path) - if err != nil { - return err - } - finfo, err := impl.fileInfoInternal(path, true) - if err != nil { - return fmt.Errorf("cannot stat file %q: %w", path, err) - } - dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange) - if finfo.NotFound { - return nil - } - if finfo.IsDir { - return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback) - } else { - return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback) - } -} - -func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { - ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16) - go func() { - defer close(ch) - firstPk := true - err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) { - resp := wshrpc.FileData{} - fileInfoLen := len(fileInfo) - if fileInfoLen > 1 || !firstPk { - resp.Entries = fileInfo - } else if fileInfoLen == 1 { - resp.Info = fileInfo[0] - } - if firstPk { - firstPk = false - } - if len(data) > 0 { - resp.Data64 = base64.StdEncoding.EncodeToString(data) - resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)} - } - ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp} - }) - if err != nil { - ch <- wshutil.RespErr[wshrpc.FileData](err) - } - }() - return ch -} - -func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { - path := data.Path - opts := data.Opts - if opts == nil { - opts = &wshrpc.FileCopyOpts{} - } - log.Printf("RemoteTarStreamCommand: path=%s\n", path) - srcHasSlash := strings.HasSuffix(path, "/") - path, err := wavebase.ExpandHomeDir(path) - if err != nil { - return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err)) - } - cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) - finfo, err := os.Stat(cleanedPath) - if err != nil { - return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err)) - } - - var pathPrefix string - singleFile := !finfo.IsDir() - if !singleFile && srcHasSlash { - pathPrefix = cleanedPath - } else { - pathPrefix = filepath.Dir(cleanedPath) - } - - timeout := fstype.DefaultTimeout - if opts.Timeout > 0 { - timeout = time.Duration(opts.Timeout) * time.Millisecond - } - readerCtx, cancel := context.WithTimeout(ctx, timeout) - rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix) - - go func() { - defer func() { - tarClose() - cancel() - }() - walkFunc := func(path string, info fs.FileInfo, err error) error { - if readerCtx.Err() != nil { - return readerCtx.Err() - } - if err != nil { - return err - } - if err = writeHeader(info, path, singleFile); err != nil { - return err - } - // if not a dir, write file content - if !info.IsDir() { - data, err := os.Open(path) - if err != nil { - return err - } - defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path) - if _, err := io.Copy(fileWriter, data); err != nil { - return err - } - } - return nil - } - log.Printf("RemoteTarStreamCommand: starting\n") - err = nil - if singleFile { - err = walkFunc(cleanedPath, finfo, nil) - } else { - err = filepath.Walk(cleanedPath, walkFunc) - } - if err != nil { - rtn <- wshutil.RespErr[iochantypes.Packet](err) - } - log.Printf("RemoteTarStreamCommand: done\n") - }() - log.Printf("RemoteTarStreamCommand: returning channel\n") - return rtn -} - -func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) { - log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) - opts := data.Opts - if opts == nil { - opts = &wshrpc.FileCopyOpts{} - } - destUri := data.DestUri - srcUri := data.SrcUri - merge := opts.Merge - overwrite := opts.Overwrite - if overwrite && merge { - return false, fmt.Errorf("cannot specify both overwrite and merge") - } - - destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) - if err != nil { - return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err) - } - destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) - destinfo, err := os.Stat(destPathCleaned) - if err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) - } - } - - destExists := destinfo != nil - destIsDir := destExists && destinfo.IsDir() - destHasSlash := strings.HasSuffix(destUri, "/") - - if destExists && !destIsDir { - if !overwrite { - return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned) - } else { - err := os.Remove(destPathCleaned) - if err != nil { - return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) - } - } - } - srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) - if err != nil { - return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) - } - - copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { - nextinfo, err := os.Stat(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return 0, fmt.Errorf("cannot stat file %q: %w", path, err) - } - - if nextinfo != nil { - if nextinfo.IsDir() { - if !finfo.IsDir() { - // try to create file in directory - path = filepath.Join(path, filepath.Base(finfo.Name())) - newdestinfo, err := os.Stat(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return 0, fmt.Errorf("cannot stat file %q: %w", path, err) - } - if newdestinfo != nil && !overwrite { - return 0, fmt.Errorf(fstype.OverwriteRequiredError, path) - } - } else if overwrite { - err := os.RemoveAll(path) - if err != nil { - return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) - } - } else if !merge { - return 0, fmt.Errorf(fstype.MergeRequiredError, path) - } - } else { - if !overwrite { - return 0, fmt.Errorf(fstype.OverwriteRequiredError, path) - } else if finfo.IsDir() { - err := os.RemoveAll(path) - if err != nil { - return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) - } - } - } - } - - if finfo.IsDir() { - err := os.MkdirAll(path, finfo.Mode()) - if err != nil { - return 0, fmt.Errorf("cannot create directory %q: %w", path, err) - } - return 0, nil - } else { - err := os.MkdirAll(filepath.Dir(path), 0755) - if err != nil { - return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) - } - } - - file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) - if err != nil { - return 0, fmt.Errorf("cannot create new file %q: %w", path, err) - } - defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path) - _, err = io.Copy(file, srcFile) - if err != nil { - return 0, fmt.Errorf("cannot write file %q: %w", path, err) - } - - return finfo.Size(), nil - } - - srcIsDir := false - if srcConn.Host == destConn.Host { - srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - - srcFileStat, err := os.Stat(srcPathCleaned) - if err != nil { - return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) - } - - if srcFileStat.IsDir() { - srcIsDir = true - var srcPathPrefix string - if destIsDir { - srcPathPrefix = filepath.Dir(srcPathCleaned) - } else { - srcPathPrefix = srcPathCleaned - } - err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return err - } - srcFilePath := path - destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix)) - var file *os.File - if !info.IsDir() { - file, err = os.Open(srcFilePath) - if err != nil { - return fmt.Errorf("cannot open file %q: %w", srcFilePath, err) - } - defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath) - } - _, err = copyFileFunc(destFilePath, info, file) - return err - }) - if err != nil { - return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) - } - } else { - file, err := os.Open(srcPathCleaned) - if err != nil { - return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) - } - defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned) - var destFilePath string - if destHasSlash { - destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned)) - } else { - destFilePath = destPathCleaned - } - _, err = copyFileFunc(destFilePath, srcFileStat, file) - if err != nil { - return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) - } - } - } else { - timeout := fstype.DefaultTimeout - if opts.Timeout > 0 { - timeout = time.Duration(opts.Timeout) * time.Millisecond - } - readCtx, cancel := context.WithCancelCause(ctx) - readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri)) - defer timeoutCancel() - copyStart := time.Now() - ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout}) - numFiles := 0 - numSkipped := 0 - totalBytes := int64(0) - - err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { - numFiles++ - nextpath := filepath.Join(destPathCleaned, next.Name) - srcIsDir = !singleFile - if singleFile && !destHasSlash { - // custom flag to indicate that the source is a single file, not a directory the contents of a directory - nextpath = destPathCleaned - } - finfo := next.FileInfo() - n, err := copyFileFunc(nextpath, finfo, reader) - if err != nil { - return fmt.Errorf("cannot copy file %q: %w", next.Name, err) - } - totalBytes += n - return nil - }) - if err != nil { - return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) - } - totalTime := time.Since(copyStart).Seconds() - totalMegaBytes := float64(totalBytes) / 1024 / 1024 - rate := float64(0) - if totalTime > 0 { - rate = totalMegaBytes / totalTime - } - log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped) - } - return srcIsDir, nil -} - -func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { - ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16) - go func() { - defer close(ch) - path, err := wavebase.ExpandHomeDir(data.Path) - if err != nil { - ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err) - return - } - innerFilesEntries := []os.DirEntry{} - seen := 0 - if data.Opts.Limit == 0 { - data.Opts.Limit = wshrpc.MaxDirSize - } - if data.Opts.All { - fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error { - defer func() { - seen++ - }() - if seen < data.Opts.Offset { - return nil - } - if seen >= data.Opts.Offset+data.Opts.Limit { - return io.EOF - } - if err != nil { - return err - } - if d.IsDir() { - return nil - } - innerFilesEntries = append(innerFilesEntries, d) - return nil - }) - } else { - innerFilesEntries, err = os.ReadDir(path) - if err != nil { - ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err)) - return - } - } - var fileInfoArr []*wshrpc.FileInfo - for _, innerFileEntry := range innerFilesEntries { - if ctx.Err() != nil { - ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err()) - return - } - innerFileInfoInt, err := innerFileEntry.Info() - if err != nil { - log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err) - continue - } - innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false) - fileInfoArr = append(fileInfoArr, innerFileInfo) - if len(fileInfoArr) >= wshrpc.DirChunkSize { - resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr} - ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp} - fileInfoArr = nil - } - } - if len(fileInfoArr) > 0 { - resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr} - ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp} - } - }() - return ch -} - -func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo { - mimeType := fileutil.DetectMimeType(fullPath, finfo, extended) - rtn := &wshrpc.FileInfo{ - Path: wavebase.ReplaceHomeDir(fullPath), - Dir: computeDirPart(fullPath), - Name: finfo.Name(), - Size: finfo.Size(), - Mode: finfo.Mode(), - ModeStr: finfo.Mode().String(), - ModTime: finfo.ModTime().UnixMilli(), - IsDir: finfo.IsDir(), - MimeType: mimeType, - SupportsMkdir: true, - } - if finfo.IsDir() { - rtn.Size = -1 - } - return rtn -} - -// fileInfo might be null -func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool { - if !exists || fileInfo.Mode().IsDir() { - dirName := filepath.Dir(path) - randHexStr, err := utilfn.RandomHexString(12) - if err != nil { - // we're not sure, just return false - return false - } - tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr) - fd, err := os.Create(tmpFileName) - if err != nil { - return true - } - utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName) - os.Remove(tmpFileName) - return false - } - // try to open for writing, if this fails then it is read-only - file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - return true - } - utilfn.GracefulClose(file, "checkIsReadOnly", path) - return false -} - -func computeDirPart(path string) string { - path = filepath.Clean(wavebase.ExpandHomeDirSafe(path)) - path = filepath.ToSlash(path) - if path == "/" { - return "/" - } - return filepath.Dir(path) -} - -func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) { - cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) - finfo, err := os.Stat(cleanedPath) - if os.IsNotExist(err) { - return &wshrpc.FileInfo{ - Path: wavebase.ReplaceHomeDir(path), - Dir: computeDirPart(path), - NotFound: true, - ReadOnly: checkIsReadOnly(cleanedPath, finfo, false), - SupportsMkdir: true, - }, nil - } - if err != nil { - return nil, fmt.Errorf("cannot stat file %q: %w", path, err) - } - rtn := statToFileInfo(cleanedPath, finfo, extended) - if extended { - rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true) - } - return rtn, nil -} - -func resolvePaths(paths []string) string { - if len(paths) == 0 { - return wavebase.ExpandHomeDirSafe("~") - } - rtnPath := wavebase.ExpandHomeDirSafe(paths[0]) - for _, path := range paths[1:] { - path = wavebase.ExpandHomeDirSafe(path) - if filepath.IsAbs(path) { - rtnPath = path - continue - } - rtnPath = filepath.Join(rtnPath, path) - } - return rtnPath -} - -func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) { - rtnPath := resolvePaths(paths) - return impl.fileInfoInternal(rtnPath, true) -} - -func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) { - return impl.fileInfoInternal(path, true) -} - -func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error { - cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) - if _, err := os.Stat(cleanedPath); err == nil { - return fmt.Errorf("file %q already exists", path) - } - if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil { - return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err) - } - if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil { - return fmt.Errorf("cannot create file %q: %w", cleanedPath, err) - } - return nil -} - -func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { - opts := data.Opts - destUri := data.DestUri - srcUri := data.SrcUri - overwrite := opts != nil && opts.Overwrite - recursive := opts != nil && opts.Recursive - - destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) - if err != nil { - return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err) - } - destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) - destinfo, err := os.Stat(destPathCleaned) - if err == nil { - if !destinfo.IsDir() { - if !overwrite { - return fmt.Errorf("destination %q already exists, use overwrite option", destUri) - } else { - err := os.Remove(destPathCleaned) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", destUri, err) - } - } - } - } else if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat destination %q: %w", destUri, err) - } - srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) - if err != nil { - return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) - } - if srcConn.Host == destConn.Host { - srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - finfo, err := os.Stat(srcPathCleaned) - if err != nil { - return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) - } - if finfo.IsDir() && !recursive { - return fmt.Errorf(fstype.RecursiveRequiredError) - } - err = os.Rename(srcPathCleaned, destPathCleaned) - if err != nil { - return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err) - } - } else { - return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri) - } - return nil -} - -func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error { - cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) - if stat, err := os.Stat(cleanedPath); err == nil { - if stat.IsDir() { - return fmt.Errorf("directory %q already exists", path) - } else { - return fmt.Errorf("cannot create directory %q, file exists at path", path) - } - } - if err := os.MkdirAll(cleanedPath, 0755); err != nil { - return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err) - } - return nil -} -func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error { - var truncate, append bool - var atOffset int64 - if data.Info != nil && data.Info.Opts != nil { - truncate = data.Info.Opts.Truncate - append = data.Info.Opts.Append - } - if data.At != nil { - atOffset = data.At.Offset - } - if truncate && atOffset > 0 { - return fmt.Errorf("cannot specify non-zero offset with truncate option") - } - if append && atOffset > 0 { - return fmt.Errorf("cannot specify non-zero offset with append option") - } - path, err := wavebase.ExpandHomeDir(data.Info.Path) - if err != nil { - return err - } - createMode := os.FileMode(0644) - if data.Info != nil && data.Info.Mode > 0 { - createMode = data.Info.Mode - } - dataSize := base64.StdEncoding.DecodedLen(len(data.Data64)) - dataBytes := make([]byte, dataSize) - n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64)) - if err != nil { - return fmt.Errorf("cannot decode base64 data: %w", err) - } - finfo, err := os.Stat(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat file %q: %w", path, err) - } - fileSize := int64(0) - if finfo != nil { - fileSize = finfo.Size() - } - if atOffset > fileSize { - return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize) - } - openFlags := os.O_CREATE | os.O_WRONLY - if truncate { - openFlags |= os.O_TRUNC - } - if append { - openFlags |= os.O_APPEND - } - - file, err := os.OpenFile(path, openFlags, createMode) - if err != nil { - return fmt.Errorf("cannot open file %q: %w", path, err) - } - defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path) - if atOffset > 0 && !append { - n, err = file.WriteAt(dataBytes[:n], atOffset) - } else { - n, err = file.Write(dataBytes[:n]) - } - if err != nil { - return fmt.Errorf("cannot write to file %q: %w", path, err) - } - return nil -} - -func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error { - expandedPath, err := wavebase.ExpandHomeDir(data.Path) - if err != nil { - return fmt.Errorf("cannot delete file %q: %w", data.Path, err) - } - cleanedPath := filepath.Clean(expandedPath) - - err = os.Remove(cleanedPath) - if err != nil { - finfo, _ := os.Stat(cleanedPath) - if finfo != nil && finfo.IsDir() { - if !data.Recursive { - return fmt.Errorf(fstype.RecursiveRequiredError) - } - err = os.RemoveAll(cleanedPath) - if err != nil { - return fmt.Errorf("cannot delete directory %q: %w", data.Path, err) - } - } else { - return fmt.Errorf("cannot delete file %q: %w", data.Path, err) - } - } - return nil -} - func (*ServerImpl) RemoteGetInfoCommand(ctx context.Context) (wshrpc.RemoteInfo, error) { return wshutil.GetInfo(), nil } @@ -861,3 +91,14 @@ func (*ServerImpl) DisposeSuggestionsCommand(ctx context.Context, widgetId strin suggestion.DisposeSuggestions(ctx, widgetId) return nil } + +func (impl *ServerImpl) getWshPath() (string, error) { + if impl.IsLocal { + return filepath.Join(wavebase.GetWaveDataDir(), "bin", "wsh"), nil + } + wshPath, err := wavebase.ExpandHomeDir("~/.waveterm/bin/wsh") + if err != nil { + return "", fmt.Errorf("cannot expand wsh path: %w", err) + } + return wshPath, nil +} diff --git a/pkg/wshrpc/wshremote/wshremote_file.go b/pkg/wshrpc/wshremote/wshremote_file.go new file mode 100644 index 0000000000..c83ae60cfa --- /dev/null +++ b/pkg/wshrpc/wshremote/wshremote_file.go @@ -0,0 +1,810 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshremote + +import ( + "archive/tar" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" + "github.com/wavetermdev/waveterm/pkg/util/fileutil" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/tarcopy" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" +) + +type ByteRangeType struct { + All bool + Start int64 + End int64 +} + +func parseByteRange(rangeStr string) (ByteRangeType, error) { + if rangeStr == "" { + return ByteRangeType{All: true}, nil + } + var start, end int64 + _, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end) + if err != nil { + return ByteRangeType{}, errors.New("invalid byte range") + } + if start < 0 || end < 0 || start > end { + return ByteRangeType{}, errors.New("invalid byte range") + } + return ByteRangeType{Start: start, End: end}, nil +} + +func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { + innerFilesEntries, err := os.ReadDir(path) + if err != nil { + return fmt.Errorf("cannot open dir %q: %w", path, err) + } + if byteRange.All { + if len(innerFilesEntries) > wshrpc.MaxDirSize { + innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize] + } + } else { + if byteRange.Start < int64(len(innerFilesEntries)) { + realEnd := byteRange.End + if realEnd > int64(len(innerFilesEntries)) { + realEnd = int64(len(innerFilesEntries)) + } + innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd] + } else { + innerFilesEntries = []os.DirEntry{} + } + } + var fileInfoArr []*wshrpc.FileInfo + for _, innerFileEntry := range innerFilesEntries { + if ctx.Err() != nil { + return ctx.Err() + } + innerFileInfoInt, err := innerFileEntry.Info() + if err != nil { + continue + } + innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false) + fileInfoArr = append(fileInfoArr, innerFileInfo) + if len(fileInfoArr) >= wshrpc.DirChunkSize { + dataCallback(fileInfoArr, nil, byteRange) + fileInfoArr = nil + } + } + if len(fileInfoArr) > 0 { + dataCallback(fileInfoArr, nil, byteRange) + } + return nil +} + +func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { + fd, err := os.Open(path) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", path, err) + } + defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path) + var filePos int64 + if !byteRange.All && byteRange.Start > 0 { + _, err := fd.Seek(byteRange.Start, io.SeekStart) + if err != nil { + return fmt.Errorf("seeking file %q: %w", path, err) + } + filePos = byteRange.Start + } + buf := make([]byte, wshrpc.FileChunkSize) + for { + if ctx.Err() != nil { + return ctx.Err() + } + n, err := fd.Read(buf) + if n > 0 { + if !byteRange.All && filePos+int64(n) > byteRange.End { + n = int(byteRange.End - filePos) + } + filePos += int64(n) + dataCallback(nil, buf[:n], byteRange) + } + if !byteRange.All && filePos >= byteRange.End { + break + } + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("reading file %q: %w", path, err) + } + } + return nil +} + +func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error { + byteRange, err := parseByteRange(data.ByteRange) + if err != nil { + return err + } + path, err := wavebase.ExpandHomeDir(data.Path) + if err != nil { + return err + } + finfo, err := impl.fileInfoInternal(path, true) + if err != nil { + return fmt.Errorf("cannot stat file %q: %w", path, err) + } + dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange) + if finfo.NotFound { + return nil + } + if finfo.IsDir { + return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback) + } else { + return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback) + } +} + +func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { + ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16) + go func() { + defer close(ch) + firstPk := true + err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) { + resp := wshrpc.FileData{} + fileInfoLen := len(fileInfo) + if fileInfoLen > 1 || !firstPk { + resp.Entries = fileInfo + } else if fileInfoLen == 1 { + resp.Info = fileInfo[0] + } + if firstPk { + firstPk = false + } + if len(data) > 0 { + resp.Data64 = base64.StdEncoding.EncodeToString(data) + resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)} + } + ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp} + }) + if err != nil { + ch <- wshutil.RespErr[wshrpc.FileData](err) + } + }() + return ch +} + +func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + path := data.Path + opts := data.Opts + if opts == nil { + opts = &wshrpc.FileCopyOpts{} + } + log.Printf("RemoteTarStreamCommand: path=%s\n", path) + srcHasSlash := strings.HasSuffix(path, "/") + path, err := wavebase.ExpandHomeDir(path) + if err != nil { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err)) + } + cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) + finfo, err := os.Stat(cleanedPath) + if err != nil { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err)) + } + + var pathPrefix string + singleFile := !finfo.IsDir() + if !singleFile && srcHasSlash { + pathPrefix = cleanedPath + } else { + pathPrefix = filepath.Dir(cleanedPath) + } + + timeout := fstype.DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readerCtx, cancel := context.WithTimeout(ctx, timeout) + rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix) + + go func() { + defer func() { + tarClose() + cancel() + }() + walkFunc := func(path string, info fs.FileInfo, err error) error { + if readerCtx.Err() != nil { + return readerCtx.Err() + } + if err != nil { + return err + } + if err = writeHeader(info, path, singleFile); err != nil { + return err + } + // if not a dir, write file content + if !info.IsDir() { + data, err := os.Open(path) + if err != nil { + return err + } + defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path) + if _, err := io.Copy(fileWriter, data); err != nil { + return err + } + } + return nil + } + log.Printf("RemoteTarStreamCommand: starting\n") + err = nil + if singleFile { + err = walkFunc(cleanedPath, finfo, nil) + } else { + err = filepath.Walk(cleanedPath, walkFunc) + } + if err != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](err) + } + log.Printf("RemoteTarStreamCommand: done\n") + }() + log.Printf("RemoteTarStreamCommand: returning channel\n") + return rtn +} + +func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) { + log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) + opts := data.Opts + if opts == nil { + opts = &wshrpc.FileCopyOpts{} + } + destUri := data.DestUri + srcUri := data.SrcUri + merge := opts.Merge + overwrite := opts.Overwrite + if overwrite && merge { + return false, fmt.Errorf("cannot specify both overwrite and merge") + } + + destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) + if err != nil { + return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err) + } + destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) + destinfo, err := os.Stat(destPathCleaned) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) + } + } + + destExists := destinfo != nil + destIsDir := destExists && destinfo.IsDir() + destHasSlash := strings.HasSuffix(destUri, "/") + + if destExists && !destIsDir { + if !overwrite { + return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned) + } else { + err := os.Remove(destPathCleaned) + if err != nil { + return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) + } + } + } + srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) + if err != nil { + return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) + } + + copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { + nextinfo, err := os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + + if nextinfo != nil { + if nextinfo.IsDir() { + if !finfo.IsDir() { + // try to create file in directory + path = filepath.Join(path, filepath.Base(finfo.Name())) + newdestinfo, err := os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + if newdestinfo != nil && !overwrite { + return 0, fmt.Errorf(fstype.OverwriteRequiredError, path) + } + } else if overwrite { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } else if !merge { + return 0, fmt.Errorf(fstype.MergeRequiredError, path) + } + } else { + if !overwrite { + return 0, fmt.Errorf(fstype.OverwriteRequiredError, path) + } else if finfo.IsDir() { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } + } + } + + if finfo.IsDir() { + err := os.MkdirAll(path, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create directory %q: %w", path, err) + } + return 0, nil + } else { + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) + } + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create new file %q: %w", path, err) + } + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path) + _, err = io.Copy(file, srcFile) + if err != nil { + return 0, fmt.Errorf("cannot write file %q: %w", path, err) + } + + return finfo.Size(), nil + } + + srcIsDir := false + if srcConn.Host == destConn.Host { + srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) + + srcFileStat, err := os.Stat(srcPathCleaned) + if err != nil { + return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + + if srcFileStat.IsDir() { + srcIsDir = true + var srcPathPrefix string + if destIsDir { + srcPathPrefix = filepath.Dir(srcPathCleaned) + } else { + srcPathPrefix = srcPathCleaned + } + err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + srcFilePath := path + destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix)) + var file *os.File + if !info.IsDir() { + file, err = os.Open(srcFilePath) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", srcFilePath, err) + } + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath) + } + _, err = copyFileFunc(destFilePath, info, file) + return err + }) + if err != nil { + return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } + } else { + file, err := os.Open(srcPathCleaned) + if err != nil { + return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) + } + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned) + var destFilePath string + if destHasSlash { + destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned)) + } else { + destFilePath = destPathCleaned + } + _, err = copyFileFunc(destFilePath, srcFileStat, file) + if err != nil { + return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } + } + } else { + timeout := fstype.DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readCtx, cancel := context.WithCancelCause(ctx) + readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri)) + defer timeoutCancel() + copyStart := time.Now() + ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout}) + numFiles := 0 + numSkipped := 0 + totalBytes := int64(0) + + err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { + numFiles++ + nextpath := filepath.Join(destPathCleaned, next.Name) + srcIsDir = !singleFile + if singleFile && !destHasSlash { + // custom flag to indicate that the source is a single file, not a directory the contents of a directory + nextpath = destPathCleaned + } + finfo := next.FileInfo() + n, err := copyFileFunc(nextpath, finfo, reader) + if err != nil { + return fmt.Errorf("cannot copy file %q: %w", next.Name, err) + } + totalBytes += n + return nil + }) + if err != nil { + return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } + totalTime := time.Since(copyStart).Seconds() + totalMegaBytes := float64(totalBytes) / 1024 / 1024 + rate := float64(0) + if totalTime > 0 { + rate = totalMegaBytes / totalTime + } + log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped) + } + return srcIsDir, nil +} + +func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { + ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16) + go func() { + defer close(ch) + path, err := wavebase.ExpandHomeDir(data.Path) + if err != nil { + ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err) + return + } + innerFilesEntries := []os.DirEntry{} + seen := 0 + if data.Opts.Limit == 0 { + data.Opts.Limit = wshrpc.MaxDirSize + } + if data.Opts.All { + fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error { + defer func() { + seen++ + }() + if seen < data.Opts.Offset { + return nil + } + if seen >= data.Opts.Offset+data.Opts.Limit { + return io.EOF + } + if err != nil { + return err + } + if d.IsDir() { + return nil + } + innerFilesEntries = append(innerFilesEntries, d) + return nil + }) + } else { + innerFilesEntries, err = os.ReadDir(path) + if err != nil { + ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err)) + return + } + } + var fileInfoArr []*wshrpc.FileInfo + for _, innerFileEntry := range innerFilesEntries { + if ctx.Err() != nil { + ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err()) + return + } + innerFileInfoInt, err := innerFileEntry.Info() + if err != nil { + log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err) + continue + } + innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false) + fileInfoArr = append(fileInfoArr, innerFileInfo) + if len(fileInfoArr) >= wshrpc.DirChunkSize { + resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr} + ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp} + fileInfoArr = nil + } + } + if len(fileInfoArr) > 0 { + resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr} + ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp} + } + }() + return ch +} + +func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo { + mimeType := fileutil.DetectMimeType(fullPath, finfo, extended) + rtn := &wshrpc.FileInfo{ + Path: wavebase.ReplaceHomeDir(fullPath), + Dir: computeDirPart(fullPath), + Name: finfo.Name(), + Size: finfo.Size(), + Mode: finfo.Mode(), + ModeStr: finfo.Mode().String(), + ModTime: finfo.ModTime().UnixMilli(), + IsDir: finfo.IsDir(), + MimeType: mimeType, + SupportsMkdir: true, + } + if finfo.IsDir() { + rtn.Size = -1 + } + return rtn +} + +// fileInfo might be null +func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool { + if !exists || fileInfo.Mode().IsDir() { + dirName := filepath.Dir(path) + randHexStr, err := utilfn.RandomHexString(12) + if err != nil { + // we're not sure, just return false + return false + } + tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr) + fd, err := os.Create(tmpFileName) + if err != nil { + return true + } + utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName) + os.Remove(tmpFileName) + return false + } + // try to open for writing, if this fails then it is read-only + file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return true + } + utilfn.GracefulClose(file, "checkIsReadOnly", path) + return false +} + +func computeDirPart(path string) string { + path = filepath.Clean(wavebase.ExpandHomeDirSafe(path)) + path = filepath.ToSlash(path) + if path == "/" { + return "/" + } + return filepath.Dir(path) +} + +func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) { + cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) + finfo, err := os.Stat(cleanedPath) + if os.IsNotExist(err) { + return &wshrpc.FileInfo{ + Path: wavebase.ReplaceHomeDir(path), + Dir: computeDirPart(path), + NotFound: true, + ReadOnly: checkIsReadOnly(cleanedPath, finfo, false), + SupportsMkdir: true, + }, nil + } + if err != nil { + return nil, fmt.Errorf("cannot stat file %q: %w", path, err) + } + rtn := statToFileInfo(cleanedPath, finfo, extended) + if extended { + rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true) + } + return rtn, nil +} + +func resolvePaths(paths []string) string { + if len(paths) == 0 { + return wavebase.ExpandHomeDirSafe("~") + } + rtnPath := wavebase.ExpandHomeDirSafe(paths[0]) + for _, path := range paths[1:] { + path = wavebase.ExpandHomeDirSafe(path) + if filepath.IsAbs(path) { + rtnPath = path + continue + } + rtnPath = filepath.Join(rtnPath, path) + } + return rtnPath +} + +func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) { + rtnPath := resolvePaths(paths) + return impl.fileInfoInternal(rtnPath, true) +} + +func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) { + return impl.fileInfoInternal(path, true) +} + +func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error { + cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) + if _, err := os.Stat(cleanedPath); err == nil { + return fmt.Errorf("file %q already exists", path) + } + if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil { + return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err) + } + if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil { + return fmt.Errorf("cannot create file %q: %w", cleanedPath, err) + } + return nil +} + +func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { + opts := data.Opts + destUri := data.DestUri + srcUri := data.SrcUri + overwrite := opts != nil && opts.Overwrite + recursive := opts != nil && opts.Recursive + + destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) + if err != nil { + return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err) + } + destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) + destinfo, err := os.Stat(destPathCleaned) + if err == nil { + if !destinfo.IsDir() { + if !overwrite { + return fmt.Errorf("destination %q already exists, use overwrite option", destUri) + } else { + err := os.Remove(destPathCleaned) + if err != nil { + return fmt.Errorf("cannot remove file %q: %w", destUri, err) + } + } + } + } else if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("cannot stat destination %q: %w", destUri, err) + } + srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) + if err != nil { + return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) + } + if srcConn.Host == destConn.Host { + srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) + finfo, err := os.Stat(srcPathCleaned) + if err != nil { + return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + if finfo.IsDir() && !recursive { + return fmt.Errorf(fstype.RecursiveRequiredError) + } + err = os.Rename(srcPathCleaned, destPathCleaned) + if err != nil { + return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err) + } + } else { + return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri) + } + return nil +} + +func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error { + cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) + if stat, err := os.Stat(cleanedPath); err == nil { + if stat.IsDir() { + return fmt.Errorf("directory %q already exists", path) + } else { + return fmt.Errorf("cannot create directory %q, file exists at path", path) + } + } + if err := os.MkdirAll(cleanedPath, 0755); err != nil { + return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err) + } + return nil +} +func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error { + var truncate, append bool + var atOffset int64 + if data.Info != nil && data.Info.Opts != nil { + truncate = data.Info.Opts.Truncate + append = data.Info.Opts.Append + } + if data.At != nil { + atOffset = data.At.Offset + } + if truncate && atOffset > 0 { + return fmt.Errorf("cannot specify non-zero offset with truncate option") + } + if append && atOffset > 0 { + return fmt.Errorf("cannot specify non-zero offset with append option") + } + path, err := wavebase.ExpandHomeDir(data.Info.Path) + if err != nil { + return err + } + createMode := os.FileMode(0644) + if data.Info != nil && data.Info.Mode > 0 { + createMode = data.Info.Mode + } + dataSize := base64.StdEncoding.DecodedLen(len(data.Data64)) + dataBytes := make([]byte, dataSize) + n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64)) + if err != nil { + return fmt.Errorf("cannot decode base64 data: %w", err) + } + finfo, err := os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("cannot stat file %q: %w", path, err) + } + fileSize := int64(0) + if finfo != nil { + fileSize = finfo.Size() + } + if atOffset > fileSize { + return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize) + } + openFlags := os.O_CREATE | os.O_WRONLY + if truncate { + openFlags |= os.O_TRUNC + } + if append { + openFlags |= os.O_APPEND + } + + file, err := os.OpenFile(path, openFlags, createMode) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", path, err) + } + defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path) + if atOffset > 0 && !append { + n, err = file.WriteAt(dataBytes[:n], atOffset) + } else { + n, err = file.Write(dataBytes[:n]) + } + if err != nil { + return fmt.Errorf("cannot write to file %q: %w", path, err) + } + return nil +} + +func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error { + expandedPath, err := wavebase.ExpandHomeDir(data.Path) + if err != nil { + return fmt.Errorf("cannot delete file %q: %w", data.Path, err) + } + cleanedPath := filepath.Clean(expandedPath) + + err = os.Remove(cleanedPath) + if err != nil { + finfo, _ := os.Stat(cleanedPath) + if finfo != nil && finfo.IsDir() { + if !data.Recursive { + return fmt.Errorf(fstype.RecursiveRequiredError) + } + err = os.RemoveAll(cleanedPath) + if err != nil { + return fmt.Errorf("cannot delete directory %q: %w", data.Path, err) + } + } else { + return fmt.Errorf("cannot delete file %q: %w", data.Path, err) + } + } + return nil +} diff --git a/pkg/wshrpc/wshremote/wshremote_job.go b/pkg/wshrpc/wshremote/wshremote_job.go new file mode 100644 index 0000000000..12545c9cc1 --- /dev/null +++ b/pkg/wshrpc/wshremote/wshremote_job.go @@ -0,0 +1,352 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshremote + +import ( + "bufio" + "context" + "fmt" + "log" + "net" + "os" + "os/exec" + "strings" + "sync" + "syscall" + "time" + + "github.com/shirou/gopsutil/v4/process" + "github.com/wavetermdev/waveterm/pkg/jobmanager" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" +) + +func isProcessRunning(pid int, pidStartTs int64) (*process.Process, error) { + if pid <= 0 { + return nil, nil + } + proc, err := process.NewProcess(int32(pid)) + if err != nil { + return nil, nil + } + createTime, err := proc.CreateTime() + if err != nil { + return nil, err + } + if createTime != pidStartTs { + return nil, nil + } + return proc, nil +} + +// returns jobRouteId, cleanupFunc, error +func (impl *ServerImpl) connectToJobManager(ctx context.Context, jobId string, mainServerJwtToken string) (string, func(), error) { + socketPath := jobmanager.GetJobSocketPath(jobId) + log.Printf("connectToJobManager: connecting to socket: %s\n", socketPath) + conn, err := net.Dial("unix", socketPath) + if err != nil { + log.Printf("connectToJobManager: error connecting to socket: %v\n", err) + return "", nil, fmt.Errorf("cannot connect to job manager socket: %w", err) + } + log.Printf("connectToJobManager: connected to socket\n") + + proxy := wshutil.MakeRpcProxy("jobmanager") + linkId := impl.Router.RegisterUntrustedLink(proxy) + + var cleanupOnce sync.Once + cleanup := func() { + cleanupOnce.Do(func() { + conn.Close() + impl.Router.UnregisterLink(linkId) + impl.removeJobManagerConnection(jobId) + }) + } + + go func() { + writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn) + if writeErr != nil { + log.Printf("connectToJobManager: error writing to job manager socket: %v\n", writeErr) + } + }() + go func() { + defer func() { + close(proxy.FromRemoteCh) + cleanup() + }() + wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh) + }() + + routeId := wshutil.MakeLinkRouteId(linkId) + authData := wshrpc.CommandAuthenticateToJobData{ + JobAccessToken: mainServerJwtToken, + } + err = wshclient.AuthenticateToJobManagerCommand(impl.RpcClient, authData, &wshrpc.RpcOpts{Route: routeId}) + if err != nil { + cleanup() + return "", nil, fmt.Errorf("authentication to job manager failed: %w", err) + } + + jobRouteId := wshutil.MakeJobRouteId(jobId) + waitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + err = impl.Router.WaitForRegister(waitCtx, jobRouteId) + if err != nil { + cleanup() + return "", nil, fmt.Errorf("timeout waiting for job route to register: %w", err) + } + + jobConn := &JobManagerConnection{ + JobId: jobId, + Conn: conn, + CleanupFn: cleanup, + } + impl.addJobManagerConnection(jobConn) + + log.Printf("connectToJobManager: successfully connected and authenticated\n") + return jobRouteId, cleanup, nil +} + +func (impl *ServerImpl) addJobManagerConnection(conn *JobManagerConnection) { + impl.Lock.Lock() + defer impl.Lock.Unlock() + impl.JobManagerMap[conn.JobId] = conn + log.Printf("addJobManagerConnection: added job manager connection for jobid=%s\n", conn.JobId) +} + +func (impl *ServerImpl) removeJobManagerConnection(jobId string) { + impl.Lock.Lock() + defer impl.Lock.Unlock() + if _, exists := impl.JobManagerMap[jobId]; exists { + delete(impl.JobManagerMap, jobId) + log.Printf("removeJobManagerConnection: removed job manager connection for jobid=%s\n", jobId) + } +} + +func (impl *ServerImpl) getJobManagerConnection(jobId string) *JobManagerConnection { + impl.Lock.Lock() + defer impl.Lock.Unlock() + return impl.JobManagerMap[jobId] +} + +func (impl *ServerImpl) RemoteStartJobCommand(ctx context.Context, data wshrpc.CommandRemoteStartJobData) (*wshrpc.CommandStartJobRtnData, error) { + log.Printf("RemoteStartJobCommand: starting, jobid=%s, clientid=%s\n", data.JobId, data.ClientId) + if impl.Router == nil { + return nil, fmt.Errorf("cannot start remote job: no router available") + } + + wshPath, err := impl.getWshPath() + if err != nil { + return nil, err + } + log.Printf("RemoteStartJobCommand: wshPath=%s\n", wshPath) + + readyPipeRead, readyPipeWrite, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("cannot create ready pipe: %w", err) + } + defer readyPipeRead.Close() + defer readyPipeWrite.Close() + + cmd := exec.Command(wshPath, "jobmanager", "--jobid", data.JobId, "--clientid", data.ClientId) + if data.PublicKeyBase64 != "" { + cmd.Env = append(os.Environ(), "WAVETERM_PUBLICKEY="+data.PublicKeyBase64) + } + cmd.ExtraFiles = []*os.File{readyPipeWrite} + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("cannot create stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("cannot create stdout pipe: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("cannot create stderr pipe: %w", err) + } + log.Printf("RemoteStartJobCommand: created pipes\n") + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("cannot start job manager: %w", err) + } + readyPipeWrite.Close() + log.Printf("RemoteStartJobCommand: job manager process started\n") + + jobAuthTokenLine := fmt.Sprintf("Wave-JobAccessToken:%s\n", data.JobAuthToken) + if _, err := stdin.Write([]byte(jobAuthTokenLine)); err != nil { + cmd.Process.Kill() + return nil, fmt.Errorf("cannot write job auth token: %w", err) + } + stdin.Close() + log.Printf("RemoteStartJobCommand: wrote auth token to stdin\n") + + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + line := scanner.Text() + log.Printf("RemoteStartJobCommand: stderr: %s\n", line) + } + if err := scanner.Err(); err != nil { + log.Printf("RemoteStartJobCommand: error reading stderr: %v\n", err) + } else { + log.Printf("RemoteStartJobCommand: stderr EOF\n") + } + }() + + go func() { + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + line := scanner.Text() + log.Printf("RemoteStartJobCommand: stdout: %s\n", line) + } + if err := scanner.Err(); err != nil { + log.Printf("RemoteStartJobCommand: error reading stdout: %v\n", err) + } else { + log.Printf("RemoteStartJobCommand: stdout EOF\n") + } + }() + + startCh := make(chan error, 1) + go func() { + scanner := bufio.NewScanner(readyPipeRead) + for scanner.Scan() { + line := scanner.Text() + log.Printf("RemoteStartJobCommand: ready pipe line: %s\n", line) + if strings.Contains(line, "Wave-JobManagerStart") { + startCh <- nil + return + } + } + if err := scanner.Err(); err != nil { + startCh <- fmt.Errorf("error reading ready pipe: %w", err) + } else { + log.Printf("RemoteStartJobCommand: ready pipe EOF\n") + startCh <- fmt.Errorf("job manager exited without start signal") + } + }() + + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + log.Printf("RemoteStartJobCommand: waiting for start signal\n") + select { + case err := <-startCh: + if err != nil { + cmd.Process.Kill() + log.Printf("RemoteStartJobCommand: error from start signal: %v\n", err) + return nil, err + } + log.Printf("RemoteStartJobCommand: received start signal\n") + case <-timeoutCtx.Done(): + cmd.Process.Kill() + log.Printf("RemoteStartJobCommand: timeout waiting for start signal\n") + return nil, fmt.Errorf("timeout waiting for job manager to start") + } + + go func() { + cmd.Wait() + }() + + jobRouteId, cleanup, err := impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken) + if err != nil { + return nil, err + } + + startJobData := wshrpc.CommandStartJobData{ + Cmd: data.Cmd, + Args: data.Args, + Env: data.Env, + TermSize: data.TermSize, + StreamMeta: data.StreamMeta, + } + rtnData, err := wshclient.StartJobCommand(impl.RpcClient, startJobData, &wshrpc.RpcOpts{Route: jobRouteId}) + if err != nil { + cleanup() + return nil, fmt.Errorf("failed to start job: %w", err) + } + + return rtnData, nil +} + +func (impl *ServerImpl) RemoteReconnectToJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteReconnectToJobManagerData) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) { + log.Printf("RemoteReconnectToJobManagerCommand: reconnecting, jobid=%s\n", data.JobId) + if impl.Router == nil { + return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{ + Success: false, + Error: "cannot reconnect to job manager: no router available", + }, nil + } + + proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs) + if err != nil { + return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{ + Success: false, + Error: fmt.Sprintf("error checking job manager process: %v", err), + }, nil + } + if proc == nil { + return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{ + Success: false, + JobManagerGone: true, + Error: fmt.Sprintf("job manager process (pid=%d) is not running", data.JobManagerPid), + }, nil + } + + existingConn := impl.getJobManagerConnection(data.JobId) + if existingConn != nil { + log.Printf("RemoteReconnectToJobManagerCommand: closing existing connection for jobid=%s\n", data.JobId) + if existingConn.CleanupFn != nil { + existingConn.CleanupFn() + } + } + + _, _, err = impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken) + if err != nil { + return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{ + Success: false, + Error: err.Error(), + }, nil + } + + log.Printf("RemoteReconnectToJobManagerCommand: successfully reconnected to job manager\n") + return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{ + Success: true, + }, nil +} + +func (impl *ServerImpl) RemoteDisconnectFromJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteDisconnectFromJobManagerData) error { + log.Printf("RemoteDisconnectFromJobManagerCommand: disconnecting, jobid=%s\n", data.JobId) + conn := impl.getJobManagerConnection(data.JobId) + if conn == nil { + log.Printf("RemoteDisconnectFromJobManagerCommand: no connection found for jobid=%s\n", data.JobId) + return nil + } + + if conn.CleanupFn != nil { + conn.CleanupFn() + log.Printf("RemoteDisconnectFromJobManagerCommand: cleanup completed for jobid=%s\n", data.JobId) + } + + return nil +} + +func (impl *ServerImpl) RemoteTerminateJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteTerminateJobManagerData) error { + log.Printf("RemoteTerminateJobManagerCommand: terminating job manager, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid) + proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs) + if err != nil { + return fmt.Errorf("error checking job manager process: %w", err) + } + if proc == nil { + log.Printf("RemoteTerminateJobManagerCommand: job manager process not running, jobid=%s\n", data.JobId) + return nil + } + err = proc.SendSignal(syscall.SIGTERM) + if err != nil { + log.Printf("failed to send SIGTERM to job manager: %v", err) + } else { + log.Printf("RemoteTerminateJobManagerCommand: sent SIGTERM to job manager process, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid) + } + return nil +} diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index f36bd8fa8d..c0d8d1214b 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -22,10 +22,18 @@ type RespOrErrorUnion[T any] struct { Error error } +// Instructions for adding a new RPC call +// * methods must end with Command +// * methods must take context as their first parameter +// * methods may take up to one parameter, and may return either just an error, or one return value plus an error +// * after modifying WshRpcInterface, run `task generate` to regnerate bindings + type WshRpcInterface interface { AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error) AuthenticateTokenCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) AuthenticateTokenVerifyCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) // (special) validates token without binding, root router only + AuthenticateJobManagerCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error + AuthenticateJobManagerVerifyCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error // (special) validates job auth token without binding, root router only DisposeCommand(ctx context.Context, data CommandDisposeData) error RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router @@ -100,6 +108,10 @@ type WshRpcInterface interface { RemoteStreamCpuDataCommand(ctx context.Context) chan RespOrErrorUnion[TimeSeriesData] RemoteGetInfoCommand(ctx context.Context) (RemoteInfo, error) RemoteInstallRcFilesCommand(ctx context.Context) error + RemoteStartJobCommand(ctx context.Context, data CommandRemoteStartJobData) (*CommandStartJobRtnData, error) + RemoteReconnectToJobManagerCommand(ctx context.Context, data CommandRemoteReconnectToJobManagerData) (*CommandRemoteReconnectToJobManagerRtnData, error) + RemoteDisconnectFromJobManagerCommand(ctx context.Context, data CommandRemoteDisconnectFromJobManagerData) error + RemoteTerminateJobManagerCommand(ctx context.Context, data CommandRemoteTerminateJobManagerData) error // emain WebSelectorCommand(ctx context.Context, data CommandWebSelectorData) ([]string, error) @@ -140,6 +152,7 @@ type WshRpcInterface interface { // terminal TermGetScrollbackLinesCommand(ctx context.Context, data CommandTermGetScrollbackLinesData) (*CommandTermGetScrollbackLinesRtnData, error) + TermUpdateAttachedJobCommand(ctx context.Context, data CommandTermUpdateAttachedJobData) error // file WshRpcFileInterface @@ -154,6 +167,26 @@ type WshRpcInterface interface { // streams StreamDataCommand(ctx context.Context, data CommandStreamData) error StreamDataAckCommand(ctx context.Context, data CommandStreamAckData) error + + // jobs + AuthenticateToJobManagerCommand(ctx context.Context, data CommandAuthenticateToJobData) error + StartJobCommand(ctx context.Context, data CommandStartJobData) (*CommandStartJobRtnData, error) + JobPrepareConnectCommand(ctx context.Context, data CommandJobPrepareConnectData) (*CommandJobConnectRtnData, error) + JobStartStreamCommand(ctx context.Context, data CommandJobStartStreamData) error + JobInputCommand(ctx context.Context, data CommandJobInputData) error + JobCmdExitedCommand(ctx context.Context, data CommandJobCmdExitedData) error // this is sent FROM the job manager => main server + + // job controller + JobControllerDeleteJobCommand(ctx context.Context, jobId string) error + JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error) + JobControllerStartJobCommand(ctx context.Context, data CommandJobControllerStartJobData) (string, error) + JobControllerExitJobCommand(ctx context.Context, jobId string) error + JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error + JobControllerReconnectJobCommand(ctx context.Context, jobId string) error + JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error + JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error) + JobControllerAttachJobCommand(ctx context.Context, data CommandJobControllerAttachJobData) error + JobControllerDetachJobCommand(ctx context.Context, jobId string) error } // for frontend @@ -250,6 +283,13 @@ type CommandBlockInputData struct { TermSize *waveobj.TermSize `json:"termsize,omitempty"` } +type CommandJobInputData struct { + JobId string `json:"jobid"` + InputData64 string `json:"inputdata64,omitempty"` + SigName string `json:"signame,omitempty"` + TermSize *waveobj.TermSize `json:"termsize,omitempty"` +} + type CommandWaitForRouteData struct { RouteId string `json:"routeid"` WaitMs int `json:"waitms"` @@ -614,6 +654,11 @@ type CommandTermGetScrollbackLinesRtnData struct { LastUpdated int64 `json:"lastupdated"` } +type CommandTermUpdateAttachedJobData struct { + BlockId string `json:"blockid"` + JobId string `json:"jobid,omitempty"` +} + type CommandElectronEncryptData struct { PlainText string `json:"plaintext"` } @@ -633,7 +678,7 @@ type CommandElectronDecryptRtnData struct { } type CommandStreamData struct { - Id int64 `json:"id"` // streamid + Id string `json:"id"` // streamid Seq int64 `json:"seq"` // start offset (bytes) Data64 string `json:"data64,omitempty"` Eof bool `json:"eof,omitempty"` // can be set with data or without @@ -641,7 +686,7 @@ type CommandStreamData struct { } type CommandStreamAckData struct { - Id int64 `json:"id"` // streamid + Id string `json:"id"` // streamid Seq int64 `json:"seq"` // next expected byte RWnd int64 `json:"rwnd"` // receive window size Fin bool `json:"fin,omitempty"` // observed end-of-stream (eof or error) @@ -651,8 +696,108 @@ type CommandStreamAckData struct { } type StreamMeta struct { - Id int64 `json:"id"` // streamid + Id string `json:"id"` // streamid RWnd int64 `json:"rwnd"` // initial receive window size ReaderRouteId string `json:"readerrouteid"` WriterRouteId string `json:"writerrouteid"` } + +type CommandAuthenticateToJobData struct { + JobAccessToken string `json:"jobaccesstoken"` +} + +type CommandAuthenticateJobManagerData struct { + JobId string `json:"jobid"` + JobAuthToken string `json:"jobauthtoken"` +} + +type CommandStartJobData struct { + Cmd string `json:"cmd"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + TermSize waveobj.TermSize `json:"termsize"` + StreamMeta *StreamMeta `json:"streammeta,omitempty"` +} + +type CommandRemoteStartJobData struct { + Cmd string `json:"cmd"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + TermSize waveobj.TermSize `json:"termsize"` + StreamMeta *StreamMeta `json:"streammeta,omitempty"` + JobAuthToken string `json:"jobauthtoken"` + JobId string `json:"jobid"` + MainServerJwtToken string `json:"mainserverjwttoken"` + ClientId string `json:"clientid"` + PublicKeyBase64 string `json:"publickeybase64"` +} + +type CommandRemoteReconnectToJobManagerData struct { + JobId string `json:"jobid"` + JobAuthToken string `json:"jobauthtoken"` + MainServerJwtToken string `json:"mainserverjwttoken"` + JobManagerPid int `json:"jobmanagerpid"` + JobManagerStartTs int64 `json:"jobmanagerstartts"` +} + +type CommandRemoteReconnectToJobManagerRtnData struct { + Success bool `json:"success"` + JobManagerGone bool `json:"jobmanagergone"` + Error string `json:"error,omitempty"` +} + +type CommandRemoteDisconnectFromJobManagerData struct { + JobId string `json:"jobid"` +} + +type CommandRemoteTerminateJobManagerData struct { + JobId string `json:"jobid"` + JobManagerPid int `json:"jobmanagerpid"` + JobManagerStartTs int64 `json:"jobmanagerstartts"` +} + +type CommandStartJobRtnData struct { + CmdPid int `json:"cmdpid"` + CmdStartTs int64 `json:"cmdstartts"` + JobManagerPid int `json:"jobmanagerpid"` + JobManagerStartTs int64 `json:"jobmanagerstartts"` +} + +type CommandJobPrepareConnectData struct { + StreamMeta StreamMeta `json:"streammeta"` + Seq int64 `json:"seq"` +} + +type CommandJobStartStreamData struct { +} + +type CommandJobConnectRtnData struct { + Seq int64 `json:"seq"` + StreamDone bool `json:"streamdone,omitempty"` + StreamError string `json:"streamerror,omitempty"` + HasExited bool `json:"hasexited,omitempty"` + ExitCode *int `json:"exitcode,omitempty"` + ExitSignal string `json:"exitsignal,omitempty"` + ExitErr string `json:"exiterr,omitempty"` +} + +type CommandJobCmdExitedData struct { + JobId string `json:"jobid"` + ExitCode *int `json:"exitcode,omitempty"` + ExitSignal string `json:"exitsignal,omitempty"` + ExitErr string `json:"exiterr,omitempty"` + ExitTs int64 `json:"exitts,omitempty"` +} + +type CommandJobControllerStartJobData struct { + ConnName string `json:"connname"` + Cmd string `json:"cmd"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + TermSize *waveobj.TermSize `json:"termsize,omitempty"` +} + +type CommandJobControllerAttachJobData struct { + JobId string `json:"jobid"` + BlockId string `json:"blockid"` +} diff --git a/pkg/wshrpc/wshrpctypes_const.go b/pkg/wshrpc/wshrpctypes_const.go index 5133b40346..51a25f147c 100644 --- a/pkg/wshrpc/wshrpctypes_const.go +++ b/pkg/wshrpc/wshrpctypes_const.go @@ -35,13 +35,16 @@ const ( // we only need consts for special commands handled in the router or // in the RPC code / WPS code directly. other commands go through the clients const ( - Command_Authenticate = "authenticate" // $control - Command_AuthenticateToken = "authenticatetoken" // $control - Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only) - Command_RouteAnnounce = "routeannounce" // $control (for routing) - Command_RouteUnannounce = "routeunannounce" // $control (for routing) - Command_Ping = "ping" // $control - Command_ControllerInput = "controllerinput" - Command_EventRecv = "eventrecv" - Command_Message = "message" + Command_Authenticate = "authenticate" // $control + Command_AuthenticateToken = "authenticatetoken" // $control + Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only) + Command_AuthenticateJobManagerVerify = "authenticatejobmanagerverify" // $control:root (internal, for job auth token validation only) + Command_RouteAnnounce = "routeannounce" // $control (for routing) + Command_RouteUnannounce = "routeunannounce" // $control (for routing) + Command_Ping = "ping" // $control + Command_ControllerInput = "controllerinput" + Command_EventRecv = "eventrecv" + Command_Message = "message" + Command_StreamData = "streamdata" + Command_StreamDataAck = "streamdataack" ) diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 9e447dd5f3..6446b5ed25 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -29,6 +29,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/filebackup" "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/genconn" + "github.com/wavetermdev/waveterm/pkg/jobcontroller" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote/awsconn" @@ -294,6 +295,21 @@ func (ws *WshServer) ControllerResyncCommand(ctx context.Context, data wshrpc.Co } func (ws *WshServer) ControllerInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error { + block, err := wstore.DBMustGet[*waveobj.Block](ctx, data.BlockId) + if err != nil { + return fmt.Errorf("error getting block: %w", err) + } + + if block.JobId != "" { + jobInputData := wshrpc.CommandJobInputData{ + JobId: block.JobId, + InputData64: data.InputData64, + SigName: data.SigName, + TermSize: data.TermSize, + } + return jobcontroller.SendInput(ctx, jobInputData) + } + inputUnion := &blockcontroller.BlockInputUnion{ SigName: data.SigName, TermSize: data.TermSize, @@ -1430,3 +1446,54 @@ func (ws *WshServer) GetSecretsLinuxStorageBackendCommand(ctx context.Context) ( } return backend, nil } + +func (ws *WshServer) JobCmdExitedCommand(ctx context.Context, data wshrpc.CommandJobCmdExitedData) error { + return jobcontroller.HandleCmdJobExited(ctx, data.JobId, data) +} + +func (ws *WshServer) JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error) { + return wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job) +} + +func (ws *WshServer) JobControllerDeleteJobCommand(ctx context.Context, jobId string) error { + return jobcontroller.DeleteJob(ctx, jobId) +} + +func (ws *WshServer) JobControllerStartJobCommand(ctx context.Context, data wshrpc.CommandJobControllerStartJobData) (string, error) { + params := jobcontroller.StartJobParams{ + ConnName: data.ConnName, + Cmd: data.Cmd, + Args: data.Args, + Env: data.Env, + TermSize: data.TermSize, + } + return jobcontroller.StartJob(ctx, params) +} + +func (ws *WshServer) JobControllerExitJobCommand(ctx context.Context, jobId string) error { + return jobcontroller.TerminateJobManager(ctx, jobId) +} + +func (ws *WshServer) JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error { + return jobcontroller.DisconnectJob(ctx, jobId) +} + +func (ws *WshServer) JobControllerReconnectJobCommand(ctx context.Context, jobId string) error { + return jobcontroller.ReconnectJob(ctx, jobId) +} + +func (ws *WshServer) JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error { + return jobcontroller.ReconnectJobsForConn(ctx, connName) +} + +func (ws *WshServer) JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error) { + return jobcontroller.GetConnectedJobIds(), nil +} + +func (ws *WshServer) JobControllerAttachJobCommand(ctx context.Context, data wshrpc.CommandJobControllerAttachJobData) error { + return jobcontroller.AttachJobToBlock(ctx, data.JobId, data.BlockId) +} + +func (ws *WshServer) JobControllerDetachJobCommand(ctx context.Context, jobId string) error { + return jobcontroller.DetachJobFromBlock(ctx, jobId, true) +} diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index cbc2f47ab3..32c889756b 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "log" + "strconv" "strings" "sync" "time" @@ -34,6 +35,9 @@ const ( RoutePrefix_Tab = "tab:" RoutePrefix_FeBlock = "feblock:" RoutePrefix_Builder = "builder:" + RoutePrefix_Link = "link:" + RoutePrefix_Job = "job:" + RoutePrefix_Bare = "bare:" ) // this works like a network switch @@ -118,6 +122,14 @@ func MakeBuilderRouteId(builderId string) string { return "builder:" + builderId } +func MakeJobRouteId(jobId string) string { + return "job:" + jobId +} + +func MakeLinkRouteId(linkId baseds.LinkId) string { + return fmt.Sprintf("%s%d", RoutePrefix_Link, linkId) +} + var DefaultRouter *WshRouter func NewWshRouter() *WshRouter { @@ -245,6 +257,13 @@ func (router *WshRouter) getRouteInfo(rpcId string) *rpcRoutingInfo { // returns true if message was sent, false if failed func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string, commandName string, ingressLinkId baseds.LinkId) bool { + if strings.HasPrefix(routeId, RoutePrefix_Link) { + linkIdStr := strings.TrimPrefix(routeId, RoutePrefix_Link) + linkIdInt, err := strconv.ParseInt(linkIdStr, 10, 32) + if err == nil { + return router.sendMessageToLink(msgBytes, baseds.LinkId(linkIdInt), ingressLinkId) + } + } lm := router.getLinkForRoute(routeId) if lm != nil { lm.client.SendRpcMessage(msgBytes, ingressLinkId, "route") @@ -448,8 +467,10 @@ func (router *WshRouter) runLinkClientRecvLoop(linkId baseds.LinkId, client Abst } else { // non-request messages (responses) if !lm.trusted { - // drop responses from untrusted links - continue + // allow responses to RPCs we initiated + if rpcMsg.ResId == "" || router.getRouteInfo(rpcMsg.ResId) == nil { + continue + } } } router.inputCh <- baseds.RpcInputChType{MsgBytes: msgBytes, IngressLinkId: linkId} @@ -596,7 +617,7 @@ func (router *WshRouter) UnregisterLink(linkId baseds.LinkId) { } func isBindableRouteId(routeId string) bool { - if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) { + if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) || strings.HasPrefix(routeId, RoutePrefix_Link) { return false } return true @@ -676,6 +697,9 @@ func (router *WshRouter) bindRoute(linkId baseds.LinkId, routeId string, isSourc if !strings.HasPrefix(routeId, ControlPrefix) { router.announceUpstream(routeId) } + if router.IsRootRouter() { + router.publishRouteToBroker(routeId) + } return nil } @@ -692,12 +716,19 @@ func (router *WshRouter) getUpstreamClient() AbstractRpcClient { return lm.client } +func (router *WshRouter) publishRouteToBroker(routeId string) { + defer func() { + panichandler.PanicHandler("WshRouter:publishRouteToBroker", recover()) + }() + wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteUp, Scopes: []string{routeId}}) +} + func (router *WshRouter) unsubscribeFromBroker(routeId string) { defer func() { - panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover()) + panichandler.PanicHandler("WshRouter:unregisterRoute:routedown", recover()) }() wps.Broker.UnsubscribeAll(routeId) - wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}}) + wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteDown, Scopes: []string{routeId}}) } func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta) { diff --git a/pkg/wshutil/wshrouter_controlimpl.go b/pkg/wshutil/wshrouter_controlimpl.go index f6f557eabc..0cc29ca2f9 100644 --- a/pkg/wshutil/wshrouter_controlimpl.go +++ b/pkg/wshutil/wshrouter_controlimpl.go @@ -11,7 +11,9 @@ import ( "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wstore" ) type WshRouterControlImpl struct { @@ -102,6 +104,46 @@ func (impl *WshRouterControlImpl) AuthenticateCommand(ctx context.Context, data return rtnData, nil } +func extractTokenData(token string) (wshrpc.CommandAuthenticateRtnData, error) { + entry := shellutil.GetAndRemoveTokenSwapEntry(token) + if entry == nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found") + } + _, err := validateRpcContextFromAuth(entry.RpcContext) + if err != nil { + return wshrpc.CommandAuthenticateRtnData{}, err + } + if entry.RpcContext.IsRouter { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") + } + if entry.RpcContext.RouteId == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") + } + return wshrpc.CommandAuthenticateRtnData{ + Env: entry.Env, + InitScriptText: entry.ScriptText, + RpcContext: entry.RpcContext, + }, nil +} + +func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { + if !impl.Router.IsRootRouter() { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router") + } + if data.Token == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message") + } + + rtnData, err := extractTokenData(data.Token) + if err != nil { + log.Printf("wshrouter authenticate-token-verify error: %v", err) + return wshrpc.CommandAuthenticateRtnData{}, err + } + + log.Printf("wshrouter authenticate-token-verify success routeid=%q", rtnData.RpcContext.RouteId) + return rtnData, nil +} + func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { handler := GetRpcResponseHandlerFromContext(ctx) if handler == nil { @@ -117,29 +159,14 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, } var rtnData wshrpc.CommandAuthenticateRtnData - var rpcContext *wshrpc.RpcContext + var err error + if impl.Router.IsRootRouter() { - entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token) - if entry == nil { - log.Printf("wshrouter authenticate-token error linkid=%d: no token entry found", linkId) - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found") - } - _, err := validateRpcContextFromAuth(entry.RpcContext) + rtnData, err = extractTokenData(data.Token) if err != nil { + log.Printf("wshrouter authenticate-token error linkid=%d: %v", linkId, err) return wshrpc.CommandAuthenticateRtnData{}, err } - if entry.RpcContext.IsRouter { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") - } - if entry.RpcContext.RouteId == "" { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") - } - rpcContext = entry.RpcContext - rtnData = wshrpc.CommandAuthenticateRtnData{ - Env: entry.Env, - InitScriptText: entry.ScriptText, - RpcContext: rpcContext, - } } else { wshRpc := GetWshRpcFromContext(ctx) if wshRpc == nil { @@ -154,51 +181,91 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, if err != nil { return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("failed to unmarshal response: %w", err) } - rpcContext = rtnData.RpcContext } - if rpcContext == nil { + if rtnData.RpcContext == nil { return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no rpccontext in token response") } - log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rpcContext.RouteId) + log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rtnData.RpcContext.RouteId) impl.Router.trustLink(linkId, LinkKind_Leaf) - impl.Router.bindRoute(linkId, rpcContext.RouteId, true) + impl.Router.bindRoute(linkId, rtnData.RpcContext.RouteId, true) return rtnData, nil } -func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { +func (impl *WshRouterControlImpl) AuthenticateJobManagerVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error { if !impl.Router.IsRootRouter() { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router") + return fmt.Errorf("authenticatejobmanagerverify can only be called on root router") } - if data.Token == "" { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message") + if data.JobId == "" { + return fmt.Errorf("no jobid in authenticatejobmanager message") } - entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token) - if entry == nil { - log.Printf("wshrouter authenticate-token-verify error: no token entry found") - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found") + if data.JobAuthToken == "" { + return fmt.Errorf("no jobauthtoken in authenticatejobmanager message") } - _, err := validateRpcContextFromAuth(entry.RpcContext) + + job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId) if err != nil { - return wshrpc.CommandAuthenticateRtnData{}, err + log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: failed to get job: %v", data.JobId, err) + return fmt.Errorf("failed to get job: %w", err) } - if entry.RpcContext.IsRouter { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") + + if job.JobAuthToken != data.JobAuthToken { + log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: invalid jobauthtoken", data.JobId) + return fmt.Errorf("invalid jobauthtoken") } - if entry.RpcContext.RouteId == "" { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") + + log.Printf("wshrouter authenticate-jobmanager-verify success jobid=%q", data.JobId) + return nil +} + +func (impl *WshRouterControlImpl) AuthenticateJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error { + handler := GetRpcResponseHandlerFromContext(ctx) + if handler == nil { + return fmt.Errorf("no response handler in context") + } + linkId := handler.GetIngressLinkId() + if linkId == baseds.NoLinkId { + return fmt.Errorf("no ingress link found") } - rtnData := wshrpc.CommandAuthenticateRtnData{ - Env: entry.Env, - InitScriptText: entry.ScriptText, - RpcContext: entry.RpcContext, + if data.JobId == "" { + return fmt.Errorf("no jobid in authenticatejobmanager message") + } + if data.JobAuthToken == "" { + return fmt.Errorf("no jobauthtoken in authenticatejobmanager message") } - log.Printf("wshrouter authenticate-token-verify success routeid=%q", entry.RpcContext.RouteId) - return rtnData, nil + if impl.Router.IsRootRouter() { + job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId) + if err != nil { + log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to get job: %v", linkId, data.JobId, err) + return fmt.Errorf("failed to get job: %w", err) + } + + if job.JobAuthToken != data.JobAuthToken { + log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: invalid jobauthtoken", linkId, data.JobId) + return fmt.Errorf("invalid jobauthtoken") + } + } else { + wshRpc := GetWshRpcFromContext(ctx) + if wshRpc == nil { + return fmt.Errorf("no wshrpc in context") + } + _, err := wshRpc.SendRpcRequest(wshrpc.Command_AuthenticateJobManagerVerify, data, &wshrpc.RpcOpts{Route: ControlRootRoute}) + if err != nil { + log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to verify job auth token: %v", linkId, data.JobId, err) + return fmt.Errorf("failed to verify job auth token: %w", err) + } + } + + routeId := MakeJobRouteId(data.JobId) + log.Printf("wshrouter authenticate-jobmanager success linkid=%d jobid=%q routeid=%q", linkId, data.JobId, routeId) + impl.Router.trustLink(linkId, LinkKind_Leaf) + impl.Router.bindRoute(linkId, routeId, true) + + return nil } func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) { diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 7d94777193..eb2903c1f7 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -18,6 +18,7 @@ import ( "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/streamclient" "github.com/wavetermdev/waveterm/pkg/util/ds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wps" @@ -56,6 +57,7 @@ type WshRpc struct { ServerImpl ServerImpl EventListener *EventListener ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler + StreamBroker *streamclient.Broker Debug bool DebugName string ServerDone bool @@ -226,6 +228,7 @@ func MakeWshRpcWithChannels(inputCh chan baseds.RpcInputChType, outputCh chan [] ResponseHandlerMap: make(map[string]*RpcResponseHandler), } rtn.RpcContext.Store(&rpcCtx) + rtn.StreamBroker = streamclient.NewBroker(AdaptWshRpc(rtn)) go rtn.runServer() return rtn } @@ -286,6 +289,36 @@ func (w *WshRpc) handleEventRecv(req *RpcMessage) { w.EventListener.RecvEvent(&waveEvent) } +func (w *WshRpc) handleStreamData(req *RpcMessage) { + if w.StreamBroker == nil { + return + } + if req.Data == nil { + return + } + var dataPk wshrpc.CommandStreamData + err := utilfn.ReUnmarshal(&dataPk, req.Data) + if err != nil { + return + } + w.StreamBroker.RecvData(dataPk) +} + +func (w *WshRpc) handleStreamAck(req *RpcMessage) { + if w.StreamBroker == nil { + return + } + if req.Data == nil { + return + } + var ackPk wshrpc.CommandStreamAckData + err := utilfn.ReUnmarshal(&ackPk, req.Data) + if err != nil { + return + } + w.StreamBroker.RecvAck(ackPk) +} + func (w *WshRpc) handleRequestInternal(req *RpcMessage, ingressLinkId baseds.LinkId, pprofCtx context.Context) { if req.Command == wshrpc.Command_EventRecv { w.handleEventRecv(req) @@ -381,6 +414,17 @@ outer: continue } if msg.IsRpcRequest() { + // Handle stream commands synchronously since the broker is designed to be non-blocking. + // RecvData/RecvAck just enqueue to work queues, so there's no risk of blocking the main loop. + if msg.Command == wshrpc.Command_StreamData { + w.handleStreamData(&msg) + continue + } + if msg.Command == wshrpc.Command_StreamDataAck { + w.handleStreamAck(&msg) + continue + } + ingressLinkId := inputVal.IngressLinkId go func() { defer func() { diff --git a/pkg/wshutil/wshstreamadapter.go b/pkg/wshutil/wshstreamadapter.go new file mode 100644 index 0000000000..b83d1c727c --- /dev/null +++ b/pkg/wshutil/wshstreamadapter.go @@ -0,0 +1,24 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshutil + +import ( + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type WshRpcStreamClientAdapter struct { + rpc *WshRpc +} + +func (a *WshRpcStreamClientAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error { + return a.rpc.SendCommand("streamdataack", data, opts) +} + +func (a *WshRpcStreamClientAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error { + return a.rpc.SendCommand("streamdata", data, opts) +} + +func AdaptWshRpc(rpc *WshRpc) *WshRpcStreamClientAdapter { + return &WshRpcStreamClientAdapter{rpc: rpc} +} diff --git a/pkg/wstore/wstore_dbops.go b/pkg/wstore/wstore_dbops.go index 6b64b3e474..e9a0289ee3 100644 --- a/pkg/wstore/wstore_dbops.go +++ b/pkg/wstore/wstore_dbops.go @@ -317,6 +317,31 @@ func DBUpdate(ctx context.Context, val waveobj.WaveObj) error { }) } +func DBUpdateFn[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T)) error { + return WithTx(ctx, func(tx *TxWrap) error { + val, err := DBMustGet[T](tx.Context(), id) + if err != nil { + return err + } + updateFn(val) + return DBUpdate(tx.Context(), val) + }) +} + +func DBUpdateFnErr[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T) error) error { + return WithTx(ctx, func(tx *TxWrap) error { + val, err := DBMustGet[T](tx.Context(), id) + if err != nil { + return err + } + err = updateFn(val) + if err != nil { + return err + } + return DBUpdate(tx.Context(), val) + }) +} + func DBInsert(ctx context.Context, val waveobj.WaveObj) error { oid := waveobj.GetOID(val) if oid == "" {