diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index c94abafb..cc330c6b 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,7 +40,7 @@ type Server struct { srv *http.Server mu sync.RWMutex logger *slog.Logger - conversation *st.Conversation + conversation *st.PTYConversation agentio *termexec.Process agentType mf.AgentType emitter *EventEmitter @@ -237,7 +237,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - conversation := st.NewConversation(ctx, st.ConversationConfig{ + conversation := st.NewPTY(ctx, st.PTYConversationConfig{ AgentType: config.AgentType, AgentIO: config.Process, GetTime: func() time.Time { @@ -331,7 +331,7 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { } func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) + s.conversation.Start(ctx) go func() { for { currentStatus := s.conversation.Status() @@ -339,7 +339,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { // Send initial prompt when agent becomes stable for the first time if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { s.logger.Error("Failed to send initial prompt", "error", err) } else { s.conversation.InitialPromptSent = true @@ -350,7 +350,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) + s.emitter.UpdateScreenAndEmitChanges(s.conversation.String()) time.Sleep(snapshotInterval) } }() @@ -449,7 +449,7 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes switch input.Body.Type { case MessageTypeUser: - if err := s.conversation.SendMessage(FormatMessage(s.agentType, input.Body.Content)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { return nil, xerrors.Errorf("failed to send message: %w", err) } case MessageTypeRaw: diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 97a74722..db8d82d1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,55 +2,27 @@ package screentracker import ( "context" - "fmt" - "log/slog" - "strings" - "sync" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/util" "github.com/danielgtaylor/huma/v2" "golang.org/x/xerrors" ) -type screenSnapshot struct { - timestamp time.Time - screen string -} - -type AgentIO interface { - Write(data []byte) (int, error) - ReadScreen() string -} +type ConversationStatus string -type ConversationConfig struct { - AgentType msgfmt.AgentType - AgentIO AgentIO - // GetTime returns the current time - GetTime func() time.Time - // How often to take a snapshot for the stability check - SnapshotInterval time.Duration - // How long the screen should not change to be considered stable - ScreenStabilityLength time.Duration - // Function to format the messages received from the agent - // userInput is the last user message - FormatMessage func(message string, userInput string) string - // SkipWritingMessage skips the writing of a message to the agent. - // This is used in tests - SkipWritingMessage bool - // SkipSendMessageStatusCheck skips the check for whether the message can be sent. - // This is used in tests - SkipSendMessageStatusCheck bool - // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt - ReadyForInitialPrompt func(message string) bool - // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls - FormatToolCall func(message string) (string, []string) - Logger *slog.Logger -} +const ( + ConversationStatusChanging ConversationStatus = "changing" + ConversationStatusStable ConversationStatus = "stable" + ConversationStatusInitializing ConversationStatus = "initializing" +) type ConversationRole string +func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { + return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) +} + const ( ConversationRoleUser ConversationRole = "user" ConversationRoleAgent ConversationRole = "agent" @@ -61,207 +33,15 @@ var ConversationRoleValues = []ConversationRole{ ConversationRoleAgent, } -func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { - return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) -} - -type ConversationMessage struct { - Id int - Message string - Role ConversationRole - Time time.Time -} - -type Conversation struct { - cfg ConversationConfig - // How many stable snapshots are required to consider the screen stable - stableSnapshotsThreshold int - snapshotBuffer *RingBuffer[screenSnapshot] - messages []ConversationMessage - screenBeforeLastUserMessage string - lock sync.Mutex - // InitialPrompt is the initial prompt passed to the agent - InitialPrompt string - // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents - InitialPromptSent bool - // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt - ReadyForInitialPrompt bool - // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message - toolCallMessageSet map[string]bool -} - -type ConversationStatus string - -const ( - ConversationStatusChanging ConversationStatus = "changing" - ConversationStatusStable ConversationStatus = "stable" - ConversationStatusInitializing ConversationStatus = "initializing" +var ( + MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") + MessageValidationErrorEmpty = xerrors.New("message must not be empty") + MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") ) -func getStableSnapshotsThreshold(cfg ConversationConfig) int { - length := cfg.ScreenStabilityLength.Milliseconds() - interval := cfg.SnapshotInterval.Milliseconds() - threshold := int(length / interval) - if length%interval != 0 { - threshold++ - } - return threshold + 1 -} - -func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation { - threshold := getStableSnapshotsThreshold(cfg) - c := &Conversation{ - cfg: cfg, - stableSnapshotsThreshold: threshold, - snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), - messages: []ConversationMessage{ - { - Message: "", - Role: ConversationRoleAgent, - Time: cfg.GetTime(), - }, - }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), - } - return c -} - -func (c *Conversation) StartSnapshotLoop(ctx context.Context) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(c.cfg.SnapshotInterval): - // It's important that we hold the lock while reading the screen. - // There's a race condition that occurs without it: - // 1. The screen is read - // 2. Independently, SendMessage is called and takes the lock. - // 3. AddSnapshot is called and waits on the lock. - // 4. SendMessage modifies the terminal state, releases the lock - // 5. AddSnapshot adds a snapshot from a stale screen - c.lock.Lock() - screen := c.cfg.AgentIO.ReadScreen() - c.addSnapshotInner(screen) - c.lock.Unlock() - } - } - }() -} - -func FindNewMessage(oldScreen, newScreen string, agentType msgfmt.AgentType) string { - oldLines := strings.Split(oldScreen, "\n") - newLines := strings.Split(newScreen, "\n") - oldLinesMap := make(map[string]bool) - - // -1 indicates no header - dynamicHeaderEnd := -1 - - // Skip header lines for Opencode agent type to avoid false positives - // The header contains dynamic content (token count, context percentage, cost) - // that changes between screens, causing line comparison mismatches: - // - // ┃ # Getting Started with Claude CLI ┃ - // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ - if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { - dynamicHeaderEnd = 2 - } - - for _, line := range oldLines { - oldLinesMap[line] = true - } - firstNonMatchingLine := len(newLines) - for i, line := range newLines[dynamicHeaderEnd+1:] { - if !oldLinesMap[line] { - firstNonMatchingLine = i - break - } - } - newSectionLines := newLines[firstNonMatchingLine:] - - // remove leading and trailing lines which are empty or have only whitespace - startLine := 0 - endLine := len(newSectionLines) - 1 - for i := 0; i < len(newSectionLines); i++ { - if strings.TrimSpace(newSectionLines[i]) != "" { - startLine = i - break - } - } - for i := len(newSectionLines) - 1; i >= 0; i-- { - if strings.TrimSpace(newSectionLines[i]) != "" { - endLine = i - break - } - } - return strings.Join(newSectionLines[startLine:endLine+1], "\n") -} - -func (c *Conversation) lastMessage(role ConversationRole) ConversationMessage { - for i := len(c.messages) - 1; i >= 0; i-- { - if c.messages[i].Role == role { - return c.messages[i] - } - } - return ConversationMessage{} -} - -// This function assumes that the caller holds the lock -func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time) { - agentMessage := FindNewMessage(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) - lastUserMessage := c.lastMessage(ConversationRoleUser) - var toolCalls []string - if c.cfg.FormatMessage != nil { - agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) - } - if c.cfg.FormatToolCall != nil { - agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) - } - for _, toolCall := range toolCalls { - if c.toolCallMessageSet[toolCall] == false { - c.toolCallMessageSet[toolCall] = true - c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) - } - } - shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser - lastAgentMessage := c.lastMessage(ConversationRoleAgent) - if lastAgentMessage.Message == agentMessage { - return - } - conversationMessage := ConversationMessage{ - Message: agentMessage, - Role: ConversationRoleAgent, - Time: timestamp, - } - if shouldCreateNewMessage { - c.messages = append(c.messages, conversationMessage) - - // Cleanup - c.toolCallMessageSet = make(map[string]bool) - - } else { - c.messages[len(c.messages)-1] = conversationMessage - } - c.messages[len(c.messages)-1].Id = len(c.messages) - 1 -} - -// assumes the caller holds the lock -func (c *Conversation) addSnapshotInner(screen string) { - snapshot := screenSnapshot{ - timestamp: c.cfg.GetTime(), - screen: screen, - } - c.snapshotBuffer.Add(snapshot) - c.updateLastAgentMessage(screen, snapshot.timestamp) -} - -func (c *Conversation) AddSnapshot(screen string) { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSnapshotInner(screen) +type AgentIO interface { + Write(data []byte) (int, error) + ReadScreen() string } type MessagePart interface { @@ -269,198 +49,18 @@ type MessagePart interface { String() string } -type MessagePartText struct { - Content string - Alias string - Hidden bool -} - -func (p MessagePartText) Do(writer AgentIO) error { - _, err := writer.Write([]byte(p.Content)) - return err -} - -func (p MessagePartText) String() string { - if p.Hidden { - return "" - } - if p.Alias != "" { - return p.Alias - } - return p.Content -} - -func PartsToString(parts ...MessagePart) string { - var sb strings.Builder - for _, part := range parts { - sb.WriteString(part.String()) - } - return sb.String() -} - -func ExecuteParts(writer AgentIO, parts ...MessagePart) error { - for _, part := range parts { - if err := part.Do(writer); err != nil { - return xerrors.Errorf("failed to write message part: %w", err) - } - } - return nil -} - -func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, messageParts ...MessagePart) error { - if c.cfg.SkipWritingMessage { - return nil - } - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - if err := ExecuteParts(c.cfg.AgentIO, messageParts...); err != nil { - return xerrors.Errorf("failed to write message: %w", err) - } - // wait for the screen to stabilize after the message is written - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 50 * time.Millisecond, - InitialWait: true, - }, func() (bool, error) { - screen := c.cfg.AgentIO.ReadScreen() - if screen != screenBeforeMessage { - time.Sleep(1 * time.Second) - newScreen := c.cfg.AgentIO.ReadScreen() - return newScreen == screen, nil - } - return false, nil - }); err != nil { - return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) - } - - // wait for the screen to change after the carriage return is written - screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() - lastCarriageReturnTime := time.Time{} - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 25 * time.Millisecond, - }, func() (bool, error) { - // we don't want to spam additional carriage returns because the agent may process them - // (aider does this), but we do want to retry sending one if nothing's - // happening for a while - if time.Since(lastCarriageReturnTime) >= 3*time.Second { - lastCarriageReturnTime = time.Now() - if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { - return false, xerrors.Errorf("failed to write carriage return: %w", err) - } - } - time.Sleep(25 * time.Millisecond) - screen := c.cfg.AgentIO.ReadScreen() - - return screen != screenBeforeCarriageReturn, nil - }); err != nil { - return xerrors.Errorf("failed to wait for processing to start: %w", err) - } - - return nil -} - -var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") -var MessageValidationErrorEmpty = xerrors.New("message must not be empty") -var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") - -func (c *Conversation) SendMessage(messageParts ...MessagePart) error { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable { - return MessageValidationErrorChanging - } - - message := PartsToString(messageParts...) - if message != msgfmt.TrimWhitespace(message) { - // msgfmt formatting functions assume this - return MessageValidationErrorWhitespace - } - if message == "" { - // writeMessageWithConfirmation requires a non-empty message - return MessageValidationErrorEmpty - } - - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - now := c.cfg.GetTime() - c.updateLastAgentMessage(screenBeforeMessage, now) - - if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil { - return xerrors.Errorf("failed to send message: %w", err) - } - - c.screenBeforeLastUserMessage = screenBeforeMessage - c.messages = append(c.messages, ConversationMessage{ - Id: len(c.messages), - Message: message, - Role: ConversationRoleUser, - Time: now, - }) - return nil -} - -// Assumes that the caller holds the lock -func (c *Conversation) statusInner() ConversationStatus { - // sanity checks - if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { - panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) - } - if c.stableSnapshotsThreshold == 0 { - panic("stable snapshots threshold is 0. can't check stability") - } - - snapshots := c.snapshotBuffer.GetAll() - if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { - // if the last message is a user message then the snapshot loop hasn't - // been triggered since the last user message, and we should assume - // the screen is changing - return ConversationStatusChanging - } - - if len(snapshots) != c.stableSnapshotsThreshold { - return ConversationStatusInitializing - } - - for i := 1; i < len(snapshots); i++ { - if snapshots[0].screen != snapshots[i].screen { - return ConversationStatusChanging - } - } - - if !c.InitialPromptSent && !c.ReadyForInitialPrompt { - if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { - c.ReadyForInitialPrompt = true - return ConversationStatusStable - } - return ConversationStatusChanging - } - - return ConversationStatusStable -} - -func (c *Conversation) Status() ConversationStatus { - c.lock.Lock() - defer c.lock.Unlock() - - return c.statusInner() -} - -func (c *Conversation) Messages() []ConversationMessage { - c.lock.Lock() - defer c.lock.Unlock() - - result := make([]ConversationMessage, len(c.messages)) - copy(result, c.messages) - return result +// Conversation allows tracking of a conversation between a user and an agent. +type Conversation interface { + Messages() []ConversationMessage + Snapshot(string) + Start(context.Context) + Status() ConversationStatus + String() string } -func (c *Conversation) Screen() string { - c.lock.Lock() - defer c.lock.Unlock() - - snapshots := c.snapshotBuffer.GetAll() - if len(snapshots) == 0 { - return "" - } - return snapshots[len(snapshots)-1].screen +type ConversationMessage struct { + Id int + Message string + Role ConversationRole + Time time.Time } diff --git a/lib/screentracker/diff.go b/lib/screentracker/diff.go new file mode 100644 index 00000000..47c5b78c --- /dev/null +++ b/lib/screentracker/diff.go @@ -0,0 +1,56 @@ +package screentracker + +import ( + "strings" + + "github.com/coder/agentapi/lib/msgfmt" +) + +// screenDiff compares two screen states and attempts to find latest message of the given agent type. +func screenDiff(oldScreen, newScreen string, agentType msgfmt.AgentType) string { + oldLines := strings.Split(oldScreen, "\n") + newLines := strings.Split(newScreen, "\n") + oldLinesMap := make(map[string]bool) + + // -1 indicates no header + dynamicHeaderEnd := -1 + + // Skip header lines for Opencode agent type to avoid false positives + // The header contains dynamic content (token count, context percentage, cost) + // that changes between screens, causing line comparison mismatches: + // + // ┃ # Getting Started with Claude CLI ┃ + // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ + if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { + dynamicHeaderEnd = 2 + } + + for _, line := range oldLines { + oldLinesMap[line] = true + } + firstNonMatchingLine := len(newLines) + for i, line := range newLines[dynamicHeaderEnd+1:] { + if !oldLinesMap[line] { + firstNonMatchingLine = i + break + } + } + newSectionLines := newLines[firstNonMatchingLine:] + + // remove leading and trailing lines which are empty or have only whitespace + startLine := 0 + endLine := len(newSectionLines) - 1 + for i := range newSectionLines { + if strings.TrimSpace(newSectionLines[i]) != "" { + startLine = i + break + } + } + for i := len(newSectionLines) - 1; i >= 0; i-- { + if strings.TrimSpace(newSectionLines[i]) != "" { + endLine = i + break + } + } + return strings.Join(newSectionLines[startLine:endLine+1], "\n") +} diff --git a/lib/screentracker/diff_internal_test.go b/lib/screentracker/diff_internal_test.go new file mode 100644 index 00000000..d68bc36c --- /dev/null +++ b/lib/screentracker/diff_internal_test.go @@ -0,0 +1,39 @@ +package screentracker + +import ( + "embed" + "path" + "testing" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/stretchr/testify/assert" +) + +//go:embed testdata +var testdataDir embed.FS + +func TestScreenDiff(t *testing.T) { + t.Run("simple", func(t *testing.T) { + assert.Equal(t, "", screenDiff("123456", "123456", msgfmt.AgentTypeCustom)) + assert.Equal(t, "1234567", screenDiff("123456", "1234567", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) + assert.Equal(t, "12342", screenDiff("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("89", "42", msgfmt.AgentTypeCustom)) + }) + + dir := "testdata/diff" + cases, err := testdataDir.ReadDir(dir) + assert.NoError(t, err) + for _, c := range cases { + t.Run(c.Name(), func(t *testing.T) { + before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) + assert.NoError(t, err) + after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) + assert.NoError(t, err) + expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) + assert.NoError(t, err) + assert.Equal(t, string(expected), screenDiff(string(before), string(after), msgfmt.AgentTypeCustom)) + }) + } +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go new file mode 100644 index 00000000..91b956a1 --- /dev/null +++ b/lib/screentracker/pty_conversation.go @@ -0,0 +1,371 @@ +package screentracker + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/coder/agentapi/lib/util" + "golang.org/x/xerrors" +) + +// A screenSnapshot represents a snapshot of the PTY at a specific time. +type screenSnapshot struct { + timestamp time.Time + screen string +} + +type MessagePartText struct { + Content string + Alias string + Hidden bool +} + +var _ MessagePart = &MessagePartText{} + +func (p MessagePartText) Do(writer AgentIO) error { + _, err := writer.Write([]byte(p.Content)) + return err +} + +func (p MessagePartText) String() string { + if p.Hidden { + return "" + } + if p.Alias != "" { + return p.Alias + } + return p.Content +} + +// PTYConversationConfig is the configuration for a PTYConversation. +type PTYConversationConfig struct { + AgentType msgfmt.AgentType + AgentIO AgentIO + // GetTime returns the current time + GetTime func() time.Time + // How often to take a snapshot for the stability check + SnapshotInterval time.Duration + // How long the screen should not change to be considered stable + ScreenStabilityLength time.Duration + // Function to format the messages received from the agent + // userInput is the last user message + FormatMessage func(message string, userInput string) string + // SkipWritingMessage skips the writing of a message to the agent. + // This is used in tests + SkipWritingMessage bool + // SkipSendMessageStatusCheck skips the check for whether the message can be sent. + // This is used in tests + SkipSendMessageStatusCheck bool + // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt + ReadyForInitialPrompt func(message string) bool + // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls + FormatToolCall func(message string) (string, []string) + Logger *slog.Logger +} + +func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { + length := cfg.ScreenStabilityLength.Milliseconds() + interval := cfg.SnapshotInterval.Milliseconds() + threshold := int(length / interval) + if length%interval != 0 { + threshold++ + } + return threshold + 1 +} + +// PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. +// It uses a combination of polling and diffs to detect changes in the screen. +type PTYConversation struct { + cfg PTYConversationConfig + // How many stable snapshots are required to consider the screen stable + stableSnapshotsThreshold int + snapshotBuffer *RingBuffer[screenSnapshot] + messages []ConversationMessage + screenBeforeLastUserMessage string + lock sync.Mutex + + // InitialPrompt is the initial prompt passed to the agent + InitialPrompt string + // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents + InitialPromptSent bool + // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt + ReadyForInitialPrompt bool + // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message + toolCallMessageSet map[string]bool +} + +var _ Conversation = &PTYConversation{} + +func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string) *PTYConversation { + threshold := cfg.getStableSnapshotsThreshold() + c := &PTYConversation{ + cfg: cfg, + stableSnapshotsThreshold: threshold, + snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), + messages: []ConversationMessage{ + { + Message: "", + Role: ConversationRoleAgent, + Time: cfg.GetTime(), + }, + }, + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + } + return c +} + +func (c *PTYConversation) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(c.cfg.SnapshotInterval): + // It's important that we hold the lock while reading the screen. + // There's a race condition that occurs without it: + // 1. The screen is read + // 2. Independently, SendMessage is called and takes the lock. + // 3. AddSnapshot is called and waits on the lock. + // 4. SendMessage modifies the terminal state, releases the lock + // 5. AddSnapshot adds a snapshot from a stale screen + c.lock.Lock() + screen := c.cfg.AgentIO.ReadScreen() + c.snapshotLocked(screen) + c.lock.Unlock() + } + } + }() +} + +func (c *PTYConversation) lastMessage(role ConversationRole) ConversationMessage { + for i := len(c.messages) - 1; i >= 0; i-- { + if c.messages[i].Role == role { + return c.messages[i] + } + } + return ConversationMessage{} +} + +// caller MUST hold c.lock +func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp time.Time) { + agentMessage := screenDiff(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) + lastUserMessage := c.lastMessage(ConversationRoleUser) + var toolCalls []string + if c.cfg.FormatMessage != nil { + agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) + } + if c.cfg.FormatToolCall != nil { + agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) + } + for _, toolCall := range toolCalls { + if c.toolCallMessageSet[toolCall] == false { + c.toolCallMessageSet[toolCall] = true + c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) + } + } + shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser + lastAgentMessage := c.lastMessage(ConversationRoleAgent) + if lastAgentMessage.Message == agentMessage { + return + } + conversationMessage := ConversationMessage{ + Message: agentMessage, + Role: ConversationRoleAgent, + Time: timestamp, + } + if shouldCreateNewMessage { + c.messages = append(c.messages, conversationMessage) + + // Cleanup + c.toolCallMessageSet = make(map[string]bool) + + } else { + c.messages[len(c.messages)-1] = conversationMessage + } + c.messages[len(c.messages)-1].Id = len(c.messages) - 1 +} + +func (c *PTYConversation) Snapshot(screen string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.snapshotLocked(screen) +} + +// caller MUST hold c.lock +func (c *PTYConversation) snapshotLocked(screen string) { + snapshot := screenSnapshot{ + timestamp: c.cfg.GetTime(), + screen: screen, + } + c.snapshotBuffer.Add(snapshot) + c.updateLastAgentMessageLocked(screen, snapshot.timestamp) +} + +func (c *PTYConversation) Send(messageParts ...MessagePart) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { + return MessageValidationErrorChanging + } + + var sb strings.Builder + for _, part := range messageParts { + sb.WriteString(part.String()) + } + message := sb.String() + if message != msgfmt.TrimWhitespace(message) { + // msgfmt formatting functions assume this + return MessageValidationErrorWhitespace + } + if message == "" { + // writeMessageWithConfirmation requires a non-empty message + return MessageValidationErrorEmpty + } + + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + now := c.cfg.GetTime() + c.updateLastAgentMessageLocked(screenBeforeMessage, now) + + if err := c.writeStabilize(context.Background(), messageParts...); err != nil { + return xerrors.Errorf("failed to send message: %w", err) + } + + c.screenBeforeLastUserMessage = screenBeforeMessage + c.messages = append(c.messages, ConversationMessage{ + Id: len(c.messages), + Message: message, + Role: ConversationRoleUser, + Time: now, + }) + return nil +} + +// writeStabilize writes messageParts to the screen and waits for the screen to stabilize after the message is written. +func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...MessagePart) error { + if c.cfg.SkipWritingMessage { + return nil + } + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + for _, part := range messageParts { + if err := part.Do(c.cfg.AgentIO); err != nil { + return xerrors.Errorf("failed to write message part: %w", err) + } + } + // wait for the screen to stabilize after the message is written + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 50 * time.Millisecond, + InitialWait: true, + }, func() (bool, error) { + screen := c.cfg.AgentIO.ReadScreen() + if screen != screenBeforeMessage { + time.Sleep(1 * time.Second) + newScreen := c.cfg.AgentIO.ReadScreen() + return newScreen == screen, nil + } + return false, nil + }); err != nil { + return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) + } + + // wait for the screen to change after the carriage return is written + screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() + lastCarriageReturnTime := time.Time{} + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 25 * time.Millisecond, + }, func() (bool, error) { + // we don't want to spam additional carriage returns because the agent may process them + // (aider does this), but we do want to retry sending one if nothing's + // happening for a while + if time.Since(lastCarriageReturnTime) >= 3*time.Second { + lastCarriageReturnTime = time.Now() + if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { + return false, xerrors.Errorf("failed to write carriage return: %w", err) + } + } + time.Sleep(25 * time.Millisecond) + screen := c.cfg.AgentIO.ReadScreen() + + return screen != screenBeforeCarriageReturn, nil + }); err != nil { + return xerrors.Errorf("failed to wait for processing to start: %w", err) + } + + return nil +} + +func (c *PTYConversation) Status() ConversationStatus { + c.lock.Lock() + defer c.lock.Unlock() + + return c.statusLocked() +} + +// caller MUST hold c.lock +func (c *PTYConversation) statusLocked() ConversationStatus { + // sanity checks + if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { + panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) + } + if c.stableSnapshotsThreshold == 0 { + panic("stable snapshots threshold is 0. can't check stability") + } + + snapshots := c.snapshotBuffer.GetAll() + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { + // if the last message is a user message then the snapshot loop hasn't + // been triggered since the last user message, and we should assume + // the screen is changing + return ConversationStatusChanging + } + + if len(snapshots) != c.stableSnapshotsThreshold { + return ConversationStatusInitializing + } + + for i := 1; i < len(snapshots); i++ { + if snapshots[0].screen != snapshots[i].screen { + return ConversationStatusChanging + } + } + + if !c.InitialPromptSent && !c.ReadyForInitialPrompt { + if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { + c.ReadyForInitialPrompt = true + return ConversationStatusStable + } + return ConversationStatusChanging + } + + return ConversationStatusStable +} + +func (c *PTYConversation) Messages() []ConversationMessage { + c.lock.Lock() + defer c.lock.Unlock() + + result := make([]ConversationMessage, len(c.messages)) + copy(result, c.messages) + return result +} + +func (c *PTYConversation) String() string { + c.lock.Lock() + defer c.lock.Unlock() + + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) == 0 { + return "" + } + return snapshots[len(snapshots)-1].screen +} diff --git a/lib/screentracker/conversation_test.go b/lib/screentracker/pty_conversation_test.go similarity index 75% rename from lib/screentracker/conversation_test.go rename to lib/screentracker/pty_conversation_test.go index 9b888813..6798de4d 100644 --- a/lib/screentracker/conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,13 +2,10 @@ package screentracker_test import ( "context" - "embed" "fmt" - "path" "testing" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/stretchr/testify/assert" st "github.com/coder/agentapi/lib/screentracker" @@ -19,7 +16,7 @@ type statusTestStep struct { status st.ConversationStatus } type statusTestParams struct { - cfg st.ConversationConfig + cfg st.PTYConversationConfig steps []statusTestStep } @@ -42,11 +39,11 @@ func statusTest(t *testing.T, params statusTestParams) { if params.cfg.GetTime == nil { params.cfg.GetTime = func() time.Time { return time.Now() } } - c := st.NewConversation(ctx, params.cfg, "") + c := st.NewPTY(ctx, params.cfg, "") assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { - c.AddSnapshot(step.snapshot) + c.Snapshot(step.snapshot) assert.Equal(t, step.status, c.Status(), "step %d", i) } }) @@ -58,7 +55,7 @@ func TestConversation(t *testing.T) { initializing := st.ConversationStatusInitializing statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, // stability threshold: 3 @@ -76,7 +73,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 2 * time.Second, ScreenStabilityLength: 3 * time.Second, // stability threshold: 3 @@ -95,7 +92,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 6 * time.Second, ScreenStabilityLength: 14 * time.Second, // stability threshold: 4 @@ -133,11 +130,11 @@ func TestMessages(t *testing.T) { Time: now, } } - sendMsg := func(c *st.Conversation, msg string) error { - return c.SendMessage(st.MessagePartText{Content: msg}) + sendMsg := func(c *st.PTYConversation, msg string) error { + return c.Send(st.MessagePartText{Content: msg}) } - newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation { - cfg := st.ConversationConfig{ + newConversation := func(opts ...func(*st.PTYConversationConfig)) *st.PTYConversation { + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, @@ -147,7 +144,7 @@ func TestMessages(t *testing.T) { for _, opt := range opts { opt(&cfg) } - return st.NewConversation(context.Background(), cfg, "") + return st.NewPTY(context.Background(), cfg, "") } t.Run("messages are copied", func(t *testing.T) { @@ -167,7 +164,7 @@ func TestMessages(t *testing.T) { t.Run("whitespace-padding", func(t *testing.T) { c := newConversation() for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} { - err := c.SendMessage(st.MessagePartText{Content: msg}) + err := c.Send(st.MessagePartText{Content: msg}) assert.Error(t, err, st.MessageValidationErrorWhitespace) } }) @@ -178,33 +175,33 @@ func TestMessages(t *testing.T) { }{ Time: now, } - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.GetTime = func() time.Time { return nowWrapper.Time } }) - c.AddSnapshot("1") + c.Snapshot("1") msgs := c.Messages() assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, msgs) nowWrapper.Time = nowWrapper.Add(1 * time.Second) - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, msgs, c.Messages()) }) t.Run("tracking messages", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // agent message is recorded when the first snapshot is added - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, c.Messages()) // agent message is updated when the screen changes - c.AddSnapshot("2") + c.Snapshot("2") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), }, c.Messages()) @@ -218,7 +215,7 @@ func TestMessages(t *testing.T) { }, c.Messages()) // agent message is added after a user message - c.AddSnapshot("4") + c.Snapshot("4") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), userMsg(1, "3"), @@ -236,9 +233,9 @@ func TestMessages(t *testing.T) { }, c.Messages()) // conversation status is changing right after a user message - c.AddSnapshot("7") - c.AddSnapshot("7") - c.AddSnapshot("7") + c.Snapshot("7") + c.Snapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) agent.screen = "7" assert.NoError(t, sendMsg(c, "8")) @@ -254,21 +251,21 @@ func TestMessages(t *testing.T) { // conversation status is back to stable after a snapshot that // doesn't change the screen - c.AddSnapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("tracking messages overlap", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // common overlap between screens is removed after a user message - c.AddSnapshot("1") + c.Snapshot("1") agent.screen = "1" assert.NoError(t, sendMsg(c, "2")) - c.AddSnapshot("1\n3") + c.Snapshot("1\n3") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -277,7 +274,7 @@ func TestMessages(t *testing.T) { agent.screen = "1\n3x" assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("1\n3x\n5") + c.Snapshot("1\n3x\n5") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -289,7 +286,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return message + " " + userInput @@ -302,7 +299,7 @@ func TestMessages(t *testing.T) { userMsg(1, "2"), }, c.Messages()) agent.screen = "x" - c.AddSnapshot("x") + c.Snapshot("x") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1 "), userMsg(1, "2"), @@ -312,7 +309,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return "formatted" @@ -329,7 +326,7 @@ func TestMessages(t *testing.T) { }) t.Run("send-message-status-check", func(t *testing.T) { - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.SkipSendMessageStatusCheck = false cfg.SnapshotInterval = 1 * time.Second cfg.ScreenStabilityLength = 2 * time.Second @@ -337,10 +334,10 @@ func TestMessages(t *testing.T) { }) assert.Error(t, sendMsg(c, "1"), st.MessageValidationErrorChanging) for range 3 { - c.AddSnapshot("1") + c.Snapshot("1") } assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("2") + c.Snapshot("2") assert.Error(t, sendMsg(c, "5"), st.MessageValidationErrorChanging) }) @@ -350,68 +347,11 @@ func TestMessages(t *testing.T) { }) } -//go:embed testdata -var testdataDir embed.FS - -func TestFindNewMessage(t *testing.T) { - assert.Equal(t, "", st.FindNewMessage("123456", "123456", msgfmt.AgentTypeCustom)) - assert.Equal(t, "1234567", st.FindNewMessage("123456", "1234567", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) - assert.Equal(t, "12342", st.FindNewMessage("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("89", "42", msgfmt.AgentTypeCustom)) - - dir := "testdata/diff" - cases, err := testdataDir.ReadDir(dir) - assert.NoError(t, err) - for _, c := range cases { - t.Run(c.Name(), func(t *testing.T) { - before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) - assert.NoError(t, err) - after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) - assert.NoError(t, err) - expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) - assert.NoError(t, err) - assert.Equal(t, string(expected), st.FindNewMessage(string(before), string(after), msgfmt.AgentTypeCustom)) - }) - } -} - -func TestPartsToString(t *testing.T) { - assert.Equal(t, "123", st.PartsToString(st.MessagePartText{Content: "123"})) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - ), - ) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "x", Hidden: true}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - st.MessagePartText{Content: "y", Hidden: true}, - ), - ) - assert.Equal(t, - "ab", - st.PartsToString( - st.MessagePartText{Content: "1", Alias: "a"}, - st.MessagePartText{Content: "2", Alias: "b"}, - st.MessagePartText{Content: "3", Alias: "c", Hidden: true}, - ), - ) -} - func TestInitialPromptReadiness(t *testing.T) { now := time.Now() t.Run("agent not ready - status remains changing", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -420,10 +360,10 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Fill buffer with stable snapshots, but agent is not ready - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Even though screen is stable, status should be changing because agent is not ready assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -432,7 +372,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("agent becomes ready - status changes to stable", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -441,14 +381,14 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Agent not ready initially - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) @@ -456,7 +396,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) { agent := &testAgent{screen: "loading..."} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -467,23 +407,23 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Initial state: ReadyForInitialPrompt should be false - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.False(t, c.ReadyForInitialPrompt, "should start as false") assert.False(t, c.InitialPromptSent) assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready: ReadyForInitialPrompt should become true agent.screen = "ready" - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt, "should become true when ready") assert.False(t, c.InitialPromptSent) // Send the initial prompt - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // After sending initial prompt: ReadyForInitialPrompt should be set back to false // (simulating what happens in the actual server code) @@ -496,7 +436,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("no initial prompt - normal status logic applies", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -506,9 +446,9 @@ func TestInitialPromptReadiness(t *testing.T) { }, } // Empty initial prompt means no need to wait for readiness - c := st.NewConversation(context.Background(), cfg, "") + c := st.NewPTY(context.Background(), cfg, "") - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Status should be stable because no initial prompt to wait for assert.Equal(t, st.ConversationStatusStable, c.Status()) @@ -518,7 +458,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) { agent := &testAgent{screen: "ready"} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -529,24 +469,24 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // First, agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) // Send the initial prompt agent.screen = "processing..." - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // Mark initial prompt as sent (simulating what the server does) c.InitialPromptSent = true c.ReadyForInitialPrompt = false // Now test that status logic works normally after initial prompt is sent - c.AddSnapshot("processing...") + c.Snapshot("processing...") // Status should be stable because initial prompt was already sent // and the readiness check is bypassed