support async
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -14,7 +14,7 @@ from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State
|
||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||
from lang_agent.components.text_releaser import TextReleaser
|
||||
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
@@ -110,7 +110,55 @@ class RoutingGraph(GraphBase):
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
return msg_list[-1].content
|
||||
|
||||
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
self._validate_input(*nargs, **kwargs)
|
||||
|
||||
if as_stream:
|
||||
# Stream messages from the workflow asynchronously
|
||||
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
|
||||
return self._astream_result(*nargs, **kwargs)
|
||||
else:
|
||||
state = await self.workflow.ainvoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
e.pretty_print()
|
||||
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
|
||||
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
||||
"""Async streaming using LangGraph's astream method."""
|
||||
streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
|
||||
|
||||
async def text_iterator():
|
||||
async for chunk, metadata in self.workflow.astream(
|
||||
{"inp": nargs},
|
||||
stream_mode="messages",
|
||||
subgraphs=True,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(metadata, tuple):
|
||||
chunk, metadata = metadata
|
||||
|
||||
tags = metadata.get("tags")
|
||||
if not (tags in streamable_tags):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
yield chunk.content
|
||||
|
||||
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
|
||||
async for chunk in text_releaser.release(text_iterator()):
|
||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||
yield chunk
|
||||
|
||||
def _validate_input(self, *nargs, **kwargs):
|
||||
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
|
||||
|
||||
Reference in New Issue
Block a user