moved to base class
This commit is contained in:
@@ -69,108 +69,6 @@ class RoutingGraph(GraphBase):
|
|||||||
self.streamable_tags:List[List[str]] = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
|
self.streamable_tags:List[List[str]] = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
|
||||||
|
|
||||||
|
|
||||||
def _stream_result(self, *nargs, **kwargs):
|
|
||||||
|
|
||||||
def text_iterator():
|
|
||||||
for chunk, metadata in self.workflow.stream({"inp": nargs},
|
|
||||||
stream_mode="messages",
|
|
||||||
subgraphs=True,
|
|
||||||
**kwargs):
|
|
||||||
if isinstance(metadata, tuple):
|
|
||||||
chunk, metadata = metadata
|
|
||||||
|
|
||||||
tags = metadata.get("tags")
|
|
||||||
if not (tags in self.streamable_tags):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
||||||
yield chunk.content
|
|
||||||
|
|
||||||
text_releaser = TextReleaser(*self.tool_node.get_delay_keys())
|
|
||||||
for chunk in text_releaser.release(text_iterator()):
|
|
||||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
||||||
self._validate_input(*nargs, **kwargs)
|
|
||||||
|
|
||||||
if as_stream:
|
|
||||||
# Stream messages from the workflow
|
|
||||||
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
|
|
||||||
return self._stream_result(*nargs, **kwargs)
|
|
||||||
else:
|
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
|
||||||
|
|
||||||
for e in msg_list:
|
|
||||||
if isinstance(e, BaseMessage):
|
|
||||||
e.pretty_print()
|
|
||||||
|
|
||||||
if as_raw:
|
|
||||||
return msg_list
|
|
||||||
|
|
||||||
return msg_list[-1].content
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
||||||
"""Async version of invoke using LangGraph's native async support."""
|
|
||||||
self._validate_input(*nargs, **kwargs)
|
|
||||||
|
|
||||||
if as_stream:
|
|
||||||
# Stream messages from the workflow asynchronously
|
|
||||||
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
|
|
||||||
return self._astream_result(*nargs, **kwargs)
|
|
||||||
else:
|
|
||||||
state = await self.workflow.ainvoke({"inp": nargs})
|
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
|
||||||
|
|
||||||
for e in msg_list:
|
|
||||||
if isinstance(e, BaseMessage):
|
|
||||||
e.pretty_print()
|
|
||||||
|
|
||||||
if as_raw:
|
|
||||||
return msg_list
|
|
||||||
|
|
||||||
return msg_list[-1].content
|
|
||||||
|
|
||||||
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
|
||||||
"""Async streaming using LangGraph's astream method."""
|
|
||||||
|
|
||||||
async def text_iterator():
|
|
||||||
async for chunk, metadata in self.workflow.astream(
|
|
||||||
{"inp": nargs},
|
|
||||||
stream_mode="messages",
|
|
||||||
subgraphs=True,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if isinstance(metadata, tuple):
|
|
||||||
chunk, metadata = metadata
|
|
||||||
|
|
||||||
tags = metadata.get("tags")
|
|
||||||
if not (tags in self.streamable_tags):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
||||||
yield chunk.content
|
|
||||||
|
|
||||||
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
|
|
||||||
async for chunk in text_releaser.release(text_iterator()):
|
|
||||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def _validate_input(self, *nargs, **kwargs):
|
|
||||||
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
|
|
||||||
for e in nargs[0]["messages"]:
|
|
||||||
if isinstance(e, HumanMessage):
|
|
||||||
e.pretty_print()
|
|
||||||
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
|
|
||||||
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
|
|
||||||
|
|
||||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
|
||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
|
||||||
|
|
||||||
def _get_chat_tools(self, man:ToolManager):
|
def _get_chat_tools(self, man:ToolManager):
|
||||||
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
|
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
|
||||||
|
|
||||||
@@ -320,7 +218,8 @@ if __name__ == "__main__":
|
|||||||
},{"configurable": {"thread_id": "3"}}
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
for chunk in route.invoke(*nargs, as_stream=True):
|
for chunk in route.invoke(*nargs, as_stream=True):
|
||||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
# print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||||
|
|||||||
Reference in New Issue
Block a user