Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 91 additions & 86 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.adk.agents;

import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.adk.Telemetry;
import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
Expand Down Expand Up @@ -57,7 +59,8 @@ public abstract class BaseAgent {

private final List<? extends BaseAgent> subAgents;

protected final CallbackPlugin callbackPlugin;
private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;

/**
* Creates a new BaseAgent.
Expand All @@ -74,53 +77,21 @@ public BaseAgent(
String name,
String description,
List<? extends BaseAgent> subAgents,
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallback,
@Nullable List<? extends AfterAgentCallback> afterAgentCallback) {
this(
name,
description,
subAgents,
createCallbackPlugin(beforeAgentCallback, afterAgentCallback));
}

/**
* Creates a new BaseAgent.
*
* @param name Unique agent name. Cannot be "user" (reserved).
* @param description Agent purpose.
* @param subAgents Agents managed by this agent.
* @param callbackPlugin The callback plugin for this agent.
*/
protected BaseAgent(
String name,
String description,
List<? extends BaseAgent> subAgents,
CallbackPlugin callbackPlugin) {
List<? extends BeforeAgentCallback> beforeAgentCallback,
List<? extends AfterAgentCallback> afterAgentCallback) {
this.name = name;
this.description = description;
this.parentAgent = null;
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
this.callbackPlugin =
callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin;
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);

// Establish parent relationships for all sub-agents if needed.
for (BaseAgent subAgent : this.subAgents) {
subAgent.parentAgent(this);
}
}

/** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */
private static CallbackPlugin createCallbackPlugin(
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallbacks,
@Nullable List<? extends AfterAgentCallback> afterAgentCallbacks) {
CallbackPlugin.Builder builder = CallbackPlugin.builder();
Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback);
Optional.ofNullable(afterAgentCallbacks).stream()
.flatMap(List::stream)
.forEach(builder::addCallback);
return builder.build();
}

/**
* Gets the agent's unique name.
*
Expand Down Expand Up @@ -201,15 +172,11 @@ public List<? extends BaseAgent> subAgents() {
}

public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
return Optional.of(callbackPlugin.getBeforeAgentCallback());
return beforeAgentCallback;
}

public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
return Optional.of(callbackPlugin.getAfterAgentCallback());
}

public Plugin getPlugin() {
return callbackPlugin;
return afterAgentCallback;
}

/**
Expand Down Expand Up @@ -252,11 +219,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
spanContext,
span,
() ->
processAgentCallbackResult(
ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx),
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(),
beforeAgentCallback.orElse(ImmutableList.of())),
invocationContext)
.map(Optional::of)
.switchIfEmpty(Single.just(Optional.empty()))
.flatMapPublisher(
beforeEventOpt -> {
if (invocationContext.endInvocation()) {
Expand All @@ -269,14 +236,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
Flowable<Event> afterEvents =
Flowable.defer(
() ->
processAgentCallbackResult(
ctx ->
invocationContext
.combinedPlugin()
.afterAgentCallback(this, ctx),
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback.orElse(ImmutableList.of())),
invocationContext)
.map(Optional::of)
.switchIfEmpty(Single.just(Optional.empty()))
.flatMapPublisher(Flowable::fromOptional));

return Flowable.concat(beforeEvents, mainEvents, afterEvents);
Expand All @@ -285,32 +249,76 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
}

/**
* Processes the result of an agent callback, creating an {@link Event} if necessary.
* Converts before-agent callbacks to functions.
*
* @param callbacks Before-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
Plugin pluginManager, List<? extends BeforeAgentCallback> callbacks) {
return Stream.concat(
Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)),
callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
.collect(toImmutableList());
}

/**
* Converts after-agent callbacks to functions.
*
* @param callbacks After-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
Plugin pluginManager, List<? extends AfterAgentCallback> callbacks) {
return Stream.concat(
Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)),
callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
.collect(toImmutableList());
}

/**
* Calls agent callbacks and returns the first produced event, if any.
*
* @param agentCallback The callback function.
* @param invocationContext The current invocation context.
* @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise.
* @param agentCallbacks Callback functions.
* @param invocationContext Current invocation context.
* @return single emitting first event, or empty if none.
*/
private Maybe<Event> processAgentCallbackResult(
Function<CallbackContext, Maybe<Content>> agentCallback,
private Single<Optional<Event>> callCallback(
List<Function<CallbackContext, Maybe<Content>>> agentCallbacks,
InvocationContext invocationContext) {
var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null);
return agentCallback
.apply(callbackContext)
.map(
content -> {
invocationContext.setEndInvocation(true);
return Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author(name())
.branch(invocationContext.branch())
.actions(callbackContext.eventActions())
.content(content)
.build();
if (agentCallbacks == null || agentCallbacks.isEmpty()) {
return Single.just(Optional.empty());
}

CallbackContext callbackContext =
new CallbackContext(invocationContext, /* eventActions= */ null);

return Flowable.fromIterable(agentCallbacks)
.concatMap(
callback -> {
Maybe<Content> maybeContent = callback.apply(callbackContext);

return maybeContent
.map(
content -> {
invocationContext.setEndInvocation(true);
return Optional.of(
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author(name())
.branch(invocationContext.branch())
.actions(callbackContext.eventActions())
.content(content)
.build());
})
.toFlowable();
})
.firstElement()
.switchIfEmpty(
Maybe.defer(
Single.defer(
() -> {
if (callbackContext.state().hasDelta()) {
Event.Builder eventBuilder =
Expand All @@ -321,9 +329,9 @@ private Maybe<Event> processAgentCallbackResult(
.branch(invocationContext.branch())
.actions(callbackContext.eventActions());

return Maybe.just(eventBuilder.build());
return Single.just(Optional.of(eventBuilder.build()));
} else {
return Maybe.empty();
return Single.just(Optional.empty());
}
}));
}
Expand Down Expand Up @@ -391,11 +399,8 @@ public abstract static class Builder<B extends Builder<B>> {
protected String name;
protected String description;
protected ImmutableList<BaseAgent> subAgents;
protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder();

protected CallbackPlugin.Builder callbackPluginBuilder() {
return callbackPluginBuilder;
}
protected ImmutableList<BeforeAgentCallback> beforeAgentCallback;
protected ImmutableList<AfterAgentCallback> afterAgentCallback;

/** This is a safe cast to the concrete builder type. */
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -429,25 +434,25 @@ public B subAgents(BaseAgent... subAgents) {

@CanIgnoreReturnValue
public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) {
callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback);
this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B beforeAgentCallback(List<Callbacks.BeforeAgentCallbackBase> beforeAgentCallback) {
beforeAgentCallback.forEach(callbackPluginBuilder::addCallback);
this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B afterAgentCallback(AfterAgentCallback afterAgentCallback) {
callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback);
this.afterAgentCallback = ImmutableList.of(afterAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B afterAgentCallback(List<Callbacks.AfterAgentCallbackBase> afterAgentCallback) {
afterAgentCallback.forEach(callbackPluginBuilder::addCallback);
this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback);
return self();
}

Expand Down
Loading
Loading