ainvoke in tool nodes

This commit is contained in:
2025-12-29 22:39:08 +08:00
parent fdd7dae796
commit 18bb795dd3

View File

@@ -3,6 +3,7 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro
import os.path as osp
import time
import asyncio
from loguru import logger
from lang_agent.config import InstantiateConfig, KeyConfig
@@ -54,6 +55,18 @@ class ToolNode(ToolNodeBase):
out = self.tool_agent.invoke(*inp)
return {"messages": out["messages"]}
async def ainvoke(self, state:State):
"""Async version of invoke using LangGraph's native async support."""
inp = {"messages":[
SystemMessage(
self.sys_prompt
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = await self.tool_agent.ainvoke(*inp)
return {"messages": out["messages"]}
def get_streamable_tags(self):
return super().get_streamable_tags()
@@ -134,6 +147,18 @@ class ChattyToolNode(ToolNodeBase):
state_msgs = [] if state.get("messages") is None else state.get("messages")
return {"messages": state_msgs + chat_msgs + tool_msgs}
async def ainvoke(self, state:State):
"""Async version of invoke using LangGraph's native async support."""
self.tool_done = False
inp = {"inp": state["inp"]}
out = await self.workflow.ainvoke(inp)
chat_msgs = out.get("chatty_messages")["messages"]
tool_msgs = out.get("tool_messages")["messages"]
state_msgs = [] if state.get("messages") is None else state.get("messages")
return {"messages": state_msgs + chat_msgs + tool_msgs}
def _tool_node_call(self, state:ChattyToolState):
# inp = {"messages":[