Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d
github.com/charmbracelet/bubbletea v1.3.4
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
github.com/coder/quartz v0.1.2
github.com/danielgtaylor/huma/v2 v2.32.0
github.com/go-chi/chi/v5 v5.2.2
github.com/go-chi/cors v1.2.1
Expand Down Expand Up @@ -193,7 +194,6 @@ require (
go-simpler.org/sloglint v0.11.1 // indirect
go.augendre.info/arangolint v0.2.0 // indirect
go.augendre.info/fatcontext v0.8.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.27.0 // indirect
Expand Down
85 changes: 50 additions & 35 deletions go.sum

Large diffs are not rendered by default.

26 changes: 19 additions & 7 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
mf "github.com/coder/agentapi/lib/msgfmt"
st "github.com/coder/agentapi/lib/screentracker"
"github.com/coder/agentapi/lib/termexec"
"github.com/coder/quartz"
"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/adapters/humachi"
"github.com/danielgtaylor/huma/v2/sse"
Expand All @@ -46,6 +47,7 @@ type Server struct {
emitter *EventEmitter
chatBasePath string
tempDir string
clock quartz.Clock
}

func (s *Server) NormalizeSchema(schema any) any {
Expand Down Expand Up @@ -102,6 +104,7 @@ type ServerConfig struct {
AllowedHosts []string
AllowedOrigins []string
InitialPrompt string
Clock quartz.Clock
}

// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
Expand Down Expand Up @@ -194,6 +197,10 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {

logger := logctx.From(ctx)

if config.Clock == nil {
config.Clock = quartz.NewReal()
}

allowedHosts, err := parseAllowedHosts(config.AllowedHosts)
if err != nil {
return nil, xerrors.Errorf("failed to parse allowed hosts: %w", err)
Expand Down Expand Up @@ -238,11 +245,9 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
}

conversation := st.NewConversation(ctx, st.ConversationConfig{
AgentType: config.AgentType,
AgentIO: config.Process,
GetTime: func() time.Time {
return time.Now()
},
AgentType: config.AgentType,
AgentIO: config.Process,
Clock: config.Clock,
SnapshotInterval: snapshotInterval,
ScreenStabilityLength: 2 * time.Second,
FormatMessage: formatMessage,
Expand Down Expand Up @@ -270,6 +275,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
emitter: emitter,
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
tempDir: tempDir,
clock: config.Clock,
}

// Register API routes
Expand Down Expand Up @@ -333,12 +339,13 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) {
func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.conversation.StartSnapshotLoop(ctx)
go func() {
ticker := s.clock.NewTicker(snapshotInterval)
defer ticker.Stop()
for {
currentStatus := s.conversation.Status()

// Send initial prompt when agent becomes stable for the first time
if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable {

if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
s.logger.Error("Failed to send initial prompt", "error", err)
} else {
Expand All @@ -351,7 +358,12 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType)
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
time.Sleep(snapshotInterval)

select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}()
}
Expand Down
40 changes: 27 additions & 13 deletions lib/screentracker/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/coder/agentapi/lib/msgfmt"
"github.com/coder/agentapi/lib/util"
"github.com/coder/quartz"
"github.com/danielgtaylor/huma/v2"
"golang.org/x/xerrors"
)
Expand All @@ -27,8 +28,8 @@ type AgentIO interface {
type ConversationConfig struct {
AgentType msgfmt.AgentType
AgentIO AgentIO
// GetTime returns the current time
GetTime func() time.Time
// Clock provides time operations for the conversation
Clock quartz.Clock
// How often to take a snapshot for the stability check
SnapshotInterval time.Duration
// How long the screen should not change to be considered stable
Expand Down Expand Up @@ -109,6 +110,9 @@ func getStableSnapshotsThreshold(cfg ConversationConfig) int {
}

func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation {
if cfg.Clock == nil {
cfg.Clock = quartz.NewReal()
}
threshold := getStableSnapshotsThreshold(cfg)
c := &Conversation{
cfg: cfg,
Expand All @@ -118,7 +122,7 @@ func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt
{
Message: "",
Role: ConversationRoleAgent,
Time: cfg.GetTime(),
Time: cfg.Clock.Now(),
},
},
InitialPrompt: initialPrompt,
Expand All @@ -130,11 +134,13 @@ func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt

func (c *Conversation) StartSnapshotLoop(ctx context.Context) {
go func() {
ticker := c.cfg.Clock.NewTicker(c.cfg.SnapshotInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-time.After(c.cfg.SnapshotInterval):
case <-ticker.C:
// It's important that we hold the lock while reading the screen.
// There's a race condition that occurs without it:
// 1. The screen is read
Expand Down Expand Up @@ -250,7 +256,7 @@ func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time
// assumes the caller holds the lock
func (c *Conversation) addSnapshotInner(screen string) {
snapshot := screenSnapshot{
timestamp: c.cfg.GetTime(),
timestamp: c.cfg.Clock.Now(),
screen: screen,
}
c.snapshotBuffer.Add(snapshot)
Expand Down Expand Up @@ -320,10 +326,13 @@ func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, message
Timeout: 15 * time.Second,
MinInterval: 50 * time.Millisecond,
InitialWait: true,
Clock: c.cfg.Clock,
}, func() (bool, error) {
screen := c.cfg.AgentIO.ReadScreen()
if screen != screenBeforeMessage {
time.Sleep(1 * time.Second)
timer := c.cfg.Clock.NewTimer(1 * time.Second)
defer timer.Stop()
<-timer.C
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Maybe this could be a little helper? Say, sleep(clock, time.Second)? Seems re-usable (at least below) and clarifies the code.

Might be a useful addition to quartz. Having <-clock.After(time.Second) would also be nice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newScreen := c.cfg.AgentIO.ReadScreen()
return newScreen == screen, nil
}
Expand All @@ -338,17 +347,20 @@ func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, message
if err := util.WaitFor(ctx, util.WaitTimeout{
Timeout: 15 * time.Second,
MinInterval: 25 * time.Millisecond,
Clock: c.cfg.Clock,
}, func() (bool, error) {
// we don't want to spam additional carriage returns because the agent may process them
// (aider does this), but we do want to retry sending one if nothing's
// happening for a while
if time.Since(lastCarriageReturnTime) >= 3*time.Second {
lastCarriageReturnTime = time.Now()
if c.cfg.Clock.Since(lastCarriageReturnTime) >= 3*time.Second {
lastCarriageReturnTime = c.cfg.Clock.Now()
if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil {
return false, xerrors.Errorf("failed to write carriage return: %w", err)
}
}
time.Sleep(25 * time.Millisecond)
timer := c.cfg.Clock.NewTimer(25 * time.Millisecond)
defer timer.Stop()
<-timer.C
screen := c.cfg.AgentIO.ReadScreen()

return screen != screenBeforeCarriageReturn, nil
Expand All @@ -359,9 +371,11 @@ func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, message
return nil
}

var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace")
var MessageValidationErrorEmpty = xerrors.New("message must not be empty")
var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input")
var (
MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace")
MessageValidationErrorEmpty = xerrors.New("message must not be empty")
MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input")
)

func (c *Conversation) SendMessage(messageParts ...MessagePart) error {
c.lock.Lock()
Expand All @@ -382,7 +396,7 @@ func (c *Conversation) SendMessage(messageParts ...MessagePart) error {
}

screenBeforeMessage := c.cfg.AgentIO.ReadScreen()
now := c.cfg.GetTime()
now := c.cfg.Clock.Now()
c.updateLastAgentMessage(screenBeforeMessage, now)

if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil {
Expand Down
40 changes: 25 additions & 15 deletions lib/screentracker/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/coder/agentapi/lib/msgfmt"
"github.com/coder/quartz"
"github.com/stretchr/testify/assert"

st "github.com/coder/agentapi/lib/screentracker"
Expand Down Expand Up @@ -39,8 +40,8 @@ func (a *testAgent) Write(data []byte) (int, error) {
func statusTest(t *testing.T, params statusTestParams) {
ctx := context.Background()
t.Run(fmt.Sprintf("interval-%s,stability_length-%s", params.cfg.SnapshotInterval, params.cfg.ScreenStabilityLength), func(t *testing.T) {
if params.cfg.GetTime == nil {
params.cfg.GetTime = func() time.Time { return time.Now() }
if params.cfg.Clock == nil {
params.cfg.Clock = quartz.NewReal()
}
c := st.NewConversation(ctx, params.cfg, "")
assert.Equal(t, st.ConversationStatusInitializing, c.Status())
Expand Down Expand Up @@ -137,8 +138,10 @@ func TestMessages(t *testing.T) {
return c.SendMessage(st.MessagePartText{Content: msg})
}
newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation {
mClock := quartz.NewMock(t)
mClock.Set(now)
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 2 * time.Second,
SkipWritingMessage: true,
Expand Down Expand Up @@ -173,21 +176,18 @@ func TestMessages(t *testing.T) {
})

t.Run("no-change-no-message-update", func(t *testing.T) {
nowWrapper := struct {
time.Time
}{
Time: now,
}
mClock := quartz.NewMock(t)
mClock.Set(now)
c := newConversation(func(cfg *st.ConversationConfig) {
cfg.GetTime = func() time.Time { return nowWrapper.Time }
cfg.Clock = mClock
})

c.AddSnapshot("1")
msgs := c.Messages()
assert.Equal(t, []st.ConversationMessage{
agentMsg(0, "1"),
}, msgs)
nowWrapper.Time = nowWrapper.Add(1 * time.Second)
mClock.Set(now.Add(1 * time.Second))
c.AddSnapshot("1")
assert.Equal(t, msgs, c.Messages())
})
Expand Down Expand Up @@ -411,8 +411,10 @@ func TestInitialPromptReadiness(t *testing.T) {
now := time.Now()

t.Run("agent not ready - status remains changing", func(t *testing.T) {
mClock := quartz.NewMock(t)
mClock.Set(now)
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 0,
AgentIO: &testAgent{screen: "loading..."},
Expand All @@ -432,8 +434,10 @@ func TestInitialPromptReadiness(t *testing.T) {
})

t.Run("agent becomes ready - status changes to stable", func(t *testing.T) {
mClock := quartz.NewMock(t)
mClock.Set(now)
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 0,
AgentIO: &testAgent{screen: "loading..."},
Expand All @@ -455,9 +459,11 @@ func TestInitialPromptReadiness(t *testing.T) {
})

t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) {
mClock := quartz.NewMock(t)
mClock.Set(now)
agent := &testAgent{screen: "loading..."}
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 0,
AgentIO: agent,
Expand Down Expand Up @@ -496,8 +502,10 @@ func TestInitialPromptReadiness(t *testing.T) {
})

t.Run("no initial prompt - normal status logic applies", func(t *testing.T) {
mClock := quartz.NewMock(t)
mClock.Set(now)
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 0,
AgentIO: &testAgent{screen: "loading..."},
Expand All @@ -517,9 +525,11 @@ func TestInitialPromptReadiness(t *testing.T) {
})

t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) {
mClock := quartz.NewMock(t)
mClock.Set(now)
agent := &testAgent{screen: "ready"}
cfg := st.ConversationConfig{
GetTime: func() time.Time { return now },
Clock: mClock,
SnapshotInterval: 1 * time.Second,
ScreenStabilityLength: 0,
AgentIO: agent,
Expand Down
Loading