diff --git a/src/art/langgraph/llm_wrapper.py b/src/art/langgraph/llm_wrapper.py index 21d1889f..b5cee898 100644 --- a/src/art/langgraph/llm_wrapper.py +++ b/src/art/langgraph/llm_wrapper.py @@ -5,7 +5,7 @@ import json import os import uuid -from typing import Any, Literal +from typing import Any, Callable, Literal from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.prompt_values import ChatPromptValue @@ -113,26 +113,54 @@ def init_chat_model( model_provider: str | None = None, configurable_fields: Literal[None] = None, config_prefix: str | None = None, + chat_model=None, + timeout_seconds: float | None = None, **kwargs: Any, ): config = CURRENT_CONFIG.get() + if chat_model is not None: + llm = chat_model + llm_factory = None + else: + # Allow overrides passed to init_chat_model, otherwise fall back to the + # current context. + temperature = kwargs.pop("temperature", 1.0) + model_name = model or config["model"] + base_url = kwargs.pop("base_url", config["base_url"]) + api_key = kwargs.pop("api_key", config["api_key"]) + + def llm_factory(ctx_config: dict[str, Any]): + return ChatOpenAI( + base_url=ctx_config.get("base_url", base_url), + api_key=ctx_config.get("api_key", api_key), + model=model_name, + temperature=temperature, + **kwargs, + ) + + llm = llm_factory(config) + return LoggingLLM( - ChatOpenAI( - base_url=config["base_url"], - api_key=config["api_key"], - model=config["model"], - temperature=1.0, - ), - config["logger"], + llm, config["logger"], timeout_seconds=timeout_seconds, llm_factory=llm_factory ) class LoggingLLM(Runnable): - def __init__(self, llm, logger, structured_output=None, tools=None): + def __init__( + self, + llm, + logger, + structured_output=None, + tools=None, + timeout_seconds: float | None = None, + llm_factory: Callable[[dict[str, Any]], Any] | None = None, + ): self.llm = llm self.logger = logger self.structured_output = structured_output self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None + self.timeout_seconds = timeout_seconds or 10 * 60 + self.llm_factory = llm_factory def _log(self, completion_id, input, output): if self.logger: @@ -167,7 +195,7 @@ async def ainvoke(self, input, config=None, **kwargs): async def execute(): try: result = await asyncio.wait_for( - self.llm.ainvoke(input, config=config), timeout=10 * 60 + self.llm.ainvoke(input, config=config), timeout=self.timeout_seconds ) self._log(completion_id, input, result) except asyncio.TimeoutError as e: @@ -194,10 +222,18 @@ def with_structured_output(self, tools): self.logger, structured_output=tools, tools=[tools], + timeout_seconds=self.timeout_seconds, + llm_factory=self.llm_factory, ) def bind_tools(self, tools): - return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools) + return LoggingLLM( + self.llm.bind_tools(tools), + self.logger, + tools=tools, + timeout_seconds=self.timeout_seconds, + llm_factory=self.llm_factory, + ) def with_retry( self, @@ -217,23 +253,9 @@ def with_config( art_config = CURRENT_CONFIG.get() self.logger = art_config["logger"] - if hasattr(self.llm, "bound"): - setattr( - self.llm, - "bound", - ChatOpenAI( - base_url=art_config["base_url"], - api_key=art_config["api_key"], - model=art_config["model"], - temperature=1.0, - ), - ) - else: - self.llm = ChatOpenAI( - base_url=art_config["base_url"], - api_key=art_config["api_key"], - model=art_config["model"], - temperature=1.0, - ) + if self.llm_factory: + self.llm = self.llm_factory(art_config) + elif hasattr(self.llm, "with_config"): + self.llm = self.llm.with_config(config=config, **kwargs) return self