Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 51 additions & 29 deletions src/art/langgraph/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
import uuid
from typing import Any, Literal
from typing import Any, Callable, Literal

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.prompt_values import ChatPromptValue
Expand Down Expand Up @@ -113,26 +113,54 @@ def init_chat_model(
model_provider: str | None = None,
configurable_fields: Literal[None] = None,
config_prefix: str | None = None,
chat_model=None,
timeout_seconds: float | None = None,
**kwargs: Any,
):
config = CURRENT_CONFIG.get()
if chat_model is not None:
llm = chat_model
llm_factory = None
else:
# Allow overrides passed to init_chat_model, otherwise fall back to the
# current context.
temperature = kwargs.pop("temperature", 1.0)
model_name = model or config["model"]
base_url = kwargs.pop("base_url", config["base_url"])
api_key = kwargs.pop("api_key", config["api_key"])

def llm_factory(ctx_config: dict[str, Any]):
return ChatOpenAI(
base_url=ctx_config.get("base_url", base_url),
api_key=ctx_config.get("api_key", api_key),
model=model_name,
temperature=temperature,
**kwargs,
)

llm = llm_factory(config)

return LoggingLLM(
ChatOpenAI(
base_url=config["base_url"],
api_key=config["api_key"],
model=config["model"],
temperature=1.0,
),
config["logger"],
llm, config["logger"], timeout_seconds=timeout_seconds, llm_factory=llm_factory
)


class LoggingLLM(Runnable):
def __init__(self, llm, logger, structured_output=None, tools=None):
def __init__(
self,
llm,
logger,
structured_output=None,
tools=None,
timeout_seconds: float | None = None,
llm_factory: Callable[[dict[str, Any]], Any] | None = None,
):
self.llm = llm
self.logger = logger
self.structured_output = structured_output
self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None
self.timeout_seconds = timeout_seconds or 10 * 60
self.llm_factory = llm_factory

def _log(self, completion_id, input, output):
if self.logger:
Expand Down Expand Up @@ -167,7 +195,7 @@ async def ainvoke(self, input, config=None, **kwargs):
async def execute():
try:
result = await asyncio.wait_for(
self.llm.ainvoke(input, config=config), timeout=10 * 60
self.llm.ainvoke(input, config=config), timeout=self.timeout_seconds
)
self._log(completion_id, input, result)
except asyncio.TimeoutError as e:
Expand All @@ -194,10 +222,18 @@ def with_structured_output(self, tools):
self.logger,
structured_output=tools,
tools=[tools],
timeout_seconds=self.timeout_seconds,
llm_factory=self.llm_factory,
)

def bind_tools(self, tools):
return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools)
return LoggingLLM(
self.llm.bind_tools(tools),
self.logger,
tools=tools,
timeout_seconds=self.timeout_seconds,
llm_factory=self.llm_factory,
)

def with_retry(
self,
Expand All @@ -217,23 +253,9 @@ def with_config(
art_config = CURRENT_CONFIG.get()
self.logger = art_config["logger"]

if hasattr(self.llm, "bound"):
setattr(
self.llm,
"bound",
ChatOpenAI(
base_url=art_config["base_url"],
api_key=art_config["api_key"],
model=art_config["model"],
temperature=1.0,
),
)
else:
self.llm = ChatOpenAI(
base_url=art_config["base_url"],
api_key=art_config["api_key"],
model=art_config["model"],
temperature=1.0,
)
if self.llm_factory:
self.llm = self.llm_factory(art_config)
elif hasattr(self.llm, "with_config"):
self.llm = self.llm.with_config(config=config, **kwargs)

return self