support async

This commit is contained in:
2025-12-29 21:59:11 +08:00
parent 53ecbebb0a
commit ab5dda1f21
8 changed files with 429 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Callable, Tuple, Dict
from typing import List, Callable, Tuple, Dict, AsyncIterator
from abc import ABC, abstractmethod
from PIL import Image
from io import BytesIO
@@ -26,6 +26,10 @@ class GraphBase(ABC):
def invoke(self, *nargs, **kwargs):
pass
async def ainvoke(self, *nargs, **kwargs):
"""Async version of invoke. Subclasses should override for true async support."""
raise NotImplementedError("Subclass should implement ainvoke for async support")
def show_graph(self, ret_img:bool=False):
#NOTE: just a useful tool for debugging; has zero useful functionality
@@ -59,4 +63,8 @@ class ToolNodeBase(GraphBase):
@abstractmethod
def invoke(self, inp)->Dict[str, List[BaseMessage]]:
pass
pass
async def ainvoke(self, inp)->Dict[str, List[BaseMessage]]:
"""Async version of invoke. Subclasses should override for true async support."""
raise NotImplementedError("Subclass should implement ainvoke for async support")

View File

@@ -1,7 +1,8 @@
import time
import asyncio
import threading
from dataclasses import dataclass, field
from typing import Iterator, Any, List, Tuple, Optional
from typing import Iterator, AsyncIterator, Any, List, Tuple, Optional
from collections import deque
@@ -256,3 +257,219 @@ class TextReleaser:
else:
# Yield entire chunk immediately
yield text_to_yield
class AsyncTextReleaser:
"""
Async version of TextReleaser for use with async generators.
Controls the release rate of streaming text chunks based on special key markers.
Uses asyncio instead of threading for non-blocking operation.
"""
KEY_START_CHAR = "["
def __init__(self, start_key: str = None, end_key: str = None):
"""
Initialize the AsyncTextReleaser.
Args:
start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]")
end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]")
"""
self.start_key = start_key
self.end_key = end_key
self.WAIT_TIME = 0.15 # sec/word in chinese
self._accumulated_text = ""
def _is_prefix_of_key(self, text: str) -> bool:
"""Check if text is a prefix of start_key or end_key."""
if self.start_key and self.start_key.startswith(text):
return True
if self.end_key and self.end_key.startswith(text):
return True
return False
def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]:
"""Find the start and end position of a key in accumulated text after given position."""
if not key:
return None
idx = self._accumulated_text.find(key, after_pos)
if idx == -1:
return None
return (idx, idx + len(key))
def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool:
"""Check if chunk overlaps with a given range."""
return not (chunk_end <= range_start or chunk_start >= range_end)
def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None:
"""Search for complete start_key and end_key occurrences in accumulated text."""
while True:
key_range = self._find_key_range(self.start_key, state.start_key_search_pos)
if key_range and key_range[1] <= accumulated_len:
state.found_start_keys.append(key_range)
state.start_key_search_pos = key_range[1]
else:
break
while True:
key_range = self._find_key_range(self.end_key, state.end_key_search_pos)
if key_range and key_range[1] <= accumulated_len:
state.found_end_keys.append(key_range)
state.end_key_search_pos = key_range[1]
else:
break
def _find_potential_key_position(self, accumulated: str) -> int:
"""Find position of potential incomplete key at end of accumulated text."""
max_key_len = max(len(self.start_key or ""), len(self.end_key or ""))
for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)):
suffix = accumulated[search_start:]
if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix):
return search_start
return -1
def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int:
"""Determine the safe position up to which we can yield chunks."""
potential_key_pos = self._find_potential_key_position(accumulated)
if potential_key_pos >= 0 and not producer_done:
return potential_key_pos
return len(accumulated)
def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None:
"""Update delay mode based on chunk position relative to keys."""
if not state.in_delay_mode and state.found_start_keys:
for sk_range in state.found_start_keys:
if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end):
state.in_delay_mode = True
break
if state.in_delay_mode and state.found_end_keys:
for ek_range in state.found_end_keys:
if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end):
state.in_delay_mode = False
break
def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool:
"""Check if chunk should be skipped because it's before an end_key in delay mode."""
if not state.in_delay_mode:
return False
for ek_range in state.found_end_keys:
if y_end <= ek_range[0]:
return True
return False
def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]:
"""
Given a chunk's position range, return the text that should be yielded.
Returns None if the entire chunk should be skipped.
"""
all_key_ranges = state.found_start_keys + state.found_end_keys
relevant_ranges = sorted(
[r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])],
key=lambda x: x[0]
)
if not relevant_ranges:
return self._accumulated_text[y_start:y_end]
result_parts = []
current_pos = y_start
for key_start, key_end in relevant_ranges:
text_start = max(current_pos, y_start)
text_end = min(key_start, y_end)
if text_end > text_start:
result_parts.append(self._accumulated_text[text_start:text_end])
current_pos = max(current_pos, key_end)
if current_pos < y_end:
result_parts.append(self._accumulated_text[current_pos:y_end])
return "".join(result_parts) if result_parts else None
async def release(self, text_iterator: AsyncIterator[str]) -> AsyncIterator[str]:
"""
Async version of release that works with async generators.
Yields chunks from text_iterator with the following behavior:
- Before start_key: yield chunks immediately (but hold back if potential key prefix)
- After start_key (until end_key): yield with WAIT_TIME delay
- start_key and end_key are never yielded
- When end_key is seen: skip all pending chunks and resume after end_key
"""
buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
state = ReleaseState()
producer_done = False
async def consume_and_process():
nonlocal producer_done
async for chunk in text_iterator:
start_pos = len(self._accumulated_text)
self._accumulated_text += chunk
end_pos = len(self._accumulated_text)
buffer.append((chunk, start_pos, end_pos))
# Process available chunks
self._search_for_keys(state, len(self._accumulated_text))
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, False)
while state.yield_idx < len(buffer):
chunk_at_yield = buffer[state.yield_idx]
y_chunk, y_start, y_end = chunk_at_yield
if y_end > safe_end_pos:
break
self._update_delay_mode(state, y_start, y_end)
if self._should_skip_to_end_key(state, y_end):
state.yield_idx += 1
continue
state.yield_idx += 1
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
if not text_to_yield:
continue
if state.in_delay_mode:
for char in text_to_yield:
yield char
await asyncio.sleep(self.WAIT_TIME)
else:
yield text_to_yield
producer_done = True
# Process remaining chunks after producer is done
self._search_for_keys(state, len(self._accumulated_text))
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, True)
while state.yield_idx < len(buffer):
chunk_at_yield = buffer[state.yield_idx]
y_chunk, y_start, y_end = chunk_at_yield
self._update_delay_mode(state, y_start, y_end)
if self._should_skip_to_end_key(state, y_end):
state.yield_idx += 1
continue
state.yield_idx += 1
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
if not text_to_yield:
continue
if state.in_delay_mode:
for char in text_to_yield:
yield char
await asyncio.sleep(self.WAIT_TIME)
else:
yield text_to_yield
async for chunk in consume_and_process():
yield chunk

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
import tyro
from pydantic import BaseModel, Field
from loguru import logger
@@ -14,7 +14,7 @@ from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.graphs.graph_states import State
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
from lang_agent.components.text_releaser import TextReleaser
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
@@ -110,7 +110,55 @@ class RoutingGraph(GraphBase):
if as_raw:
return msg_list
return msg_list[-1].content
return msg_list[-1].content
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow asynchronously
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
return self._astream_result(*nargs, **kwargs)
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
"""Async streaming using LangGraph's astream method."""
streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
async def text_iterator():
async for chunk, metadata in self.workflow.astream(
{"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs
):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
async for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
def _validate_input(self, *nargs, **kwargs):
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")

View File

@@ -140,6 +140,11 @@ class Pipeline:
for chunk in out:
yield chunk
async def _astream_res(self, out):
"""Async version of _stream_res for async generators."""
async for chunk in out:
yield chunk
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None):
"""
as_stream (bool): if true, enable the thing to be streamable
@@ -158,4 +163,48 @@ class Pipeline:
# Yield chunks from the generator
return self._stream_res(out)
else:
return out
return out
async def ainvoke(self, *nargs, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
out = await self.graph.ainvoke(*nargs, **kwargs)
# If streaming, return async generator
if kwargs.get("as_stream"):
return self._astream_res(out)
# Non-streaming path
if kwargs.get("as_raw"):
return out
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
return out.content
if isinstance(out, list):
return out[-1].content
if isinstance(out, str):
return out
assert 0, "something is wrong"
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None):
"""
Async version of chat using LangGraph's native async support.
as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
"""
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
u = DEFAULT_PROMPT
thread_id = thread_id if thread_id is not None else 3
inp_data = {"messages":[SystemMessage(u),
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id}}
if as_stream:
# Return async generator for streaming
out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw)
return self._astream_res(out)
else:
return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw)