diff --git a/lang_agent/base.py b/lang_agent/base.py index edb6944..101a425 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -4,9 +4,13 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt from loguru import logger +import jax from langgraph.graph.state import CompiledStateGraph -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages.base import BaseMessageChunk + +from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser class LangToolBase(ABC): @@ -21,14 +25,111 @@ class LangToolBase(ABC): class GraphBase(ABC): workflow: CompiledStateGraph + streamable_tags: List[List[str]] - @abstractmethod - def invoke(self, *nargs, **kwargs): - pass + def _stream_result(self, *nargs, **kwargs): - async def ainvoke(self, *nargs, **kwargs): - """Async version of invoke. Subclasses should override for true async support.""" - raise NotImplementedError("Subclass should implement ainvoke for async support") + def text_iterator(): + for chunk, metadata in self.workflow.stream({"inp": nargs}, + stream_mode="messages", + subgraphs=True, + **kwargs): + if isinstance(metadata, tuple): + chunk, metadata = metadata + + tags = metadata.get("tags") + if not (tags in self.streamable_tags): + continue + + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + yield chunk.content + + text_releaser = TextReleaser(*self.tool_node.get_delay_keys()) + logger.info("streaming output") + for chunk in text_releaser.release(text_iterator()): + print(f"\033[92m{chunk}\033[0m", end="", flush=True) + yield chunk + + # NOTE: DEFAULT IMPLEMENTATION; Overide to support your class + def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + self._validate_input(*nargs, **kwargs) + + if as_stream: + # Stream messages from the workflow + print("\033[93m====================STREAM OUTPUT=============================\033[0m") + return self._stream_result(*nargs, **kwargs) + else: + state = self.workflow.invoke({"inp": nargs}) + + msg_list = jax.tree.leaves(state) + + for e in msg_list: + if isinstance(e, BaseMessage): + e.pretty_print() + + if as_raw: + return msg_list + + return msg_list[-1].content + + # NOTE: DEFAULT IMPLEMENTATION; Overide to support your class + async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + """Async version of invoke using LangGraph's native async support.""" + self._validate_input(*nargs, **kwargs) + + if as_stream: + # Stream messages from the workflow asynchronously + print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m") + return self._astream_result(*nargs, **kwargs) + else: + state = await self.workflow.ainvoke({"inp": nargs}) + + msg_list = jax.tree.leaves(state) + + for e in msg_list: + if isinstance(e, BaseMessage): + e.pretty_print() + + if as_raw: + return msg_list + + return msg_list[-1].content + + async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]: + """Async streaming using LangGraph's astream method.""" + + async def text_iterator(): + async for chunk, metadata in self.workflow.astream( + {"inp": nargs}, + stream_mode="messages", + subgraphs=True, + **kwargs + ): + if isinstance(metadata, tuple): + chunk, metadata = metadata + + tags = metadata.get("tags") + if not (tags in self.streamable_tags): + continue + + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + yield chunk.content + + text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys()) + async for chunk in text_releaser.release(text_iterator()): + print(f"\033[92m{chunk}\033[0m", end="", flush=True) + yield chunk + + def _validate_input(self, *nargs, **kwargs): + print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m") + for e in nargs[0]["messages"]: + if isinstance(e, HumanMessage): + e.pretty_print() + print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m") + print(f"\033[93m model used: {self.config.llm_name}\033[0m") + + assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" + assert len(kwargs) == 0, "due to inp assumptions" def show_graph(self, ret_img:bool=False): #NOTE: just a useful tool for debugging; has zero useful functionality diff --git a/lang_agent/components/text_releaser.py b/lang_agent/components/text_releaser.py index 025623b..7c5f0ed 100644 --- a/lang_agent/components/text_releaser.py +++ b/lang_agent/components/text_releaser.py @@ -37,7 +37,7 @@ class TextReleaser: """ KEY_START_CHAR = "[" - def __init__(self, start_key: str = None, end_key: str = None): + def __init__(self, start_key: str = None, end_key: str = None, wait_time:float=0.15): """ Initialize the TextReleaser. @@ -47,7 +47,7 @@ class TextReleaser: """ self.start_key = start_key self.end_key = end_key - self.WAIT_TIME = 0.15 # sec/word in chinese + self.WAIT_TIME = wait_time # sec/word in chinese # Internal state for producer-consumer pattern self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos) @@ -273,7 +273,7 @@ class AsyncTextReleaser: """ KEY_START_CHAR = "[" - def __init__(self, start_key: str = None, end_key: str = None): + def __init__(self, start_key: str = None, end_key: str = None, wait_time:float = 0.15): """ Initialize the AsyncTextReleaser. @@ -283,7 +283,7 @@ class AsyncTextReleaser: """ self.start_key = start_key self.end_key = end_key - self.WAIT_TIME = 0.15 # sec/word in chinese + self.WAIT_TIME = wait_time # sec/word in chinese self._accumulated_text = "" def _is_prefix_of_key(self, text: str) -> bool: