From 18bb795dd3d50e746c2a5d1e03515b8cf4248961 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 29 Dec 2025 22:39:08 +0800 Subject: [PATCH] ainvoke in tool nodes --- lang_agent/graphs/tool_nodes.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index e503bc7..aabdc3f 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -3,6 +3,7 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple import tyro import os.path as osp import time +import asyncio from loguru import logger from lang_agent.config import InstantiateConfig, KeyConfig @@ -54,6 +55,18 @@ class ToolNode(ToolNodeBase): out = self.tool_agent.invoke(*inp) return {"messages": out["messages"]} + + async def ainvoke(self, state:State): + """Async version of invoke using LangGraph's native async support.""" + inp = {"messages":[ + SystemMessage( + self.sys_prompt + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + + out = await self.tool_agent.ainvoke(*inp) + return {"messages": out["messages"]} def get_streamable_tags(self): return super().get_streamable_tags() @@ -134,6 +147,18 @@ class ChattyToolNode(ToolNodeBase): state_msgs = [] if state.get("messages") is None else state.get("messages") return {"messages": state_msgs + chat_msgs + tool_msgs} + + async def ainvoke(self, state:State): + """Async version of invoke using LangGraph's native async support.""" + self.tool_done = False + + inp = {"inp": state["inp"]} + out = await self.workflow.ainvoke(inp) + chat_msgs = out.get("chatty_messages")["messages"] + tool_msgs = out.get("tool_messages")["messages"] + + state_msgs = [] if state.get("messages") is None else state.get("messages") + return {"messages": state_msgs + chat_msgs + tool_msgs} def _tool_node_call(self, state:ChattyToolState): # inp = {"messages":[