diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index f472cba6..53a97897 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,6 +16,8 @@ package com.google.adk.agents; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -57,7 +59,8 @@ public abstract class BaseAgent { private final List subAgents; - protected final CallbackPlugin callbackPlugin; + private final Optional> beforeAgentCallback; + private final Optional> afterAgentCallback; /** * Creates a new BaseAgent. @@ -74,34 +77,14 @@ public BaseAgent( String name, String description, List subAgents, - @Nullable List beforeAgentCallback, - @Nullable List afterAgentCallback) { - this( - name, - description, - subAgents, - createCallbackPlugin(beforeAgentCallback, afterAgentCallback)); - } - - /** - * Creates a new BaseAgent. - * - * @param name Unique agent name. Cannot be "user" (reserved). - * @param description Agent purpose. - * @param subAgents Agents managed by this agent. - * @param callbackPlugin The callback plugin for this agent. - */ - protected BaseAgent( - String name, - String description, - List subAgents, - CallbackPlugin callbackPlugin) { + List beforeAgentCallback, + List afterAgentCallback) { this.name = name; this.description = description; this.parentAgent = null; this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.callbackPlugin = - callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin; + this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); + this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -109,18 +92,6 @@ protected BaseAgent( } } - /** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */ - private static CallbackPlugin createCallbackPlugin( - @Nullable List beforeAgentCallbacks, - @Nullable List afterAgentCallbacks) { - CallbackPlugin.Builder builder = CallbackPlugin.builder(); - Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback); - Optional.ofNullable(afterAgentCallbacks).stream() - .flatMap(List::stream) - .forEach(builder::addCallback); - return builder.build(); - } - /** * Gets the agent's unique name. * @@ -201,15 +172,11 @@ public List subAgents() { } public Optional> beforeAgentCallback() { - return Optional.of(callbackPlugin.getBeforeAgentCallback()); + return beforeAgentCallback; } public Optional> afterAgentCallback() { - return Optional.of(callbackPlugin.getAfterAgentCallback()); - } - - public Plugin getPlugin() { - return callbackPlugin; + return afterAgentCallback; } /** @@ -252,11 +219,11 @@ public Flowable runAsync(InvocationContext parentContext) { spanContext, span, () -> - processAgentCallbackResult( - ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx), + callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), + beforeAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher( beforeEventOpt -> { if (invocationContext.endInvocation()) { @@ -269,14 +236,11 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable afterEvents = Flowable.defer( () -> - processAgentCallbackResult( - ctx -> - invocationContext - .combinedPlugin() - .afterAgentCallback(this, ctx), + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), + afterAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher(Flowable::fromOptional)); return Flowable.concat(beforeEvents, mainEvents, afterEvents); @@ -285,32 +249,76 @@ public Flowable runAsync(InvocationContext parentContext) { } /** - * Processes the result of an agent callback, creating an {@link Event} if necessary. + * Converts before-agent callbacks to functions. + * + * @param callbacks Before-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> beforeCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Converts after-agent callbacks to functions. + * + * @param callbacks After-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> afterCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Calls agent callbacks and returns the first produced event, if any. * - * @param agentCallback The callback function. - * @param invocationContext The current invocation context. - * @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise. + * @param agentCallbacks Callback functions. + * @param invocationContext Current invocation context. + * @return single emitting first event, or empty if none. */ - private Maybe processAgentCallbackResult( - Function> agentCallback, + private Single> callCallback( + List>> agentCallbacks, InvocationContext invocationContext) { - var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null); - return agentCallback - .apply(callbackContext) - .map( - content -> { - invocationContext.setEndInvocation(true); - return Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build(); + if (agentCallbacks == null || agentCallbacks.isEmpty()) { + return Single.just(Optional.empty()); + } + + CallbackContext callbackContext = + new CallbackContext(invocationContext, /* eventActions= */ null); + + return Flowable.fromIterable(agentCallbacks) + .concatMap( + callback -> { + Maybe maybeContent = callback.apply(callbackContext); + + return maybeContent + .map( + content -> { + invocationContext.setEndInvocation(true); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build()); + }) + .toFlowable(); }) + .firstElement() .switchIfEmpty( - Maybe.defer( + Single.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -321,9 +329,9 @@ private Maybe processAgentCallbackResult( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Maybe.just(eventBuilder.build()); + return Single.just(Optional.of(eventBuilder.build())); } else { - return Maybe.empty(); + return Single.just(Optional.empty()); } })); } @@ -391,11 +399,8 @@ public abstract static class Builder> { protected String name; protected String description; protected ImmutableList subAgents; - protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder(); - - protected CallbackPlugin.Builder callbackPluginBuilder() { - return callbackPluginBuilder; - } + protected ImmutableList beforeAgentCallback; + protected ImmutableList afterAgentCallback; /** This is a safe cast to the concrete builder type. */ @SuppressWarnings("unchecked") @@ -429,25 +434,25 @@ public B subAgents(BaseAgent... subAgents) { @CanIgnoreReturnValue public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { - callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback); + this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - beforeAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { - callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback); + this.afterAgentCallback = ImmutableList.of(afterAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - afterAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java deleted file mode 100644 index 791e9455..00000000 --- a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java +++ /dev/null @@ -1,333 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackBase; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackBase; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackBase; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackBase; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackBase; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.plugins.BasePlugin; -import com.google.adk.plugins.PluginManager; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A plugin that wraps callbacks and exposes them as a plugin. */ -public class CallbackPlugin extends PluginManager { - - private static final Logger logger = LoggerFactory.getLogger(CallbackPlugin.class); - - private final ImmutableListMultimap, Object> callbacks; - - private CallbackPlugin( - ImmutableList plugins, - ImmutableListMultimap, Object> callbacks) { - super(plugins); - this.callbacks = callbacks; - } - - @Override - public String getName() { - return "CallbackPlugin"; - } - - @SuppressWarnings("unchecked") // The builder ensures that the type is correct. - private ImmutableList getCallbacks(Class type) { - return (ImmutableList) callbacks.get(type); - } - - public ImmutableList getBeforeAgentCallback() { - return getCallbacks(Callbacks.BeforeAgentCallback.class); - } - - public ImmutableList getAfterAgentCallback() { - return getCallbacks(Callbacks.AfterAgentCallback.class); - } - - public ImmutableList getBeforeModelCallback() { - return getCallbacks(Callbacks.BeforeModelCallback.class); - } - - public ImmutableList getAfterModelCallback() { - return getCallbacks(Callbacks.AfterModelCallback.class); - } - - public ImmutableList getBeforeToolCallback() { - return getCallbacks(Callbacks.BeforeToolCallback.class); - } - - public ImmutableList getAfterToolCallback() { - return getCallbacks(Callbacks.AfterToolCallback.class); - } - - public static Builder builder() { - return new Builder(); - } - - /** Builder for {@link CallbackPlugin}. */ - public static class Builder { - // Ensures a unique name for each callback. - private static final AtomicInteger callbackId = new AtomicInteger(0); - - private final ImmutableList.Builder plugins = ImmutableList.builder(); - private final ListMultimap, Object> callbacks = ArrayListMultimap.create(); - - Builder() {} - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) { - callbacks.put(Callbacks.BeforeAgentCallback.class, callback); - plugins.add( - new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync callback) { - return addBeforeAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) { - callbacks.put(Callbacks.AfterAgentCallback.class, callback); - plugins.add( - new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callback) { - return addAfterAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) { - callbacks.put(Callbacks.BeforeModelCallback.class, callback); - plugins.add( - new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return callback.call(callbackContext, llmRequest); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync callback) { - return addBeforeModelCallback( - (callbackContext, llmRequest) -> - Maybe.fromOptional(callback.call(callbackContext, llmRequest))); - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) { - callbacks.put(Callbacks.AfterModelCallback.class, callback); - plugins.add( - new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return callback.call(callbackContext, llmResponse); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callback) { - return addAfterModelCallback( - (callbackContext, llmResponse) -> - Maybe.fromOptional(callback.call(callbackContext, llmResponse))); - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) { - callbacks.put(Callbacks.BeforeToolCallback.class, callback); - plugins.add( - new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callback) { - return addBeforeToolCallback( - (invocationContext, tool, toolArgs, toolContext) -> - Maybe.fromOptional(callback.call(invocationContext, tool, toolArgs, toolContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) { - callbacks.put(Callbacks.AfterToolCallback.class, callback); - plugins.add( - new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> afterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return callback.call( - toolContext.invocationContext(), tool, toolArgs, toolContext, result); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallbackSync(Callbacks.AfterToolCallbackSync callback) { - return addAfterToolCallback( - (invocationContext, tool, toolArgs, toolContext, result) -> - Maybe.fromOptional( - callback.call(invocationContext, tool, toolArgs, toolContext, result))); - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeAgentCallbackBase callback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - addBeforeAgentCallback(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - addBeforeAgentCallbackSync(beforeAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterAgentCallbackBase callback) { - if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { - addAfterAgentCallback(afterAgentCallbackInstance); - } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { - addAfterAgentCallbackSync(afterAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeModelCallbackBase callback) { - if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { - addBeforeModelCallback(beforeModelCallbackInstance); - } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { - addBeforeModelCallbackSync(beforeModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterModelCallbackBase callback) { - if (callback instanceof AfterModelCallback afterModelCallbackInstance) { - addAfterModelCallback(afterModelCallbackInstance); - } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { - addAfterModelCallbackSync(afterModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeToolCallbackBase callback) { - if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { - addBeforeToolCallback(beforeToolCallbackInstance); - } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { - addBeforeToolCallbackSync(beforeToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterToolCallbackBase callback) { - if (callback instanceof AfterToolCallback afterToolCallbackInstance) { - addAfterToolCallback(afterToolCallbackInstance); - } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { - addAfterToolCallbackSync(afterToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - public CallbackPlugin build() { - return new CallbackPlugin(plugins.build(), ImmutableListMultimap.copyOf(callbacks)); - } - } -} diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 3913b746..a273012f 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -56,6 +56,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; @@ -94,6 +95,10 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; + private final Optional> beforeModelCallback; + private final Optional> afterModelCallback; + private final Optional> beforeToolCallback; + private final Optional> afterToolCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -108,7 +113,8 @@ protected LlmAgent(Builder builder) { builder.name, builder.description, builder.subAgents, - builder.callbackPluginBuilder.build()); + builder.beforeAgentCallback, + builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); this.instruction = builder.instruction == null ? new Instruction.Static("") : builder.instruction; @@ -122,6 +128,10 @@ protected LlmAgent(Builder builder) { this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; + this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); + this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); + this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); + this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); @@ -163,6 +173,10 @@ public static class Builder extends BaseAgent.Builder { private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; + private ImmutableList beforeModelCallback; + private ImmutableList afterModelCallback; + private ImmutableList beforeToolCallback; + private ImmutableList afterToolCallback; private Schema inputSchema; private Schema outputSchema; private Executor executor; @@ -276,86 +290,200 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) { @CanIgnoreReturnValue public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { - callbackPluginBuilder.addBeforeModelCallback(beforeModelCallback); + this.beforeModelCallback = ImmutableList.of(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(List beforeModelCallback) { - beforeModelCallback.forEach(callbackPluginBuilder::addCallback); + if (beforeModelCallback == null) { + this.beforeModelCallback = null; + } else if (beforeModelCallback.isEmpty()) { + this.beforeModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeModelCallbackBase callback : beforeModelCallback) { + if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { + builder.add(beforeModelCallbackInstance); + } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { + builder.add( + (BeforeModelCallback) + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSyncInstance.call( + callbackContext, llmRequestBuilder))); + } else { + logger.warn( + "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) { - callbackPluginBuilder.addBeforeModelCallbackSync(beforeModelCallbackSync); + this.beforeModelCallback = + ImmutableList.of( + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSync.call(callbackContext, llmRequestBuilder))); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(AfterModelCallback afterModelCallback) { - callbackPluginBuilder.addAfterModelCallback(afterModelCallback); + this.afterModelCallback = ImmutableList.of(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(List afterModelCallback) { - afterModelCallback.forEach(callbackPluginBuilder::addCallback); + if (afterModelCallback == null) { + this.afterModelCallback = null; + } else if (afterModelCallback.isEmpty()) { + this.afterModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterModelCallbackBase callback : afterModelCallback) { + if (callback instanceof AfterModelCallback afterModelCallbackInstance) { + builder.add(afterModelCallbackInstance); + } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { + builder.add( + (AfterModelCallback) + (callbackContext, llmResponse) -> + Maybe.fromOptional( + afterModelCallbackSyncInstance.call(callbackContext, llmResponse))); + } else { + logger.warn( + "Invalid afterModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder afterModelCallbackSync(AfterModelCallbackSync afterModelCallbackSync) { - callbackPluginBuilder.addAfterModelCallbackSync(afterModelCallbackSync); + this.afterModelCallback = + ImmutableList.of( + (callbackContext, llmResponse) -> + Maybe.fromOptional(afterModelCallbackSync.call(callbackContext, llmResponse))); return this; } @CanIgnoreReturnValue public Builder beforeAgentCallbackSync(BeforeAgentCallbackSync beforeAgentCallbackSync) { - callbackPluginBuilder.addBeforeAgentCallbackSync(beforeAgentCallbackSync); + this.beforeAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(beforeAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackSync) { - callbackPluginBuilder.addAfterAgentCallbackSync(afterAgentCallbackSync); + this.afterAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(afterAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) { - callbackPluginBuilder.addBeforeToolCallback(beforeToolCallback); + this.beforeToolCallback = ImmutableList.of(beforeToolCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback( @Nullable List beforeToolCallbacks) { - beforeToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (beforeToolCallbacks == null) { + this.beforeToolCallback = null; + } else if (beforeToolCallbacks.isEmpty()) { + this.beforeToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeToolCallbackBase callback : beforeToolCallbacks) { + if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { + builder.add(beforeToolCallbackInstance); + } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext))); + } else { + logger.warn( + "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) { - callbackPluginBuilder.addBeforeToolCallbackSync(beforeToolCallbackSync); + this.beforeToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSync.call( + invocationContext, baseTool, input, toolContext))); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(AfterToolCallback afterToolCallback) { - callbackPluginBuilder.addAfterToolCallback(afterToolCallback); + this.afterToolCallback = ImmutableList.of(afterToolCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(@Nullable List afterToolCallbacks) { - afterToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (afterToolCallbacks == null) { + this.afterToolCallback = null; + } else if (afterToolCallbacks.isEmpty()) { + this.afterToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterToolCallbackBase callback : afterToolCallbacks) { + if (callback instanceof AfterToolCallback afterToolCallbackInstance) { + builder.add(afterToolCallbackInstance); + } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext, response))); + } else { + logger.warn( + "Invalid afterToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) { - callbackPluginBuilder.addAfterToolCallbackSync(afterToolCallbackSync); + this.afterToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSync.call( + invocationContext, baseTool, input, toolContext, response))); return this; } @@ -629,19 +757,19 @@ public boolean disallowTransferToPeers() { } public Optional> beforeModelCallback() { - return Optional.of(callbackPlugin.getBeforeModelCallback()); + return beforeModelCallback; } public Optional> afterModelCallback() { - return Optional.of(callbackPlugin.getAfterModelCallback()); + return afterModelCallback; } public Optional> beforeToolCallback() { - return Optional.of(callbackPlugin.getBeforeToolCallback()); + return beforeToolCallback; } public Optional> afterToolCallback() { - return Optional.of(callbackPlugin.getAfterToolCallback()); + return afterToolCallback; } public Optional inputSchema() { @@ -702,8 +830,8 @@ private Model resolveModelInternal() { } BaseAgent current = this.parentAgent(); while (current != null) { - if (current instanceof LlmAgent llmAgent) { - return llmAgent.resolvedModel(); + if (current instanceof LlmAgent) { + return ((LlmAgent) current).resolvedModel(); } current = current.parentAgent(); } diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index 921ef368..d9d049f8 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -46,13 +46,16 @@ public class LoopAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private LoopAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); - this.maxIterations = builder.maxIterations; + private LoopAgent( + String name, + String description, + List subAgents, + Optional maxIterations, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.maxIterations = maxIterations; } /** Builder for {@link LoopAgent}. */ @@ -73,7 +76,9 @@ public Builder maxIterations(Optional maxIterations) { @Override public LoopAgent build() { - return new LoopAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new LoopAgent( + name, description, subAgents, maxIterations, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 583bfffc..f30d951a 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -45,12 +45,14 @@ public class ParallelAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private ParallelAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private ParallelAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link ParallelAgent}. */ @@ -58,7 +60,8 @@ public static class Builder extends BaseAgent.Builder { @Override public ParallelAgent build() { - return new ParallelAgent(this); + return new ParallelAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java index dc7480f5..7d3a5acb 100644 --- a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java +++ b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java @@ -34,11 +34,6 @@ public ReadonlyContext(InvocationContext invocationContext) { this.invocationContext = invocationContext; } - /** Returns the invocation context. */ - public InvocationContext invocationContext() { - return invocationContext; - } - /** Returns the user content that initiated this invocation. */ public Optional userContent() { return invocationContext.userContent(); diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index aa4b76fb..b0b45a0e 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -18,6 +18,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,12 +36,14 @@ public class SequentialAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private SequentialAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private SequentialAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link SequentialAgent}. */ @@ -48,7 +51,9 @@ public static class Builder extends BaseAgent.Builder { @Override public SequentialAgent build() { - return new SequentialAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new SequentialAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 8258d32d..6e06a34a 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -18,17 +18,16 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.testing.TestBaseAgent; -import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,213 +51,37 @@ public void constructor_setsNameAndDescription() { @Test public void runAsync_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunAsyncImplAndAfterCallback() { - var runAsyncImpl = TestCallback.returningEmpty(); + AtomicBoolean runAsyncImplCalled = new AtomicBoolean(false); + AtomicBoolean afterAgentCallbackCalled = new AtomicBoolean(false); Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - var beforeCallback = TestCallback.returning(callbackContent); - var afterCallback = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier("main_output")); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(callbackContent); - assertThat(runAsyncImpl.wasCalled()).isFalse(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isFalse(); - } - - @Test - public void runAsync_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { - Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - var beforeCallback1 = TestCallback.returning(callbackContent); - var beforeCallback2 = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of( - beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), - ImmutableList.of(), - TestCallback.returningEmpty().asRunAsyncImplSupplier("main_output")); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - var unused = agent.runAsync(invocationContext).toList().blockingGet(); - assertThat(beforeCallback1.wasCalled()).isTrue(); - assertThat(beforeCallback2.wasCalled()).isFalse(); - } - - @Test - public void runAsync_noCallbacks_invokesRunAsyncImpl() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - /* beforeAgentCallbacks= */ ImmutableList.of(), - /* afterAgentCallbacks= */ ImmutableList.of(), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunAsyncImplAndAfterCallbacks() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - var beforeCallback = TestCallback.returningEmpty(); - var afterCallback = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_afterCallbackReturnsContent_invokesRunAsyncImplAndAfterCallbacksAndReturnsAllContent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); - var beforeCallback = TestCallback.returningEmpty(); - var afterCallback = TestCallback.returning(afterCallbackContent); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(results.get(1).content()).hasValue(afterCallbackContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - BeforeAgentCallback beforeCallback = - new BeforeAgentCallback() { - @Override - public Maybe call(CallbackContext context) { - context.state().put("key", "value"); - return Maybe.empty(); - } + Callbacks.BeforeAgentCallback beforeCallback = (callbackContext) -> Maybe.just(callbackContent); + Callbacks.AfterAgentCallback afterCallback = + (callbackContext) -> { + afterAgentCallbackCalled.set(true); + return Maybe.empty(); }; - var afterCallback = TestCallback.returningEmpty(); TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, ImmutableList.of(beforeCallback), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - // State event from before callback - assertThat(results.get(0).content()).isEmpty(); - assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); - // Content event from runAsyncImpl - assertThat(results.get(1).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_afterCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - var beforeCallback = TestCallback.returningEmpty(); - AfterAgentCallback afterCallback = - new AfterAgentCallback() { - @Override - public Maybe call(CallbackContext context) { - context.state().put("key", "value"); - return Maybe.empty(); - } - }; - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), ImmutableList.of(afterCallback), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + () -> + Flowable.defer( + () -> { + runAsyncImplCalled.set(true); + return Flowable.just( + Event.builder() + .content(Content.fromParts(Part.fromText("main_output"))) + .build()); + })); InvocationContext invocationContext = TestUtils.createInvocationContext(agent); List results = agent.runAsync(invocationContext).toList().blockingGet(); - assertThat(results).hasSize(2); - // Content event from runAsyncImpl - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - // State event from after callback - assertThat(results.get(1).content()).isEmpty(); - assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - } - - @Test - public void runAsync_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); - var afterCallback1 = TestCallback.returning(afterCallbackContent); - var afterCallback2 = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(), - ImmutableList.of( - afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(results.get(1).content()).hasValue(afterCallbackContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(afterCallback1.wasCalled()).isTrue(); - assertThat(afterCallback2.wasCalled()).isFalse(); + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(callbackContent); + assertThat(runAsyncImplCalled.get()).isFalse(); + assertThat(afterAgentCallbackCalled.get()).isFalse(); } } diff --git a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java deleted file mode 100644 index 361c8619..00000000 --- a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java +++ /dev/null @@ -1,499 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import static com.google.adk.testing.TestUtils.createInvocationContext; -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.events.EventActions; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.testing.TestCallback; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public final class CallbackPluginTest { - - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - @Mock private BaseAgent agent; - @Mock private BaseTool tool; - @Mock private ToolContext toolContext; - private InvocationContext invocationContext; - private CallbackContext callbackContext; - - @Before - public void setUp() { - invocationContext = createInvocationContext(agent); - callbackContext = - new CallbackContext( - invocationContext, - EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()); - } - - @Test - public void build_empty_successful() { - CallbackPlugin plugin = CallbackPlugin.builder().build(); - assertThat(plugin.getName()).isEqualTo("CallbackPlugin"); - assertThat(plugin.getBeforeAgentCallback()).isEmpty(); - assertThat(plugin.getAfterAgentCallback()).isEmpty(); - assertThat(plugin.getBeforeModelCallback()).isEmpty(); - assertThat(plugin.getAfterModelCallback()).isEmpty(); - assertThat(plugin.getBeforeToolCallback()).isEmpty(); - assertThat(plugin.getAfterToolCallback()).isEmpty(); - } - - @Test - public void addBeforeAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - BeforeAgentCallback callback = testCallback.asBeforeAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeAgentCallback(callback).build(); - - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeAgentCallbackSync(testCallback.asBeforeAgentCallbackSync()) - .build(); - - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - AfterAgentCallback callback = testCallback.asAfterAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterAgentCallback(callback).build(); - - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterAgentCallbackSync(testCallback.asAfterAgentCallbackSync()) - .build(); - - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeModelCallback_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - BeforeModelCallback callback = testCallback.asBeforeModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeModelCallback(callback).build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeModelCallbackSync_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallbackSync(testCallback.asBeforeModelCallbackSync()) - .build(); - - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallback_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallback callback = testCallback.asAfterModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallback(callback).build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallbackSync_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallbackSync callback = testCallback.asAfterModelCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallbackSync(callback).build(); - - assertThat(plugin.getAfterModelCallback()).hasSize(1); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeToolCallback_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - BeforeToolCallback callback = testCallback.asBeforeToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeToolCallback(callback).build(); - - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addBeforeToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeToolCallbackSync(testCallback.asBeforeToolCallbackSync()) - .build(); - - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallback_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallback callback = testCallback.asAfterToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallback(callback).build(); - - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallbackSync callback = testCallback.asAfterToolCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallbackSync(callback).build(); - - assertThat(plugin.getAfterToolCallback()).hasSize(1); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addCallback_beforeAgentCallback() { - BeforeAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeAgentCallbackSync() { - BeforeAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_afterAgentCallback() { - AfterAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterAgentCallbackSync() { - AfterAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeModelCallback() { - BeforeModelCallback callback = (ctx, req) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeModelCallbackSync() { - BeforeModelCallbackSync callback = (ctx, req) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - } - - @Test - public void addCallback_afterModelCallback() { - AfterModelCallback callback = (ctx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterModelCallbackSync() { - AfterModelCallbackSync callback = (ctx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeToolCallback() { - BeforeToolCallback callback = (invCtx, tool, toolArgs, toolCtx) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeToolCallbackSync() { - BeforeToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - } - - @Test - public void addCallback_afterToolCallback() { - AfterToolCallback callback = (invCtx, tool, toolArgs, toolCtx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterToolCallbackSync() { - AfterToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).hasSize(1); - } - - @Test - public void addMultipleBeforeModelCallbacks_invokedInOrder() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_shortCircuit() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_shortCircuit() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("response"))).build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_invokedInOrder() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("second"))).build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_bothEmpty_returnsEmpty() { - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } - - @Test - public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() { - LlmResponse initialResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } -} diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java deleted file mode 100644 index 04f83ed9..00000000 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.testing; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.events.Event; -import com.google.adk.models.LlmResponse; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; - -/** - * A test helper that wraps an {@link AtomicBoolean} and provides factory methods for creating - * callbacks that update the boolean when called. - * - * @param The type of the result returned by the callback. - */ -public final class TestCallback { - private final AtomicBoolean called = new AtomicBoolean(false); - private final Optional result; - - private TestCallback(Optional result) { - this.result = result; - } - - /** Creates a {@link TestCallback} that returns the given result. */ - public static TestCallback returning(T result) { - return new TestCallback<>(Optional.of(result)); - } - - /** Creates a {@link TestCallback} that returns an empty result. */ - public static TestCallback returningEmpty() { - return new TestCallback<>(Optional.empty()); - } - - /** Returns true if the callback was called. */ - public boolean wasCalled() { - return called.get(); - } - - /** Marks the callback as called. */ - public void markAsCalled() { - called.set(true); - } - - private Maybe callMaybe() { - called.set(true); - return result.map(Maybe::just).orElseGet(Maybe::empty); - } - - private Optional callOptional() { - called.set(true); - return result; - } - - /** - * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} - * with an event containing the given content. - */ - public Supplier> asRunAsyncImplSupplier(Content content) { - return () -> - Flowable.defer( - () -> { - markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); - }); - } - - /** - * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} - */ - public Supplier> asRunAsyncImplSupplier(String contentText) { - return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public BeforeAgentCallback asBeforeAgentCallback() { - return ctx -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { - return ctx -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public AfterAgentCallback asAfterAgentCallback() { - return ctx -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public AfterAgentCallbackSync asAfterAgentCallbackSync() { - return ctx -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public BeforeModelCallback asBeforeModelCallback() { - return (ctx, req) -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public BeforeModelCallbackSync asBeforeModelCallbackSync() { - return (ctx, req) -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public AfterModelCallback asAfterModelCallback() { - return (ctx, res) -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public AfterModelCallbackSync asAfterModelCallbackSync() { - return (ctx, res) -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public BeforeToolCallback asBeforeToolCallback() { - return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public BeforeToolCallbackSync asBeforeToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public AfterToolCallback asAfterToolCallback() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public AfterToolCallbackSync asAfterToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); - } -}