510 lines
21 KiB
Python
510 lines
21 KiB
Python
import time
|
|
import asyncio
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Iterator, AsyncIterator, Any, List, Tuple, Optional
|
|
from collections import deque
|
|
|
|
|
|
@dataclass
|
|
class ReleaseState:
|
|
"""Mutable state for the release process."""
|
|
read_idx: int = 0
|
|
yield_idx: int = 0
|
|
in_delay_mode: bool = False
|
|
start_key_search_pos: int = 0
|
|
end_key_search_pos: int = 0
|
|
found_start_keys: List[Tuple[int, int]] = field(default_factory=list)
|
|
found_end_keys: List[Tuple[int, int]] = field(default_factory=list)
|
|
|
|
|
|
class TextReleaser:
|
|
"""
|
|
Controls the release rate of streaming text chunks based on special key markers.
|
|
|
|
This class consumes a text stream and yields chunks with controlled timing:
|
|
- Before start_key: yields chunks immediately as they arrive
|
|
- After start_key: yields chunks with a delay (WAIT_TIME) between each
|
|
- When end_key appears: skips all pending delayed chunks and resumes immediately after end_key
|
|
- Both start_key and end_key are filtered out and never yielded
|
|
|
|
Keys can span multiple chunks (e.g., "[CHAT", "TY_OUT]") and are detected across boundaries.
|
|
Keys always start with "[" character, allowing the class to buffer potential key prefixes.
|
|
|
|
Example flow:
|
|
Input: ["Hello ", "[CHATTY_OUT]", "Thinking...", "[TOOL_OUT]", " Result: 42"]
|
|
Output: ["Hello "] -> delay -> ["Thinking..."] -> skip to -> [" Result: 42"]
|
|
"""
|
|
KEY_START_CHAR = "["
|
|
|
|
def __init__(self, start_key: str = None, end_key: str = None, wait_time:float=0.15):
|
|
"""
|
|
Initialize the TextReleaser.
|
|
|
|
Args:
|
|
start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]")
|
|
end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]")
|
|
"""
|
|
self.start_key = start_key
|
|
self.end_key = end_key
|
|
self.WAIT_TIME = wait_time # sec/word in chinese
|
|
|
|
# Internal state for producer-consumer pattern
|
|
self._buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
|
|
self._lock = threading.Lock()
|
|
self._producer_done = threading.Event()
|
|
self._accumulated_text = "" # full text accumulated so far for key detection
|
|
|
|
def _is_prefix_of_key(self, text: str) -> bool:
|
|
"""Check if text is a prefix of start_key or end_key."""
|
|
if self.start_key and self.start_key.startswith(text):
|
|
return True
|
|
if self.end_key and self.end_key.startswith(text):
|
|
return True
|
|
return False
|
|
|
|
def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]:
|
|
"""Find the start and end position of a key in accumulated text after given position."""
|
|
if not key:
|
|
return None
|
|
idx = self._accumulated_text.find(key, after_pos)
|
|
if idx == -1:
|
|
return None
|
|
return (idx, idx + len(key))
|
|
|
|
def _producer(self, text_iterator: Iterator[Any]):
|
|
"""Consumes text_iterator and stores chunks into buffer as they arrive."""
|
|
for chunk in text_iterator:
|
|
with self._lock:
|
|
if isinstance(chunk, str):
|
|
start_pos = len(self._accumulated_text)
|
|
self._accumulated_text += chunk
|
|
end_pos = len(self._accumulated_text)
|
|
self._buffer.append((chunk, start_pos, end_pos))
|
|
else:
|
|
self._buffer.append((chunk, None, None))
|
|
self._producer_done.set()
|
|
|
|
def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool:
|
|
"""Check if chunk overlaps with a given range."""
|
|
return not (chunk_end <= range_start or chunk_start >= range_end)
|
|
|
|
def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None:
|
|
"""Search for complete start_key and end_key occurrences in accumulated text."""
|
|
# Search for start_keys
|
|
while True:
|
|
key_range = self._find_key_range(self.start_key, state.start_key_search_pos)
|
|
if key_range and key_range[1] <= accumulated_len:
|
|
state.found_start_keys.append(key_range)
|
|
state.start_key_search_pos = key_range[1]
|
|
else:
|
|
break
|
|
|
|
# Search for end_keys
|
|
while True:
|
|
key_range = self._find_key_range(self.end_key, state.end_key_search_pos)
|
|
if key_range and key_range[1] <= accumulated_len:
|
|
state.found_end_keys.append(key_range)
|
|
state.end_key_search_pos = key_range[1]
|
|
else:
|
|
break
|
|
|
|
def _find_potential_key_position(self, accumulated: str) -> int:
|
|
"""Find position of potential incomplete key at end of accumulated text. Returns -1 if none."""
|
|
max_key_len = max(len(self.start_key or ""), len(self.end_key or ""))
|
|
for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)):
|
|
suffix = accumulated[search_start:]
|
|
if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix):
|
|
return search_start
|
|
return -1
|
|
|
|
def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int:
|
|
"""Determine the safe position up to which we can yield chunks."""
|
|
potential_key_pos = self._find_potential_key_position(accumulated)
|
|
if potential_key_pos >= 0 and not producer_done:
|
|
return potential_key_pos
|
|
return len(accumulated)
|
|
|
|
def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None:
|
|
"""Update delay mode based on chunk position relative to keys."""
|
|
# Check if should enter delay mode (after start_key)
|
|
if not state.in_delay_mode and state.found_start_keys:
|
|
for sk_range in state.found_start_keys:
|
|
if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end):
|
|
state.in_delay_mode = True
|
|
break
|
|
|
|
# Check if should exit delay mode (after end_key)
|
|
if state.in_delay_mode and state.found_end_keys:
|
|
for ek_range in state.found_end_keys:
|
|
if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end):
|
|
state.in_delay_mode = False
|
|
break
|
|
|
|
def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool:
|
|
"""Check if chunk should be skipped because it's before an end_key in delay mode."""
|
|
if not state.in_delay_mode:
|
|
return False
|
|
for ek_range in state.found_end_keys:
|
|
if y_end <= ek_range[0]:
|
|
return True
|
|
return False
|
|
|
|
def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]:
|
|
"""
|
|
Given a chunk's position range, return the text that should be yielded.
|
|
Returns None if the entire chunk should be skipped.
|
|
Handles partial overlaps with keys by extracting non-key portions.
|
|
"""
|
|
all_key_ranges = state.found_start_keys + state.found_end_keys
|
|
|
|
# Sort relevant key ranges by start position
|
|
relevant_ranges = sorted(
|
|
[r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])],
|
|
key=lambda x: x[0]
|
|
)
|
|
|
|
if not relevant_ranges:
|
|
return self._accumulated_text[y_start:y_end]
|
|
|
|
# Extract non-key portions
|
|
result_parts = []
|
|
current_pos = y_start
|
|
|
|
for key_start, key_end in relevant_ranges:
|
|
text_start = max(current_pos, y_start)
|
|
text_end = min(key_start, y_end)
|
|
if text_end > text_start:
|
|
result_parts.append(self._accumulated_text[text_start:text_end])
|
|
current_pos = max(current_pos, key_end)
|
|
|
|
# Add remaining text after last key
|
|
if current_pos < y_end:
|
|
result_parts.append(self._accumulated_text[current_pos:y_end])
|
|
|
|
return "".join(result_parts) if result_parts else None
|
|
|
|
def _try_get_next_chunk(self, state: ReleaseState) -> Tuple[Optional[Tuple[str, int, int]], str, bool]:
|
|
"""Try to get the next chunk from buffer. Returns (chunk_data, accumulated_text, producer_done)."""
|
|
with self._lock:
|
|
chunk_data = None
|
|
if state.read_idx < len(self._buffer):
|
|
chunk_data = self._buffer[state.read_idx]
|
|
return chunk_data, self._accumulated_text, self._producer_done.is_set()
|
|
|
|
def _get_chunk_at_yield_idx(self, state: ReleaseState) -> Optional[Tuple[str, int, int]]:
|
|
"""Get chunk data at current yield index."""
|
|
with self._lock:
|
|
if state.yield_idx < len(self._buffer):
|
|
return self._buffer[state.yield_idx]
|
|
return None
|
|
|
|
def release(self, text_iterator: Iterator[str]) -> Iterator[str]:
|
|
"""
|
|
Yields chunks from text_iterator with the following behavior:
|
|
- Before start_key: yield chunks immediately (but hold back if potential key prefix)
|
|
- After start_key (until end_key): yield with WAIT_TIME delay
|
|
- start_key and end_key are never yielded, but text around them in same chunk is yielded
|
|
- When end_key is seen: skip all pending chunks and resume after end_key
|
|
- Keys can span multiple chunks, chunks are held until key is confirmed or ruled out
|
|
"""
|
|
# Reset instance state for safe reuse
|
|
self._buffer.clear()
|
|
self._producer_done.clear()
|
|
self._accumulated_text = ""
|
|
|
|
producer_thread = threading.Thread(target=self._producer, args=(text_iterator,), daemon=True)
|
|
producer_thread.start()
|
|
|
|
state = ReleaseState()
|
|
|
|
while True:
|
|
chunk_data, accumulated, producer_done = self._try_get_next_chunk(state)
|
|
|
|
if chunk_data is None:
|
|
if producer_done:
|
|
with self._lock:
|
|
if state.read_idx >= len(self._buffer):
|
|
break
|
|
else:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
# If it is not string; return the thing
|
|
item, start_pos, end_pos = chunk_data
|
|
if start_pos is None: # Non-string item - yield immediately
|
|
state.read_idx += 1 # <-- ADD THIS LINE
|
|
state.yield_idx = state.read_idx # Skip past this in yield tracking
|
|
yield item
|
|
continue
|
|
|
|
state.read_idx += 1
|
|
self._search_for_keys(state, len(accumulated))
|
|
safe_end_pos = self._get_safe_end_pos(accumulated, producer_done)
|
|
|
|
# Process chunks ready to yield
|
|
while state.yield_idx < state.read_idx:
|
|
chunk_at_yield = self._get_chunk_at_yield_idx(state)
|
|
if chunk_at_yield is None:
|
|
break
|
|
|
|
y_chunk, y_start, y_end = chunk_at_yield
|
|
|
|
if y_end > safe_end_pos and not producer_done:
|
|
break
|
|
|
|
self._update_delay_mode(state, y_start, y_end)
|
|
|
|
if self._should_skip_to_end_key(state, y_end):
|
|
state.yield_idx += 1
|
|
continue
|
|
|
|
state.yield_idx += 1
|
|
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
|
|
|
|
if not text_to_yield:
|
|
continue
|
|
|
|
if state.in_delay_mode:
|
|
# Yield character by character with delay
|
|
for char in text_to_yield:
|
|
yield char
|
|
time.sleep(self.WAIT_TIME)
|
|
else:
|
|
# Yield entire chunk immediately
|
|
yield text_to_yield
|
|
|
|
|
|
class AsyncTextReleaser:
|
|
"""
|
|
Async version of TextReleaser for use with async generators.
|
|
|
|
Controls the release rate of streaming text chunks based on special key markers.
|
|
Uses asyncio instead of threading for non-blocking operation.
|
|
"""
|
|
KEY_START_CHAR = "["
|
|
|
|
def __init__(self, start_key: str = None, end_key: str = None, wait_time:float = 0.15):
|
|
"""
|
|
Initialize the AsyncTextReleaser.
|
|
|
|
Args:
|
|
start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]")
|
|
end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]")
|
|
"""
|
|
self.start_key = start_key
|
|
self.end_key = end_key
|
|
self.WAIT_TIME = wait_time # sec/word in chinese
|
|
self._accumulated_text = ""
|
|
|
|
def _is_prefix_of_key(self, text: str) -> bool:
|
|
"""Check if text is a prefix of start_key or end_key."""
|
|
if self.start_key and self.start_key.startswith(text):
|
|
return True
|
|
if self.end_key and self.end_key.startswith(text):
|
|
return True
|
|
return False
|
|
|
|
def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]:
|
|
"""Find the start and end position of a key in accumulated text after given position."""
|
|
if not key:
|
|
return None
|
|
idx = self._accumulated_text.find(key, after_pos)
|
|
if idx == -1:
|
|
return None
|
|
return (idx, idx + len(key))
|
|
|
|
def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool:
|
|
"""Check if chunk overlaps with a given range."""
|
|
return not (chunk_end <= range_start or chunk_start >= range_end)
|
|
|
|
def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None:
|
|
"""Search for complete start_key and end_key occurrences in accumulated text."""
|
|
while True:
|
|
key_range = self._find_key_range(self.start_key, state.start_key_search_pos)
|
|
if key_range and key_range[1] <= accumulated_len:
|
|
state.found_start_keys.append(key_range)
|
|
state.start_key_search_pos = key_range[1]
|
|
else:
|
|
break
|
|
|
|
while True:
|
|
key_range = self._find_key_range(self.end_key, state.end_key_search_pos)
|
|
if key_range and key_range[1] <= accumulated_len:
|
|
state.found_end_keys.append(key_range)
|
|
state.end_key_search_pos = key_range[1]
|
|
else:
|
|
break
|
|
|
|
def _find_potential_key_position(self, accumulated: str) -> int:
|
|
"""Find position of potential incomplete key at end of accumulated text."""
|
|
max_key_len = max(len(self.start_key or ""), len(self.end_key or ""))
|
|
for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)):
|
|
suffix = accumulated[search_start:]
|
|
if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix):
|
|
return search_start
|
|
return -1
|
|
|
|
def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int:
|
|
"""Determine the safe position up to which we can yield chunks."""
|
|
potential_key_pos = self._find_potential_key_position(accumulated)
|
|
if potential_key_pos >= 0 and not producer_done:
|
|
return potential_key_pos
|
|
return len(accumulated)
|
|
|
|
def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None:
|
|
"""Update delay mode based on chunk position relative to keys."""
|
|
if not state.in_delay_mode and state.found_start_keys:
|
|
for sk_range in state.found_start_keys:
|
|
if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end):
|
|
state.in_delay_mode = True
|
|
break
|
|
|
|
if state.in_delay_mode and state.found_end_keys:
|
|
for ek_range in state.found_end_keys:
|
|
if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end):
|
|
state.in_delay_mode = False
|
|
break
|
|
|
|
def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool:
|
|
"""Check if chunk should be skipped because it's before an end_key in delay mode."""
|
|
if not state.in_delay_mode:
|
|
return False
|
|
for ek_range in state.found_end_keys:
|
|
if y_end <= ek_range[0]:
|
|
return True
|
|
return False
|
|
|
|
def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]:
|
|
"""
|
|
Given a chunk's position range, return the text that should be yielded.
|
|
Returns None if the entire chunk should be skipped.
|
|
"""
|
|
all_key_ranges = state.found_start_keys + state.found_end_keys
|
|
|
|
relevant_ranges = sorted(
|
|
[r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])],
|
|
key=lambda x: x[0]
|
|
)
|
|
|
|
if not relevant_ranges:
|
|
return self._accumulated_text[y_start:y_end]
|
|
|
|
result_parts = []
|
|
current_pos = y_start
|
|
|
|
for key_start, key_end in relevant_ranges:
|
|
text_start = max(current_pos, y_start)
|
|
text_end = min(key_start, y_end)
|
|
if text_end > text_start:
|
|
result_parts.append(self._accumulated_text[text_start:text_end])
|
|
current_pos = max(current_pos, key_end)
|
|
|
|
if current_pos < y_end:
|
|
result_parts.append(self._accumulated_text[current_pos:y_end])
|
|
|
|
return "".join(result_parts) if result_parts else None
|
|
|
|
async def release(self, text_iterator: AsyncIterator[Any]) -> AsyncIterator[Any]:
|
|
"""
|
|
Async version of release that works with async generators.
|
|
|
|
Yields chunks from text_iterator with the following behavior:
|
|
- Before start_key: yield chunks immediately (but hold back if potential key prefix)
|
|
- After start_key (until end_key): yield with WAIT_TIME delay
|
|
- start_key and end_key are never yielded
|
|
- When end_key is seen: skip all pending chunks and resume after end_key
|
|
"""
|
|
# Reset instance state for safe reuse
|
|
self._accumulated_text = ""
|
|
|
|
buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
|
|
state = ReleaseState()
|
|
producer_done = False
|
|
|
|
async def consume_and_process():
|
|
nonlocal producer_done
|
|
|
|
async for chunk in text_iterator:
|
|
if isinstance(chunk, str):
|
|
start_pos = len(self._accumulated_text)
|
|
self._accumulated_text += chunk
|
|
end_pos = len(self._accumulated_text)
|
|
buffer.append((chunk, start_pos, end_pos))
|
|
else:
|
|
buffer.append((chunk, None, None))
|
|
|
|
# Process available chunks
|
|
self._search_for_keys(state, len(self._accumulated_text))
|
|
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, False)
|
|
|
|
while state.yield_idx < len(buffer):
|
|
chunk_at_yield = buffer[state.yield_idx]
|
|
y_chunk, y_start, y_end = chunk_at_yield
|
|
|
|
# If it is not string; return the thing
|
|
if y_start is None: # Non-string item - yield immediately
|
|
state.yield_idx += 1
|
|
yield y_chunk
|
|
continue
|
|
|
|
if y_end > safe_end_pos:
|
|
break
|
|
|
|
self._update_delay_mode(state, y_start, y_end)
|
|
|
|
if self._should_skip_to_end_key(state, y_end):
|
|
state.yield_idx += 1
|
|
continue
|
|
|
|
state.yield_idx += 1
|
|
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
|
|
|
|
if not text_to_yield:
|
|
continue
|
|
|
|
if state.in_delay_mode:
|
|
for char in text_to_yield:
|
|
yield char
|
|
await asyncio.sleep(self.WAIT_TIME)
|
|
else:
|
|
yield text_to_yield
|
|
|
|
producer_done = True
|
|
|
|
# Process remaining chunks after producer is done
|
|
self._search_for_keys(state, len(self._accumulated_text))
|
|
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, True)
|
|
|
|
while state.yield_idx < len(buffer):
|
|
chunk_at_yield = buffer[state.yield_idx]
|
|
y_chunk, y_start, y_end = chunk_at_yield
|
|
|
|
# If it is not string; return the thing
|
|
if y_start is None: # Non-string item - yield immediately
|
|
state.yield_idx += 1
|
|
yield y_chunk
|
|
continue
|
|
|
|
self._update_delay_mode(state, y_start, y_end)
|
|
|
|
if self._should_skip_to_end_key(state, y_end):
|
|
state.yield_idx += 1
|
|
continue
|
|
|
|
state.yield_idx += 1
|
|
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
|
|
|
|
if not text_to_yield:
|
|
continue
|
|
|
|
if state.in_delay_mode:
|
|
for char in text_to_yield:
|
|
yield char
|
|
await asyncio.sleep(self.WAIT_TIME)
|
|
else:
|
|
yield text_to_yield
|
|
|
|
async for chunk in consume_and_process():
|
|
yield chunk
|