ainvoke in tool nodes

This commit is contained in:
2025-12-29 22:39:08 +08:00
parent fdd7dae796
commit 18bb795dd3

View File

@@ -3,6 +3,7 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro import tyro
import os.path as osp import os.path as osp
import time import time
import asyncio
from loguru import logger from loguru import logger
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, KeyConfig
@@ -55,6 +56,18 @@ class ToolNode(ToolNodeBase):
out = self.tool_agent.invoke(*inp) out = self.tool_agent.invoke(*inp)
return {"messages": out["messages"]} return {"messages": out["messages"]}
async def ainvoke(self, state:State):
"""Async version of invoke using LangGraph's native async support."""
inp = {"messages":[
SystemMessage(
self.sys_prompt
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = await self.tool_agent.ainvoke(*inp)
return {"messages": out["messages"]}
def get_streamable_tags(self): def get_streamable_tags(self):
return super().get_streamable_tags() return super().get_streamable_tags()
@@ -135,6 +148,18 @@ class ChattyToolNode(ToolNodeBase):
state_msgs = [] if state.get("messages") is None else state.get("messages") state_msgs = [] if state.get("messages") is None else state.get("messages")
return {"messages": state_msgs + chat_msgs + tool_msgs} return {"messages": state_msgs + chat_msgs + tool_msgs}
async def ainvoke(self, state:State):
"""Async version of invoke using LangGraph's native async support."""
self.tool_done = False
inp = {"inp": state["inp"]}
out = await self.workflow.ainvoke(inp)
chat_msgs = out.get("chatty_messages")["messages"]
tool_msgs = out.get("tool_messages")["messages"]
state_msgs = [] if state.get("messages") is None else state.get("messages")
return {"messages": state_msgs + chat_msgs + tool_msgs}
def _tool_node_call(self, state:ChattyToolState): def _tool_node_call(self, state:ChattyToolState):
# inp = {"messages":[ # inp = {"messages":[
# SystemMessage( # SystemMessage(