ainvoke in tool nodes
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
|||||||
import tyro
|
import tyro
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
@@ -55,6 +56,18 @@ class ToolNode(ToolNodeBase):
|
|||||||
out = self.tool_agent.invoke(*inp)
|
out = self.tool_agent.invoke(*inp)
|
||||||
return {"messages": out["messages"]}
|
return {"messages": out["messages"]}
|
||||||
|
|
||||||
|
async def ainvoke(self, state:State):
|
||||||
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
|
inp = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
self.sys_prompt
|
||||||
|
),
|
||||||
|
*state["inp"][0]["messages"][1:]
|
||||||
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
out = await self.tool_agent.ainvoke(*inp)
|
||||||
|
return {"messages": out["messages"]}
|
||||||
|
|
||||||
def get_streamable_tags(self):
|
def get_streamable_tags(self):
|
||||||
return super().get_streamable_tags()
|
return super().get_streamable_tags()
|
||||||
|
|
||||||
@@ -135,6 +148,18 @@ class ChattyToolNode(ToolNodeBase):
|
|||||||
state_msgs = [] if state.get("messages") is None else state.get("messages")
|
state_msgs = [] if state.get("messages") is None else state.get("messages")
|
||||||
return {"messages": state_msgs + chat_msgs + tool_msgs}
|
return {"messages": state_msgs + chat_msgs + tool_msgs}
|
||||||
|
|
||||||
|
async def ainvoke(self, state:State):
|
||||||
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
|
self.tool_done = False
|
||||||
|
|
||||||
|
inp = {"inp": state["inp"]}
|
||||||
|
out = await self.workflow.ainvoke(inp)
|
||||||
|
chat_msgs = out.get("chatty_messages")["messages"]
|
||||||
|
tool_msgs = out.get("tool_messages")["messages"]
|
||||||
|
|
||||||
|
state_msgs = [] if state.get("messages") is None else state.get("messages")
|
||||||
|
return {"messages": state_msgs + chat_msgs + tool_msgs}
|
||||||
|
|
||||||
def _tool_node_call(self, state:ChattyToolState):
|
def _tool_node_call(self, state:ChattyToolState):
|
||||||
# inp = {"messages":[
|
# inp = {"messages":[
|
||||||
# SystemMessage(
|
# SystemMessage(
|
||||||
|
|||||||
Reference in New Issue
Block a user