From 8349aaf5a665e09bf7ddf267675e7625f8ef10ac Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 16 Jan 2026 16:20:16 -0800 Subject: [PATCH] refactor: Introducing a CallbackPlugin to wrap the old style Callbacks The goal is to unify the processing of Plugins and Callbacks. We should consider depercating and removing the old Callbacks. There are a bunch of cyclical dependencies caused by requests back to the agent to get specific Callbacks. The next step will be to augmet the InvocationContext's PluginManager with the appropriate agent specific callbacks PiperOrigin-RevId: 857342459 --- .../java/com/google/adk/agents/BaseAgent.java | 177 ++++--- .../com/google/adk/agents/CallbackPlugin.java | 333 ------------ .../java/com/google/adk/agents/LlmAgent.java | 170 +++++- .../java/com/google/adk/agents/LoopAgent.java | 21 +- .../com/google/adk/agents/ParallelAgent.java | 17 +- .../google/adk/agents/ReadonlyContext.java | 5 - .../google/adk/agents/SequentialAgent.java | 19 +- .../com/google/adk/agents/BaseAgentTest.java | 221 +------- .../google/adk/agents/CallbackPluginTest.java | 499 ------------------ .../com/google/adk/testing/TestCallback.java | 164 ------ 10 files changed, 297 insertions(+), 1329 deletions(-) delete mode 100644 core/src/main/java/com/google/adk/agents/CallbackPlugin.java delete mode 100644 core/src/test/java/com/google/adk/agents/CallbackPluginTest.java delete mode 100644 core/src/test/java/com/google/adk/testing/TestCallback.java diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index f472cba66..53a978974 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,6 +16,8 @@ package com.google.adk.agents; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -57,7 +59,8 @@ public abstract class BaseAgent { private final List subAgents; - protected final CallbackPlugin callbackPlugin; + private final Optional> beforeAgentCallback; + private final Optional> afterAgentCallback; /** * Creates a new BaseAgent. @@ -74,34 +77,14 @@ public BaseAgent( String name, String description, List subAgents, - @Nullable List beforeAgentCallback, - @Nullable List afterAgentCallback) { - this( - name, - description, - subAgents, - createCallbackPlugin(beforeAgentCallback, afterAgentCallback)); - } - - /** - * Creates a new BaseAgent. - * - * @param name Unique agent name. Cannot be "user" (reserved). - * @param description Agent purpose. - * @param subAgents Agents managed by this agent. - * @param callbackPlugin The callback plugin for this agent. - */ - protected BaseAgent( - String name, - String description, - List subAgents, - CallbackPlugin callbackPlugin) { + List beforeAgentCallback, + List afterAgentCallback) { this.name = name; this.description = description; this.parentAgent = null; this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.callbackPlugin = - callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin; + this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); + this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -109,18 +92,6 @@ protected BaseAgent( } } - /** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */ - private static CallbackPlugin createCallbackPlugin( - @Nullable List beforeAgentCallbacks, - @Nullable List afterAgentCallbacks) { - CallbackPlugin.Builder builder = CallbackPlugin.builder(); - Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback); - Optional.ofNullable(afterAgentCallbacks).stream() - .flatMap(List::stream) - .forEach(builder::addCallback); - return builder.build(); - } - /** * Gets the agent's unique name. * @@ -201,15 +172,11 @@ public List subAgents() { } public Optional> beforeAgentCallback() { - return Optional.of(callbackPlugin.getBeforeAgentCallback()); + return beforeAgentCallback; } public Optional> afterAgentCallback() { - return Optional.of(callbackPlugin.getAfterAgentCallback()); - } - - public Plugin getPlugin() { - return callbackPlugin; + return afterAgentCallback; } /** @@ -252,11 +219,11 @@ public Flowable runAsync(InvocationContext parentContext) { spanContext, span, () -> - processAgentCallbackResult( - ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx), + callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), + beforeAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher( beforeEventOpt -> { if (invocationContext.endInvocation()) { @@ -269,14 +236,11 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable afterEvents = Flowable.defer( () -> - processAgentCallbackResult( - ctx -> - invocationContext - .combinedPlugin() - .afterAgentCallback(this, ctx), + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), + afterAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher(Flowable::fromOptional)); return Flowable.concat(beforeEvents, mainEvents, afterEvents); @@ -285,32 +249,76 @@ public Flowable runAsync(InvocationContext parentContext) { } /** - * Processes the result of an agent callback, creating an {@link Event} if necessary. + * Converts before-agent callbacks to functions. + * + * @param callbacks Before-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> beforeCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Converts after-agent callbacks to functions. + * + * @param callbacks After-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> afterCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Calls agent callbacks and returns the first produced event, if any. * - * @param agentCallback The callback function. - * @param invocationContext The current invocation context. - * @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise. + * @param agentCallbacks Callback functions. + * @param invocationContext Current invocation context. + * @return single emitting first event, or empty if none. */ - private Maybe processAgentCallbackResult( - Function> agentCallback, + private Single> callCallback( + List>> agentCallbacks, InvocationContext invocationContext) { - var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null); - return agentCallback - .apply(callbackContext) - .map( - content -> { - invocationContext.setEndInvocation(true); - return Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build(); + if (agentCallbacks == null || agentCallbacks.isEmpty()) { + return Single.just(Optional.empty()); + } + + CallbackContext callbackContext = + new CallbackContext(invocationContext, /* eventActions= */ null); + + return Flowable.fromIterable(agentCallbacks) + .concatMap( + callback -> { + Maybe maybeContent = callback.apply(callbackContext); + + return maybeContent + .map( + content -> { + invocationContext.setEndInvocation(true); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build()); + }) + .toFlowable(); }) + .firstElement() .switchIfEmpty( - Maybe.defer( + Single.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -321,9 +329,9 @@ private Maybe processAgentCallbackResult( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Maybe.just(eventBuilder.build()); + return Single.just(Optional.of(eventBuilder.build())); } else { - return Maybe.empty(); + return Single.just(Optional.empty()); } })); } @@ -391,11 +399,8 @@ public abstract static class Builder> { protected String name; protected String description; protected ImmutableList subAgents; - protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder(); - - protected CallbackPlugin.Builder callbackPluginBuilder() { - return callbackPluginBuilder; - } + protected ImmutableList beforeAgentCallback; + protected ImmutableList afterAgentCallback; /** This is a safe cast to the concrete builder type. */ @SuppressWarnings("unchecked") @@ -429,25 +434,25 @@ public B subAgents(BaseAgent... subAgents) { @CanIgnoreReturnValue public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { - callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback); + this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - beforeAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { - callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback); + this.afterAgentCallback = ImmutableList.of(afterAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - afterAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java deleted file mode 100644 index 791e9455c..000000000 --- a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java +++ /dev/null @@ -1,333 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackBase; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackBase; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackBase; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackBase; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackBase; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.plugins.BasePlugin; -import com.google.adk.plugins.PluginManager; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A plugin that wraps callbacks and exposes them as a plugin. */ -public class CallbackPlugin extends PluginManager { - - private static final Logger logger = LoggerFactory.getLogger(CallbackPlugin.class); - - private final ImmutableListMultimap, Object> callbacks; - - private CallbackPlugin( - ImmutableList plugins, - ImmutableListMultimap, Object> callbacks) { - super(plugins); - this.callbacks = callbacks; - } - - @Override - public String getName() { - return "CallbackPlugin"; - } - - @SuppressWarnings("unchecked") // The builder ensures that the type is correct. - private ImmutableList getCallbacks(Class type) { - return (ImmutableList) callbacks.get(type); - } - - public ImmutableList getBeforeAgentCallback() { - return getCallbacks(Callbacks.BeforeAgentCallback.class); - } - - public ImmutableList getAfterAgentCallback() { - return getCallbacks(Callbacks.AfterAgentCallback.class); - } - - public ImmutableList getBeforeModelCallback() { - return getCallbacks(Callbacks.BeforeModelCallback.class); - } - - public ImmutableList getAfterModelCallback() { - return getCallbacks(Callbacks.AfterModelCallback.class); - } - - public ImmutableList getBeforeToolCallback() { - return getCallbacks(Callbacks.BeforeToolCallback.class); - } - - public ImmutableList getAfterToolCallback() { - return getCallbacks(Callbacks.AfterToolCallback.class); - } - - public static Builder builder() { - return new Builder(); - } - - /** Builder for {@link CallbackPlugin}. */ - public static class Builder { - // Ensures a unique name for each callback. - private static final AtomicInteger callbackId = new AtomicInteger(0); - - private final ImmutableList.Builder plugins = ImmutableList.builder(); - private final ListMultimap, Object> callbacks = ArrayListMultimap.create(); - - Builder() {} - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) { - callbacks.put(Callbacks.BeforeAgentCallback.class, callback); - plugins.add( - new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync callback) { - return addBeforeAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) { - callbacks.put(Callbacks.AfterAgentCallback.class, callback); - plugins.add( - new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callback) { - return addAfterAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) { - callbacks.put(Callbacks.BeforeModelCallback.class, callback); - plugins.add( - new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return callback.call(callbackContext, llmRequest); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync callback) { - return addBeforeModelCallback( - (callbackContext, llmRequest) -> - Maybe.fromOptional(callback.call(callbackContext, llmRequest))); - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) { - callbacks.put(Callbacks.AfterModelCallback.class, callback); - plugins.add( - new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return callback.call(callbackContext, llmResponse); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callback) { - return addAfterModelCallback( - (callbackContext, llmResponse) -> - Maybe.fromOptional(callback.call(callbackContext, llmResponse))); - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) { - callbacks.put(Callbacks.BeforeToolCallback.class, callback); - plugins.add( - new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callback) { - return addBeforeToolCallback( - (invocationContext, tool, toolArgs, toolContext) -> - Maybe.fromOptional(callback.call(invocationContext, tool, toolArgs, toolContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) { - callbacks.put(Callbacks.AfterToolCallback.class, callback); - plugins.add( - new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> afterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return callback.call( - toolContext.invocationContext(), tool, toolArgs, toolContext, result); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallbackSync(Callbacks.AfterToolCallbackSync callback) { - return addAfterToolCallback( - (invocationContext, tool, toolArgs, toolContext, result) -> - Maybe.fromOptional( - callback.call(invocationContext, tool, toolArgs, toolContext, result))); - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeAgentCallbackBase callback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - addBeforeAgentCallback(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - addBeforeAgentCallbackSync(beforeAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterAgentCallbackBase callback) { - if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { - addAfterAgentCallback(afterAgentCallbackInstance); - } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { - addAfterAgentCallbackSync(afterAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeModelCallbackBase callback) { - if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { - addBeforeModelCallback(beforeModelCallbackInstance); - } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { - addBeforeModelCallbackSync(beforeModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterModelCallbackBase callback) { - if (callback instanceof AfterModelCallback afterModelCallbackInstance) { - addAfterModelCallback(afterModelCallbackInstance); - } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { - addAfterModelCallbackSync(afterModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeToolCallbackBase callback) { - if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { - addBeforeToolCallback(beforeToolCallbackInstance); - } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { - addBeforeToolCallbackSync(beforeToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterToolCallbackBase callback) { - if (callback instanceof AfterToolCallback afterToolCallbackInstance) { - addAfterToolCallback(afterToolCallbackInstance); - } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { - addAfterToolCallbackSync(afterToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - public CallbackPlugin build() { - return new CallbackPlugin(plugins.build(), ImmutableListMultimap.copyOf(callbacks)); - } - } -} diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 3913b7468..a273012fc 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -56,6 +56,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; @@ -94,6 +95,10 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; + private final Optional> beforeModelCallback; + private final Optional> afterModelCallback; + private final Optional> beforeToolCallback; + private final Optional> afterToolCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -108,7 +113,8 @@ protected LlmAgent(Builder builder) { builder.name, builder.description, builder.subAgents, - builder.callbackPluginBuilder.build()); + builder.beforeAgentCallback, + builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); this.instruction = builder.instruction == null ? new Instruction.Static("") : builder.instruction; @@ -122,6 +128,10 @@ protected LlmAgent(Builder builder) { this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; + this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); + this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); + this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); + this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); @@ -163,6 +173,10 @@ public static class Builder extends BaseAgent.Builder { private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; + private ImmutableList beforeModelCallback; + private ImmutableList afterModelCallback; + private ImmutableList beforeToolCallback; + private ImmutableList afterToolCallback; private Schema inputSchema; private Schema outputSchema; private Executor executor; @@ -276,86 +290,200 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) { @CanIgnoreReturnValue public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { - callbackPluginBuilder.addBeforeModelCallback(beforeModelCallback); + this.beforeModelCallback = ImmutableList.of(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(List beforeModelCallback) { - beforeModelCallback.forEach(callbackPluginBuilder::addCallback); + if (beforeModelCallback == null) { + this.beforeModelCallback = null; + } else if (beforeModelCallback.isEmpty()) { + this.beforeModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeModelCallbackBase callback : beforeModelCallback) { + if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { + builder.add(beforeModelCallbackInstance); + } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { + builder.add( + (BeforeModelCallback) + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSyncInstance.call( + callbackContext, llmRequestBuilder))); + } else { + logger.warn( + "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) { - callbackPluginBuilder.addBeforeModelCallbackSync(beforeModelCallbackSync); + this.beforeModelCallback = + ImmutableList.of( + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSync.call(callbackContext, llmRequestBuilder))); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(AfterModelCallback afterModelCallback) { - callbackPluginBuilder.addAfterModelCallback(afterModelCallback); + this.afterModelCallback = ImmutableList.of(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(List afterModelCallback) { - afterModelCallback.forEach(callbackPluginBuilder::addCallback); + if (afterModelCallback == null) { + this.afterModelCallback = null; + } else if (afterModelCallback.isEmpty()) { + this.afterModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterModelCallbackBase callback : afterModelCallback) { + if (callback instanceof AfterModelCallback afterModelCallbackInstance) { + builder.add(afterModelCallbackInstance); + } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { + builder.add( + (AfterModelCallback) + (callbackContext, llmResponse) -> + Maybe.fromOptional( + afterModelCallbackSyncInstance.call(callbackContext, llmResponse))); + } else { + logger.warn( + "Invalid afterModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder afterModelCallbackSync(AfterModelCallbackSync afterModelCallbackSync) { - callbackPluginBuilder.addAfterModelCallbackSync(afterModelCallbackSync); + this.afterModelCallback = + ImmutableList.of( + (callbackContext, llmResponse) -> + Maybe.fromOptional(afterModelCallbackSync.call(callbackContext, llmResponse))); return this; } @CanIgnoreReturnValue public Builder beforeAgentCallbackSync(BeforeAgentCallbackSync beforeAgentCallbackSync) { - callbackPluginBuilder.addBeforeAgentCallbackSync(beforeAgentCallbackSync); + this.beforeAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(beforeAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackSync) { - callbackPluginBuilder.addAfterAgentCallbackSync(afterAgentCallbackSync); + this.afterAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(afterAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) { - callbackPluginBuilder.addBeforeToolCallback(beforeToolCallback); + this.beforeToolCallback = ImmutableList.of(beforeToolCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback( @Nullable List beforeToolCallbacks) { - beforeToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (beforeToolCallbacks == null) { + this.beforeToolCallback = null; + } else if (beforeToolCallbacks.isEmpty()) { + this.beforeToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeToolCallbackBase callback : beforeToolCallbacks) { + if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { + builder.add(beforeToolCallbackInstance); + } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext))); + } else { + logger.warn( + "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) { - callbackPluginBuilder.addBeforeToolCallbackSync(beforeToolCallbackSync); + this.beforeToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSync.call( + invocationContext, baseTool, input, toolContext))); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(AfterToolCallback afterToolCallback) { - callbackPluginBuilder.addAfterToolCallback(afterToolCallback); + this.afterToolCallback = ImmutableList.of(afterToolCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(@Nullable List afterToolCallbacks) { - afterToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (afterToolCallbacks == null) { + this.afterToolCallback = null; + } else if (afterToolCallbacks.isEmpty()) { + this.afterToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterToolCallbackBase callback : afterToolCallbacks) { + if (callback instanceof AfterToolCallback afterToolCallbackInstance) { + builder.add(afterToolCallbackInstance); + } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext, response))); + } else { + logger.warn( + "Invalid afterToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) { - callbackPluginBuilder.addAfterToolCallbackSync(afterToolCallbackSync); + this.afterToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSync.call( + invocationContext, baseTool, input, toolContext, response))); return this; } @@ -629,19 +757,19 @@ public boolean disallowTransferToPeers() { } public Optional> beforeModelCallback() { - return Optional.of(callbackPlugin.getBeforeModelCallback()); + return beforeModelCallback; } public Optional> afterModelCallback() { - return Optional.of(callbackPlugin.getAfterModelCallback()); + return afterModelCallback; } public Optional> beforeToolCallback() { - return Optional.of(callbackPlugin.getBeforeToolCallback()); + return beforeToolCallback; } public Optional> afterToolCallback() { - return Optional.of(callbackPlugin.getAfterToolCallback()); + return afterToolCallback; } public Optional inputSchema() { @@ -702,8 +830,8 @@ private Model resolveModelInternal() { } BaseAgent current = this.parentAgent(); while (current != null) { - if (current instanceof LlmAgent llmAgent) { - return llmAgent.resolvedModel(); + if (current instanceof LlmAgent) { + return ((LlmAgent) current).resolvedModel(); } current = current.parentAgent(); } diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index 921ef3689..d9d049f80 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -46,13 +46,16 @@ public class LoopAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private LoopAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); - this.maxIterations = builder.maxIterations; + private LoopAgent( + String name, + String description, + List subAgents, + Optional maxIterations, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.maxIterations = maxIterations; } /** Builder for {@link LoopAgent}. */ @@ -73,7 +76,9 @@ public Builder maxIterations(Optional maxIterations) { @Override public LoopAgent build() { - return new LoopAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new LoopAgent( + name, description, subAgents, maxIterations, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 583bfffcb..f30d951aa 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -45,12 +45,14 @@ public class ParallelAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private ParallelAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private ParallelAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link ParallelAgent}. */ @@ -58,7 +60,8 @@ public static class Builder extends BaseAgent.Builder { @Override public ParallelAgent build() { - return new ParallelAgent(this); + return new ParallelAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java index dc7480f58..7d3a5acb9 100644 --- a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java +++ b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java @@ -34,11 +34,6 @@ public ReadonlyContext(InvocationContext invocationContext) { this.invocationContext = invocationContext; } - /** Returns the invocation context. */ - public InvocationContext invocationContext() { - return invocationContext; - } - /** Returns the user content that initiated this invocation. */ public Optional userContent() { return invocationContext.userContent(); diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index aa4b76fb6..b0b45a0ec 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -18,6 +18,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,12 +36,14 @@ public class SequentialAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private SequentialAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private SequentialAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link SequentialAgent}. */ @@ -48,7 +51,9 @@ public static class Builder extends BaseAgent.Builder { @Override public SequentialAgent build() { - return new SequentialAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new SequentialAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 8258d32d0..6e06a34ab 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -18,17 +18,16 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.testing.TestBaseAgent; -import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,213 +51,37 @@ public void constructor_setsNameAndDescription() { @Test public void runAsync_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunAsyncImplAndAfterCallback() { - var runAsyncImpl = TestCallback.returningEmpty(); + AtomicBoolean runAsyncImplCalled = new AtomicBoolean(false); + AtomicBoolean afterAgentCallbackCalled = new AtomicBoolean(false); Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - var beforeCallback = TestCallback.returning(callbackContent); - var afterCallback = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier("main_output")); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(callbackContent); - assertThat(runAsyncImpl.wasCalled()).isFalse(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isFalse(); - } - - @Test - public void runAsync_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { - Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - var beforeCallback1 = TestCallback.returning(callbackContent); - var beforeCallback2 = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of( - beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), - ImmutableList.of(), - TestCallback.returningEmpty().asRunAsyncImplSupplier("main_output")); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - var unused = agent.runAsync(invocationContext).toList().blockingGet(); - assertThat(beforeCallback1.wasCalled()).isTrue(); - assertThat(beforeCallback2.wasCalled()).isFalse(); - } - - @Test - public void runAsync_noCallbacks_invokesRunAsyncImpl() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - /* beforeAgentCallbacks= */ ImmutableList.of(), - /* afterAgentCallbacks= */ ImmutableList.of(), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunAsyncImplAndAfterCallbacks() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - var beforeCallback = TestCallback.returningEmpty(); - var afterCallback = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(1); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_afterCallbackReturnsContent_invokesRunAsyncImplAndAfterCallbacksAndReturnsAllContent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); - var beforeCallback = TestCallback.returningEmpty(); - var afterCallback = TestCallback.returning(afterCallbackContent); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(results.get(1).content()).hasValue(afterCallbackContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - BeforeAgentCallback beforeCallback = - new BeforeAgentCallback() { - @Override - public Maybe call(CallbackContext context) { - context.state().put("key", "value"); - return Maybe.empty(); - } + Callbacks.BeforeAgentCallback beforeCallback = (callbackContext) -> Maybe.just(callbackContent); + Callbacks.AfterAgentCallback afterCallback = + (callbackContext) -> { + afterAgentCallbackCalled.set(true); + return Maybe.empty(); }; - var afterCallback = TestCallback.returningEmpty(); TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, ImmutableList.of(beforeCallback), - ImmutableList.of(afterCallback.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - // State event from before callback - assertThat(results.get(0).content()).isEmpty(); - assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); - // Content event from runAsyncImpl - assertThat(results.get(1).content()).hasValue(runAsyncImplContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(afterCallback.wasCalled()).isTrue(); - } - - @Test - public void - runAsync_afterCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - var beforeCallback = TestCallback.returningEmpty(); - AfterAgentCallback afterCallback = - new AfterAgentCallback() { - @Override - public Maybe call(CallbackContext context) { - context.state().put("key", "value"); - return Maybe.empty(); - } - }; - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback.asBeforeAgentCallback()), ImmutableList.of(afterCallback), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + () -> + Flowable.defer( + () -> { + runAsyncImplCalled.set(true); + return Flowable.just( + Event.builder() + .content(Content.fromParts(Part.fromText("main_output"))) + .build()); + })); InvocationContext invocationContext = TestUtils.createInvocationContext(agent); List results = agent.runAsync(invocationContext).toList().blockingGet(); - assertThat(results).hasSize(2); - // Content event from runAsyncImpl - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - // State event from after callback - assertThat(results.get(1).content()).isEmpty(); - assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(beforeCallback.wasCalled()).isTrue(); - } - - @Test - public void runAsync_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { - var runAsyncImpl = TestCallback.returningEmpty(); - Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); - Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); - var afterCallback1 = TestCallback.returning(afterCallbackContent); - var afterCallback2 = TestCallback.returningEmpty(); - TestBaseAgent agent = - new TestBaseAgent( - TEST_AGENT_NAME, - TEST_AGENT_DESCRIPTION, - ImmutableList.of(), - ImmutableList.of( - afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), - runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); - InvocationContext invocationContext = TestUtils.createInvocationContext(agent); - - List results = agent.runAsync(invocationContext).toList().blockingGet(); - - assertThat(results).hasSize(2); - assertThat(results.get(0).content()).hasValue(runAsyncImplContent); - assertThat(results.get(1).content()).hasValue(afterCallbackContent); - assertThat(runAsyncImpl.wasCalled()).isTrue(); - assertThat(afterCallback1.wasCalled()).isTrue(); - assertThat(afterCallback2.wasCalled()).isFalse(); + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(callbackContent); + assertThat(runAsyncImplCalled.get()).isFalse(); + assertThat(afterAgentCallbackCalled.get()).isFalse(); } } diff --git a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java deleted file mode 100644 index 361c86193..000000000 --- a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java +++ /dev/null @@ -1,499 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import static com.google.adk.testing.TestUtils.createInvocationContext; -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.events.EventActions; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.testing.TestCallback; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public final class CallbackPluginTest { - - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - @Mock private BaseAgent agent; - @Mock private BaseTool tool; - @Mock private ToolContext toolContext; - private InvocationContext invocationContext; - private CallbackContext callbackContext; - - @Before - public void setUp() { - invocationContext = createInvocationContext(agent); - callbackContext = - new CallbackContext( - invocationContext, - EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()); - } - - @Test - public void build_empty_successful() { - CallbackPlugin plugin = CallbackPlugin.builder().build(); - assertThat(plugin.getName()).isEqualTo("CallbackPlugin"); - assertThat(plugin.getBeforeAgentCallback()).isEmpty(); - assertThat(plugin.getAfterAgentCallback()).isEmpty(); - assertThat(plugin.getBeforeModelCallback()).isEmpty(); - assertThat(plugin.getAfterModelCallback()).isEmpty(); - assertThat(plugin.getBeforeToolCallback()).isEmpty(); - assertThat(plugin.getAfterToolCallback()).isEmpty(); - } - - @Test - public void addBeforeAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - BeforeAgentCallback callback = testCallback.asBeforeAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeAgentCallback(callback).build(); - - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeAgentCallbackSync(testCallback.asBeforeAgentCallbackSync()) - .build(); - - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - AfterAgentCallback callback = testCallback.asAfterAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterAgentCallback(callback).build(); - - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterAgentCallbackSync(testCallback.asAfterAgentCallbackSync()) - .build(); - - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeModelCallback_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - BeforeModelCallback callback = testCallback.asBeforeModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeModelCallback(callback).build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeModelCallbackSync_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallbackSync(testCallback.asBeforeModelCallbackSync()) - .build(); - - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallback_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallback callback = testCallback.asAfterModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallback(callback).build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallbackSync_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallbackSync callback = testCallback.asAfterModelCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallbackSync(callback).build(); - - assertThat(plugin.getAfterModelCallback()).hasSize(1); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeToolCallback_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - BeforeToolCallback callback = testCallback.asBeforeToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeToolCallback(callback).build(); - - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addBeforeToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeToolCallbackSync(testCallback.asBeforeToolCallbackSync()) - .build(); - - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallback_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallback callback = testCallback.asAfterToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallback(callback).build(); - - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallbackSync callback = testCallback.asAfterToolCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallbackSync(callback).build(); - - assertThat(plugin.getAfterToolCallback()).hasSize(1); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addCallback_beforeAgentCallback() { - BeforeAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeAgentCallbackSync() { - BeforeAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_afterAgentCallback() { - AfterAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterAgentCallbackSync() { - AfterAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeModelCallback() { - BeforeModelCallback callback = (ctx, req) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeModelCallbackSync() { - BeforeModelCallbackSync callback = (ctx, req) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - } - - @Test - public void addCallback_afterModelCallback() { - AfterModelCallback callback = (ctx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterModelCallbackSync() { - AfterModelCallbackSync callback = (ctx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeToolCallback() { - BeforeToolCallback callback = (invCtx, tool, toolArgs, toolCtx) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeToolCallbackSync() { - BeforeToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - } - - @Test - public void addCallback_afterToolCallback() { - AfterToolCallback callback = (invCtx, tool, toolArgs, toolCtx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterToolCallbackSync() { - AfterToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).hasSize(1); - } - - @Test - public void addMultipleBeforeModelCallbacks_invokedInOrder() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_shortCircuit() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_shortCircuit() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("response"))).build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_invokedInOrder() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("second"))).build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_bothEmpty_returnsEmpty() { - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } - - @Test - public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() { - LlmResponse initialResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } -} diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java deleted file mode 100644 index 04f83ed9b..000000000 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.testing; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.events.Event; -import com.google.adk.models.LlmResponse; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; - -/** - * A test helper that wraps an {@link AtomicBoolean} and provides factory methods for creating - * callbacks that update the boolean when called. - * - * @param The type of the result returned by the callback. - */ -public final class TestCallback { - private final AtomicBoolean called = new AtomicBoolean(false); - private final Optional result; - - private TestCallback(Optional result) { - this.result = result; - } - - /** Creates a {@link TestCallback} that returns the given result. */ - public static TestCallback returning(T result) { - return new TestCallback<>(Optional.of(result)); - } - - /** Creates a {@link TestCallback} that returns an empty result. */ - public static TestCallback returningEmpty() { - return new TestCallback<>(Optional.empty()); - } - - /** Returns true if the callback was called. */ - public boolean wasCalled() { - return called.get(); - } - - /** Marks the callback as called. */ - public void markAsCalled() { - called.set(true); - } - - private Maybe callMaybe() { - called.set(true); - return result.map(Maybe::just).orElseGet(Maybe::empty); - } - - private Optional callOptional() { - called.set(true); - return result; - } - - /** - * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} - * with an event containing the given content. - */ - public Supplier> asRunAsyncImplSupplier(Content content) { - return () -> - Flowable.defer( - () -> { - markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); - }); - } - - /** - * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} - */ - public Supplier> asRunAsyncImplSupplier(String contentText) { - return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public BeforeAgentCallback asBeforeAgentCallback() { - return ctx -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { - return ctx -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public AfterAgentCallback asAfterAgentCallback() { - return ctx -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Content. - public AfterAgentCallbackSync asAfterAgentCallbackSync() { - return ctx -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public BeforeModelCallback asBeforeModelCallback() { - return (ctx, req) -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public BeforeModelCallbackSync asBeforeModelCallbackSync() { - return (ctx, req) -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public AfterModelCallback asAfterModelCallback() { - return (ctx, res) -> (Maybe) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. - public AfterModelCallbackSync asAfterModelCallbackSync() { - return (ctx, res) -> (Optional) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public BeforeToolCallback asBeforeToolCallback() { - return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public BeforeToolCallbackSync asBeforeToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public AfterToolCallback asAfterToolCallback() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); - } - - @SuppressWarnings("unchecked") // This cast is safe if T is Map. - public AfterToolCallbackSync asAfterToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); - } -}