make wait_time a param
This commit is contained in:
@@ -4,9 +4,13 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import jax
|
||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage, HumanMessage
|
||||||
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
|
|
||||||
|
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
||||||
|
|
||||||
|
|
||||||
class LangToolBase(ABC):
|
class LangToolBase(ABC):
|
||||||
@@ -21,14 +25,111 @@ class LangToolBase(ABC):
|
|||||||
|
|
||||||
class GraphBase(ABC):
|
class GraphBase(ABC):
|
||||||
workflow: CompiledStateGraph
|
workflow: CompiledStateGraph
|
||||||
|
streamable_tags: List[List[str]]
|
||||||
|
|
||||||
@abstractmethod
|
def _stream_result(self, *nargs, **kwargs):
|
||||||
def invoke(self, *nargs, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, **kwargs):
|
def text_iterator():
|
||||||
"""Async version of invoke. Subclasses should override for true async support."""
|
for chunk, metadata in self.workflow.stream({"inp": nargs},
|
||||||
raise NotImplementedError("Subclass should implement ainvoke for async support")
|
stream_mode="messages",
|
||||||
|
subgraphs=True,
|
||||||
|
**kwargs):
|
||||||
|
if isinstance(metadata, tuple):
|
||||||
|
chunk, metadata = metadata
|
||||||
|
|
||||||
|
tags = metadata.get("tags")
|
||||||
|
if not (tags in self.streamable_tags):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
text_releaser = TextReleaser(*self.tool_node.get_delay_keys())
|
||||||
|
logger.info("streaming output")
|
||||||
|
for chunk in text_releaser.release(text_iterator()):
|
||||||
|
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
|
||||||
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
|
self._validate_input(*nargs, **kwargs)
|
||||||
|
|
||||||
|
if as_stream:
|
||||||
|
# Stream messages from the workflow
|
||||||
|
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
|
||||||
|
return self._stream_result(*nargs, **kwargs)
|
||||||
|
else:
|
||||||
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
|
msg_list = jax.tree.leaves(state)
|
||||||
|
|
||||||
|
for e in msg_list:
|
||||||
|
if isinstance(e, BaseMessage):
|
||||||
|
e.pretty_print()
|
||||||
|
|
||||||
|
if as_raw:
|
||||||
|
return msg_list
|
||||||
|
|
||||||
|
return msg_list[-1].content
|
||||||
|
|
||||||
|
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
|
||||||
|
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
|
self._validate_input(*nargs, **kwargs)
|
||||||
|
|
||||||
|
if as_stream:
|
||||||
|
# Stream messages from the workflow asynchronously
|
||||||
|
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
|
||||||
|
return self._astream_result(*nargs, **kwargs)
|
||||||
|
else:
|
||||||
|
state = await self.workflow.ainvoke({"inp": nargs})
|
||||||
|
|
||||||
|
msg_list = jax.tree.leaves(state)
|
||||||
|
|
||||||
|
for e in msg_list:
|
||||||
|
if isinstance(e, BaseMessage):
|
||||||
|
e.pretty_print()
|
||||||
|
|
||||||
|
if as_raw:
|
||||||
|
return msg_list
|
||||||
|
|
||||||
|
return msg_list[-1].content
|
||||||
|
|
||||||
|
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
||||||
|
"""Async streaming using LangGraph's astream method."""
|
||||||
|
|
||||||
|
async def text_iterator():
|
||||||
|
async for chunk, metadata in self.workflow.astream(
|
||||||
|
{"inp": nargs},
|
||||||
|
stream_mode="messages",
|
||||||
|
subgraphs=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if isinstance(metadata, tuple):
|
||||||
|
chunk, metadata = metadata
|
||||||
|
|
||||||
|
tags = metadata.get("tags")
|
||||||
|
if not (tags in self.streamable_tags):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
|
||||||
|
async for chunk in text_releaser.release(text_iterator()):
|
||||||
|
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _validate_input(self, *nargs, **kwargs):
|
||||||
|
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
|
||||||
|
for e in nargs[0]["messages"]:
|
||||||
|
if isinstance(e, HumanMessage):
|
||||||
|
e.pretty_print()
|
||||||
|
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
|
||||||
|
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
|
||||||
|
|
||||||
|
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||||
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
def show_graph(self, ret_img:bool=False):
|
def show_graph(self, ret_img:bool=False):
|
||||||
#NOTE: just a useful tool for debugging; has zero useful functionality
|
#NOTE: just a useful tool for debugging; has zero useful functionality
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TextReleaser:
|
|||||||
"""
|
"""
|
||||||
KEY_START_CHAR = "["
|
KEY_START_CHAR = "["
|
||||||
|
|
||||||
def __init__(self, start_key: str = None, end_key: str = None):
|
def __init__(self, start_key: str = None, end_key: str = None, wait_time:float=0.15):
|
||||||
"""
|
"""
|
||||||
Initialize the TextReleaser.
|
Initialize the TextReleaser.
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class TextReleaser:
|
|||||||
"""
|
"""
|
||||||
self.start_key = start_key
|
self.start_key = start_key
|
||||||
self.end_key = end_key
|
self.end_key = end_key
|
||||||
self.WAIT_TIME = 0.15 # sec/word in chinese
|
self.WAIT_TIME = wait_time # sec/word in chinese
|
||||||
|
|
||||||
# Internal state for producer-consumer pattern
|
# Internal state for producer-consumer pattern
|
||||||
self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
|
self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
|
||||||
@@ -273,7 +273,7 @@ class AsyncTextReleaser:
|
|||||||
"""
|
"""
|
||||||
KEY_START_CHAR = "["
|
KEY_START_CHAR = "["
|
||||||
|
|
||||||
def __init__(self, start_key: str = None, end_key: str = None):
|
def __init__(self, start_key: str = None, end_key: str = None, wait_time:float = 0.15):
|
||||||
"""
|
"""
|
||||||
Initialize the AsyncTextReleaser.
|
Initialize the AsyncTextReleaser.
|
||||||
|
|
||||||
@@ -283,7 +283,7 @@ class AsyncTextReleaser:
|
|||||||
"""
|
"""
|
||||||
self.start_key = start_key
|
self.start_key = start_key
|
||||||
self.end_key = end_key
|
self.end_key = end_key
|
||||||
self.WAIT_TIME = 0.15 # sec/word in chinese
|
self.WAIT_TIME = wait_time # sec/word in chinese
|
||||||
self._accumulated_text = ""
|
self._accumulated_text = ""
|
||||||
|
|
||||||
def _is_prefix_of_key(self, text: str) -> bool:
|
def _is_prefix_of_key(self, text: str) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user