support async

This commit is contained in:
2025-12-29 21:59:11 +08:00
parent 53ecbebb0a
commit ab5dda1f21
8 changed files with 429 additions and 26 deletions

View File

@@ -1,7 +1,8 @@
import time
import asyncio
import threading
from dataclasses import dataclass, field
from typing import Iterator, Any, List, Tuple, Optional
from typing import Iterator, AsyncIterator, Any, List, Tuple, Optional
from collections import deque
@@ -256,3 +257,219 @@ class TextReleaser:
else:
# Yield entire chunk immediately
yield text_to_yield
class AsyncTextReleaser:
"""
Async version of TextReleaser for use with async generators.
Controls the release rate of streaming text chunks based on special key markers.
Uses asyncio instead of threading for non-blocking operation.
"""
KEY_START_CHAR = "["
def __init__(self, start_key: str = None, end_key: str = None):
"""
Initialize the AsyncTextReleaser.
Args:
start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]")
end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]")
"""
self.start_key = start_key
self.end_key = end_key
self.WAIT_TIME = 0.15 # sec/word in chinese
self._accumulated_text = ""
def _is_prefix_of_key(self, text: str) -> bool:
"""Check if text is a prefix of start_key or end_key."""
if self.start_key and self.start_key.startswith(text):
return True
if self.end_key and self.end_key.startswith(text):
return True
return False
def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]:
"""Find the start and end position of a key in accumulated text after given position."""
if not key:
return None
idx = self._accumulated_text.find(key, after_pos)
if idx == -1:
return None
return (idx, idx + len(key))
def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool:
"""Check if chunk overlaps with a given range."""
return not (chunk_end <= range_start or chunk_start >= range_end)
def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None:
"""Search for complete start_key and end_key occurrences in accumulated text."""
while True:
key_range = self._find_key_range(self.start_key, state.start_key_search_pos)
if key_range and key_range[1] <= accumulated_len:
state.found_start_keys.append(key_range)
state.start_key_search_pos = key_range[1]
else:
break
while True:
key_range = self._find_key_range(self.end_key, state.end_key_search_pos)
if key_range and key_range[1] <= accumulated_len:
state.found_end_keys.append(key_range)
state.end_key_search_pos = key_range[1]
else:
break
def _find_potential_key_position(self, accumulated: str) -> int:
"""Find position of potential incomplete key at end of accumulated text."""
max_key_len = max(len(self.start_key or ""), len(self.end_key or ""))
for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)):
suffix = accumulated[search_start:]
if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix):
return search_start
return -1
def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int:
"""Determine the safe position up to which we can yield chunks."""
potential_key_pos = self._find_potential_key_position(accumulated)
if potential_key_pos >= 0 and not producer_done:
return potential_key_pos
return len(accumulated)
def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None:
"""Update delay mode based on chunk position relative to keys."""
if not state.in_delay_mode and state.found_start_keys:
for sk_range in state.found_start_keys:
if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end):
state.in_delay_mode = True
break
if state.in_delay_mode and state.found_end_keys:
for ek_range in state.found_end_keys:
if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end):
state.in_delay_mode = False
break
def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool:
"""Check if chunk should be skipped because it's before an end_key in delay mode."""
if not state.in_delay_mode:
return False
for ek_range in state.found_end_keys:
if y_end <= ek_range[0]:
return True
return False
def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]:
"""
Given a chunk's position range, return the text that should be yielded.
Returns None if the entire chunk should be skipped.
"""
all_key_ranges = state.found_start_keys + state.found_end_keys
relevant_ranges = sorted(
[r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])],
key=lambda x: x[0]
)
if not relevant_ranges:
return self._accumulated_text[y_start:y_end]
result_parts = []
current_pos = y_start
for key_start, key_end in relevant_ranges:
text_start = max(current_pos, y_start)
text_end = min(key_start, y_end)
if text_end > text_start:
result_parts.append(self._accumulated_text[text_start:text_end])
current_pos = max(current_pos, key_end)
if current_pos < y_end:
result_parts.append(self._accumulated_text[current_pos:y_end])
return "".join(result_parts) if result_parts else None
async def release(self, text_iterator: AsyncIterator[str]) -> AsyncIterator[str]:
"""
Async version of release that works with async generators.
Yields chunks from text_iterator with the following behavior:
- Before start_key: yield chunks immediately (but hold back if potential key prefix)
- After start_key (until end_key): yield with WAIT_TIME delay
- start_key and end_key are never yielded
- When end_key is seen: skip all pending chunks and resume after end_key
"""
buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
state = ReleaseState()
producer_done = False
async def consume_and_process():
nonlocal producer_done
async for chunk in text_iterator:
start_pos = len(self._accumulated_text)
self._accumulated_text += chunk
end_pos = len(self._accumulated_text)
buffer.append((chunk, start_pos, end_pos))
# Process available chunks
self._search_for_keys(state, len(self._accumulated_text))
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, False)
while state.yield_idx < len(buffer):
chunk_at_yield = buffer[state.yield_idx]
y_chunk, y_start, y_end = chunk_at_yield
if y_end > safe_end_pos:
break
self._update_delay_mode(state, y_start, y_end)
if self._should_skip_to_end_key(state, y_end):
state.yield_idx += 1
continue
state.yield_idx += 1
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
if not text_to_yield:
continue
if state.in_delay_mode:
for char in text_to_yield:
yield char
await asyncio.sleep(self.WAIT_TIME)
else:
yield text_to_yield
producer_done = True
# Process remaining chunks after producer is done
self._search_for_keys(state, len(self._accumulated_text))
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, True)
while state.yield_idx < len(buffer):
chunk_at_yield = buffer[state.yield_idx]
y_chunk, y_start, y_end = chunk_at_yield
self._update_delay_mode(state, y_start, y_end)
if self._should_skip_to_end_key(state, y_end):
state.yield_idx += 1
continue
state.yield_idx += 1
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
if not text_to_yield:
continue
if state.in_delay_mode:
for char in text_to_yield:
yield char
await asyncio.sleep(self.WAIT_TIME)
else:
yield text_to_yield
async for chunk in consume_and_process():
yield chunk