From ff69e4728c35046db6a8d9a3ecf5874454cf168a Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 5 Jan 2026 22:41:56 +0800 Subject: [PATCH] moved to base class --- lang_agent/graphs/routing.py | 105 +---------------------------------- 1 file changed, 2 insertions(+), 103 deletions(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index c5a9a97..1e5e3d0 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -68,108 +68,6 @@ class RoutingGraph(GraphBase): self.streamable_tags:List[List[str]] = self.tool_node.get_streamable_tags() + [["route_chat_llm"]] - - def _stream_result(self, *nargs, **kwargs): - - def text_iterator(): - for chunk, metadata in self.workflow.stream({"inp": nargs}, - stream_mode="messages", - subgraphs=True, - **kwargs): - if isinstance(metadata, tuple): - chunk, metadata = metadata - - tags = metadata.get("tags") - if not (tags in self.streamable_tags): - continue - - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - yield chunk.content - - text_releaser = TextReleaser(*self.tool_node.get_delay_keys()) - for chunk in text_releaser.release(text_iterator()): - print(f"\033[92m{chunk}\033[0m", end="", flush=True) - yield chunk - - - def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): - self._validate_input(*nargs, **kwargs) - - if as_stream: - # Stream messages from the workflow - print("\033[93m====================STREAM OUTPUT=============================\033[0m") - return self._stream_result(*nargs, **kwargs) - else: - state = self.workflow.invoke({"inp": nargs}) - - msg_list = jax.tree.leaves(state) - - for e in msg_list: - if isinstance(e, BaseMessage): - e.pretty_print() - - if as_raw: - return msg_list - - return msg_list[-1].content - - async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): - """Async version of invoke using LangGraph's native async support.""" - self._validate_input(*nargs, **kwargs) - - if as_stream: - # Stream messages from the workflow asynchronously - print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m") - return self._astream_result(*nargs, **kwargs) - else: - state = await self.workflow.ainvoke({"inp": nargs}) - - msg_list = jax.tree.leaves(state) - - for e in msg_list: - if isinstance(e, BaseMessage): - e.pretty_print() - - if as_raw: - return msg_list - - return msg_list[-1].content - - async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]: - """Async streaming using LangGraph's astream method.""" - - async def text_iterator(): - async for chunk, metadata in self.workflow.astream( - {"inp": nargs}, - stream_mode="messages", - subgraphs=True, - **kwargs - ): - if isinstance(metadata, tuple): - chunk, metadata = metadata - - tags = metadata.get("tags") - if not (tags in self.streamable_tags): - continue - - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - yield chunk.content - - text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys()) - async for chunk in text_releaser.release(text_iterator()): - print(f"\033[92m{chunk}\033[0m", end="", flush=True) - yield chunk - - def _validate_input(self, *nargs, **kwargs): - print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m") - for e in nargs[0]["messages"]: - if isinstance(e, HumanMessage): - e.pretty_print() - print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m") - print(f"\033[93m model used: {self.config.llm_name}\033[0m") - - assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" - assert len(kwargs) == 0, "due to inp assumptions" def _get_chat_tools(self, man:ToolManager): return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names] @@ -320,7 +218,8 @@ if __name__ == "__main__": },{"configurable": {"thread_id": "3"}} for chunk in route.invoke(*nargs, as_stream=True): - print(f"\033[92m{chunk}\033[0m", end="", flush=True) + # print(f"\033[92m{chunk}\033[0m", end="", flush=True) + continue # for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):