make wait_time a param

This commit is contained in:
2026-01-05 22:41:40 +08:00
parent 18494b1aea
commit 9f554bcf94
2 changed files with 112 additions and 11 deletions

View File

@@ -4,9 +4,13 @@ from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
import jax
from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
class LangToolBase(ABC):
@@ -21,14 +25,111 @@ class LangToolBase(ABC):
class GraphBase(ABC):
workflow: CompiledStateGraph
streamable_tags: List[List[str]]
@abstractmethod
def invoke(self, *nargs, **kwargs):
pass
def _stream_result(self, *nargs, **kwargs):
async def ainvoke(self, *nargs, **kwargs):
"""Async version of invoke. Subclasses should override for true async support."""
raise NotImplementedError("Subclass should implement ainvoke for async support")
def text_iterator():
for chunk, metadata in self.workflow.stream({"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = TextReleaser(*self.tool_node.get_delay_keys())
logger.info("streaming output")
for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
return self._stream_result(*nargs, **kwargs)
else:
state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow asynchronously
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
return self._astream_result(*nargs, **kwargs)
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
"""Async streaming using LangGraph's astream method."""
async def text_iterator():
async for chunk, metadata in self.workflow.astream(
{"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs
):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
async for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
def _validate_input(self, *nargs, **kwargs):
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
for e in nargs[0]["messages"]:
if isinstance(e, HumanMessage):
e.pretty_print()
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def show_graph(self, ret_img:bool=False):
#NOTE: just a useful tool for debugging; has zero useful functionality