make wait_time a param

This commit is contained in:
2026-01-05 22:41:40 +08:00
parent 18494b1aea
commit 9f554bcf94
2 changed files with 112 additions and 11 deletions

View File

@@ -4,9 +4,13 @@ from PIL import Image
from io import BytesIO from io import BytesIO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from loguru import logger from loguru import logger
import jax
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
class LangToolBase(ABC): class LangToolBase(ABC):
@@ -21,14 +25,111 @@ class LangToolBase(ABC):
class GraphBase(ABC): class GraphBase(ABC):
workflow: CompiledStateGraph workflow: CompiledStateGraph
streamable_tags: List[List[str]]
@abstractmethod def _stream_result(self, *nargs, **kwargs):
def invoke(self, *nargs, **kwargs):
pass
async def ainvoke(self, *nargs, **kwargs): def text_iterator():
"""Async version of invoke. Subclasses should override for true async support.""" for chunk, metadata in self.workflow.stream({"inp": nargs},
raise NotImplementedError("Subclass should implement ainvoke for async support") stream_mode="messages",
subgraphs=True,
**kwargs):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = TextReleaser(*self.tool_node.get_delay_keys())
logger.info("streaming output")
for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
return self._stream_result(*nargs, **kwargs)
else:
state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow asynchronously
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
return self._astream_result(*nargs, **kwargs)
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
"""Async streaming using LangGraph's astream method."""
async def text_iterator():
async for chunk, metadata in self.workflow.astream(
{"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs
):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
async for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
def _validate_input(self, *nargs, **kwargs):
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
for e in nargs[0]["messages"]:
if isinstance(e, HumanMessage):
e.pretty_print()
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def show_graph(self, ret_img:bool=False): def show_graph(self, ret_img:bool=False):
#NOTE: just a useful tool for debugging; has zero useful functionality #NOTE: just a useful tool for debugging; has zero useful functionality

View File

@@ -37,7 +37,7 @@ class TextReleaser:
""" """
KEY_START_CHAR = "[" KEY_START_CHAR = "["
def __init__(self, start_key: str = None, end_key: str = None): def __init__(self, start_key: str = None, end_key: str = None, wait_time:float=0.15):
""" """
Initialize the TextReleaser. Initialize the TextReleaser.
@@ -47,7 +47,7 @@ class TextReleaser:
""" """
self.start_key = start_key self.start_key = start_key
self.end_key = end_key self.end_key = end_key
self.WAIT_TIME = 0.15 # sec/word in chinese self.WAIT_TIME = wait_time # sec/word in chinese
# Internal state for producer-consumer pattern # Internal state for producer-consumer pattern
self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos) self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
@@ -273,7 +273,7 @@ class AsyncTextReleaser:
""" """
KEY_START_CHAR = "[" KEY_START_CHAR = "["
def __init__(self, start_key: str = None, end_key: str = None): def __init__(self, start_key: str = None, end_key: str = None, wait_time:float = 0.15):
""" """
Initialize the AsyncTextReleaser. Initialize the AsyncTextReleaser.
@@ -283,7 +283,7 @@ class AsyncTextReleaser:
""" """
self.start_key = start_key self.start_key = start_key
self.end_key = end_key self.end_key = end_key
self.WAIT_TIME = 0.15 # sec/word in chinese self.WAIT_TIME = wait_time # sec/word in chinese
self._accumulated_text = "" self._accumulated_text = ""
def _is_prefix_of_key(self, text: str) -> bool: def _is_prefix_of_key(self, text: str) -> bool: