From ab5dda1f21b0996516b0d33b5bb4282c426fa84c Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 29 Dec 2025 21:59:11 +0800 Subject: [PATCH] support async --- .../fake_stream_server_dashscopy.py | 4 +- fastapi_server/server_dashscope.py | 60 ++++- fastapi_server/server_openai.py | 51 +++- fastapi_server/test_openai_client.py | 4 +- lang_agent/base.py | 12 +- lang_agent/components/text_releaser.py | 219 +++++++++++++++++- lang_agent/graphs/routing.py | 54 ++++- lang_agent/pipeline.py | 51 +++- 8 files changed, 429 insertions(+), 26 deletions(-) diff --git a/fastapi_server/fake_stream_server_dashscopy.py b/fastapi_server/fake_stream_server_dashscopy.py index 1ba3ca3..59bb656 100644 --- a/fastapi_server/fake_stream_server_dashscopy.py +++ b/fastapi_server/fake_stream_server_dashscopy.py @@ -128,7 +128,7 @@ async def application_responses( user_msg = last.get("content") if isinstance(last, dict) else str(last) # Invoke pipeline (non-stream) then stream-chunk it to the client - result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) if not isinstance(result_text, str): result_text = str(result_text) @@ -206,7 +206,7 @@ async def application_completion( last = messages[-1] user_msg = last.get("content") if isinstance(last, dict) else str(last) - result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) if not isinstance(result_text, str): result_text = str(result_text) diff --git a/fastapi_server/server_dashscope.py b/fastapi_server/server_dashscope.py index a304152..3b6582b 100644 --- a/fastapi_server/server_dashscope.py +++ b/fastapi_server/server_dashscope.py @@ -89,6 +89,46 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen yield f"data: {json.dumps(final)}\n\n" +async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"): + """ + Async version: Stream chunks from pipeline and format as SSE. + Accumulates text and sends incremental updates. + DashScope SDK expects accumulated text in each chunk (not deltas). + """ + created_time = int(time.time()) + accumulated_text = "" + + async for chunk in chunk_generator: + if chunk: + accumulated_text += chunk + data = { + "request_id": response_id, + "code": 200, + "message": "OK", + "output": { + "text": accumulated_text, + "created": created_time, + "model": model, + }, + "is_end": False, + } + yield f"data: {json.dumps(data)}\n\n" + + # Final message with complete text + final = { + "request_id": response_id, + "code": 200, + "message": "OK", + "output": { + "text": accumulated_text, + "created": created_time, + "model": model, + }, + "is_end": True, + } + yield f"data: {json.dumps(final)}\n\n" + + @app.post("/v1/apps/{app_id}/sessions/{session_id}/responses") @app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses") async def application_responses( @@ -137,15 +177,15 @@ async def application_responses( response_id = f"appcmpl-{os.urandom(12).hex()}" if stream: - # Use actual streaming from pipeline - chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id) + # Use async streaming from pipeline + chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) return StreamingResponse( - sse_chunks_from_stream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), + sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), media_type="text/event-stream", ) - # Non-streaming: get full result - result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + # Non-streaming: get full result using async + result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) if not isinstance(result_text, str): result_text = str(result_text) @@ -217,15 +257,15 @@ async def application_completion( response_id = f"appcmpl-{os.urandom(12).hex()}" if stream: - # Use actual streaming from pipeline - chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id) + # Use async streaming from pipeline + chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) return StreamingResponse( - sse_chunks_from_stream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), + sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), media_type="text/event-stream", ) - # Non-streaming: get full result - result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + # Non-streaming: get full result using async + result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) if not isinstance(result_text, str): result_text = str(result_text) diff --git a/fastapi_server/server_openai.py b/fastapi_server/server_openai.py index b04c12b..1fa4c0d 100644 --- a/fastapi_server/server_openai.py +++ b/fastapi_server/server_openai.py @@ -91,6 +91,47 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create yield "data: [DONE]\n\n" +async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int): + """ + Async version: Stream chunks from pipeline and format as OpenAI SSE. + """ + async for chunk in chunk_generator: + if chunk: + data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": chunk + }, + "finish_reason": None + } + ] + } + yield f"data: {json.dumps(data)}\n\n" + + # Final message + final = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop" + } + ] + } + yield f"data: {json.dumps(final)}\n\n" + yield "data: [DONE]\n\n" + + @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: @@ -121,15 +162,15 @@ async def chat_completions(request: Request): created_time = int(time.time()) if stream: - # Use actual streaming from pipeline - chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id) + # Use async streaming from pipeline + chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) return StreamingResponse( - sse_chunks_from_stream(chunk_generator, response_id=response_id, model=model, created_time=created_time), + sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time), media_type="text/event-stream", ) - # Non-streaming: get full result - result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + # Non-streaming: get full result using async + result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) if not isinstance(result_text, str): result_text = str(result_text) diff --git a/fastapi_server/test_openai_client.py b/fastapi_server/test_openai_client.py index 25620b7..57cb809 100644 --- a/fastapi_server/test_openai_client.py +++ b/fastapi_server/test_openai_client.py @@ -113,13 +113,13 @@ def main(): print(f"\nUsing base_url = {BASE_URL}\n") # Test both streaming and non-streaming - # streaming_result = test_streaming() + streaming_result = test_streaming() non_streaming_result = test_non_streaming() print("\n" + "="*60) print("SUMMARY") print("="*60) - # print(f"Streaming response length: {len(streaming_result)}") + print(f"Streaming response length: {len(streaming_result)}") print(f"Non-streaming response length: {len(non_streaming_result)}") print("\nBoth tests completed successfully!") diff --git a/lang_agent/base.py b/lang_agent/base.py index d862793..6bccc79 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -1,4 +1,4 @@ -from typing import List, Callable, Tuple, Dict +from typing import List, Callable, Tuple, Dict, AsyncIterator from abc import ABC, abstractmethod from PIL import Image from io import BytesIO @@ -26,6 +26,10 @@ class GraphBase(ABC): def invoke(self, *nargs, **kwargs): pass + async def ainvoke(self, *nargs, **kwargs): + """Async version of invoke. Subclasses should override for true async support.""" + raise NotImplementedError("Subclass should implement ainvoke for async support") + def show_graph(self, ret_img:bool=False): #NOTE: just a useful tool for debugging; has zero useful functionality @@ -59,4 +63,8 @@ class ToolNodeBase(GraphBase): @abstractmethod def invoke(self, inp)->Dict[str, List[BaseMessage]]: - pass \ No newline at end of file + pass + + async def ainvoke(self, inp)->Dict[str, List[BaseMessage]]: + """Async version of invoke. Subclasses should override for true async support.""" + raise NotImplementedError("Subclass should implement ainvoke for async support") \ No newline at end of file diff --git a/lang_agent/components/text_releaser.py b/lang_agent/components/text_releaser.py index 1d10521..3203d14 100644 --- a/lang_agent/components/text_releaser.py +++ b/lang_agent/components/text_releaser.py @@ -1,7 +1,8 @@ import time +import asyncio import threading from dataclasses import dataclass, field -from typing import Iterator, Any, List, Tuple, Optional +from typing import Iterator, AsyncIterator, Any, List, Tuple, Optional from collections import deque @@ -256,3 +257,219 @@ class TextReleaser: else: # Yield entire chunk immediately yield text_to_yield + + +class AsyncTextReleaser: + """ + Async version of TextReleaser for use with async generators. + + Controls the release rate of streaming text chunks based on special key markers. + Uses asyncio instead of threading for non-blocking operation. + """ + KEY_START_CHAR = "[" + + def __init__(self, start_key: str = None, end_key: str = None): + """ + Initialize the AsyncTextReleaser. + + Args: + start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]") + end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]") + """ + self.start_key = start_key + self.end_key = end_key + self.WAIT_TIME = 0.15 # sec/word in chinese + self._accumulated_text = "" + + def _is_prefix_of_key(self, text: str) -> bool: + """Check if text is a prefix of start_key or end_key.""" + if self.start_key and self.start_key.startswith(text): + return True + if self.end_key and self.end_key.startswith(text): + return True + return False + + def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]: + """Find the start and end position of a key in accumulated text after given position.""" + if not key: + return None + idx = self._accumulated_text.find(key, after_pos) + if idx == -1: + return None + return (idx, idx + len(key)) + + def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool: + """Check if chunk overlaps with a given range.""" + return not (chunk_end <= range_start or chunk_start >= range_end) + + def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None: + """Search for complete start_key and end_key occurrences in accumulated text.""" + while True: + key_range = self._find_key_range(self.start_key, state.start_key_search_pos) + if key_range and key_range[1] <= accumulated_len: + state.found_start_keys.append(key_range) + state.start_key_search_pos = key_range[1] + else: + break + + while True: + key_range = self._find_key_range(self.end_key, state.end_key_search_pos) + if key_range and key_range[1] <= accumulated_len: + state.found_end_keys.append(key_range) + state.end_key_search_pos = key_range[1] + else: + break + + def _find_potential_key_position(self, accumulated: str) -> int: + """Find position of potential incomplete key at end of accumulated text.""" + max_key_len = max(len(self.start_key or ""), len(self.end_key or "")) + for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)): + suffix = accumulated[search_start:] + if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix): + return search_start + return -1 + + def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int: + """Determine the safe position up to which we can yield chunks.""" + potential_key_pos = self._find_potential_key_position(accumulated) + if potential_key_pos >= 0 and not producer_done: + return potential_key_pos + return len(accumulated) + + def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None: + """Update delay mode based on chunk position relative to keys.""" + if not state.in_delay_mode and state.found_start_keys: + for sk_range in state.found_start_keys: + if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end): + state.in_delay_mode = True + break + + if state.in_delay_mode and state.found_end_keys: + for ek_range in state.found_end_keys: + if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end): + state.in_delay_mode = False + break + + def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool: + """Check if chunk should be skipped because it's before an end_key in delay mode.""" + if not state.in_delay_mode: + return False + for ek_range in state.found_end_keys: + if y_end <= ek_range[0]: + return True + return False + + def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]: + """ + Given a chunk's position range, return the text that should be yielded. + Returns None if the entire chunk should be skipped. + """ + all_key_ranges = state.found_start_keys + state.found_end_keys + + relevant_ranges = sorted( + [r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])], + key=lambda x: x[0] + ) + + if not relevant_ranges: + return self._accumulated_text[y_start:y_end] + + result_parts = [] + current_pos = y_start + + for key_start, key_end in relevant_ranges: + text_start = max(current_pos, y_start) + text_end = min(key_start, y_end) + if text_end > text_start: + result_parts.append(self._accumulated_text[text_start:text_end]) + current_pos = max(current_pos, key_end) + + if current_pos < y_end: + result_parts.append(self._accumulated_text[current_pos:y_end]) + + return "".join(result_parts) if result_parts else None + + async def release(self, text_iterator: AsyncIterator[str]) -> AsyncIterator[str]: + """ + Async version of release that works with async generators. + + Yields chunks from text_iterator with the following behavior: + - Before start_key: yield chunks immediately (but hold back if potential key prefix) + - After start_key (until end_key): yield with WAIT_TIME delay + - start_key and end_key are never yielded + - When end_key is seen: skip all pending chunks and resume after end_key + """ + buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos) + state = ReleaseState() + producer_done = False + + async def consume_and_process(): + nonlocal producer_done + + async for chunk in text_iterator: + start_pos = len(self._accumulated_text) + self._accumulated_text += chunk + end_pos = len(self._accumulated_text) + buffer.append((chunk, start_pos, end_pos)) + + # Process available chunks + self._search_for_keys(state, len(self._accumulated_text)) + safe_end_pos = self._get_safe_end_pos(self._accumulated_text, False) + + while state.yield_idx < len(buffer): + chunk_at_yield = buffer[state.yield_idx] + y_chunk, y_start, y_end = chunk_at_yield + + if y_end > safe_end_pos: + break + + self._update_delay_mode(state, y_start, y_end) + + if self._should_skip_to_end_key(state, y_end): + state.yield_idx += 1 + continue + + state.yield_idx += 1 + text_to_yield = self._get_text_to_yield(y_start, y_end, state) + + if not text_to_yield: + continue + + if state.in_delay_mode: + for char in text_to_yield: + yield char + await asyncio.sleep(self.WAIT_TIME) + else: + yield text_to_yield + + producer_done = True + + # Process remaining chunks after producer is done + self._search_for_keys(state, len(self._accumulated_text)) + safe_end_pos = self._get_safe_end_pos(self._accumulated_text, True) + + while state.yield_idx < len(buffer): + chunk_at_yield = buffer[state.yield_idx] + y_chunk, y_start, y_end = chunk_at_yield + + self._update_delay_mode(state, y_start, y_end) + + if self._should_skip_to_end_key(state, y_end): + state.yield_idx += 1 + continue + + state.yield_idx += 1 + text_to_yield = self._get_text_to_yield(y_start, y_end, state) + + if not text_to_yield: + continue + + if state.in_delay_mode: + for char in text_to_yield: + yield char + await asyncio.sleep(self.WAIT_TIME) + else: + yield text_to_yield + + async for chunk in consume_and_process(): + yield chunk diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index e53fd5d..10e2cfa 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, is_dataclass -from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any +from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator import tyro from pydantic import BaseModel, Field from loguru import logger @@ -14,7 +14,7 @@ from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.graphs.graph_states import State from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig -from lang_agent.components.text_releaser import TextReleaser +from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage @@ -110,7 +110,55 @@ class RoutingGraph(GraphBase): if as_raw: return msg_list - return msg_list[-1].content + return msg_list[-1].content + + async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + """Async version of invoke using LangGraph's native async support.""" + self._validate_input(*nargs, **kwargs) + + if as_stream: + # Stream messages from the workflow asynchronously + print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m") + return self._astream_result(*nargs, **kwargs) + else: + state = await self.workflow.ainvoke({"inp": nargs}) + + msg_list = jax.tree.leaves(state) + + for e in msg_list: + if isinstance(e, BaseMessage): + e.pretty_print() + + if as_raw: + return msg_list + + return msg_list[-1].content + + async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]: + """Async streaming using LangGraph's astream method.""" + streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]] + + async def text_iterator(): + async for chunk, metadata in self.workflow.astream( + {"inp": nargs}, + stream_mode="messages", + subgraphs=True, + **kwargs + ): + if isinstance(metadata, tuple): + chunk, metadata = metadata + + tags = metadata.get("tags") + if not (tags in streamable_tags): + continue + + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + yield chunk.content + + text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys()) + async for chunk in text_releaser.release(text_iterator()): + print(f"\033[92m{chunk}\033[0m", end="", flush=True) + yield chunk def _validate_input(self, *nargs, **kwargs): print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m") diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 6a48a1e..3401e21 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -140,6 +140,11 @@ class Pipeline: for chunk in out: yield chunk + async def _astream_res(self, out): + """Async version of _stream_res for async generators.""" + async for chunk in out: + yield chunk + def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None): """ as_stream (bool): if true, enable the thing to be streamable @@ -158,4 +163,48 @@ class Pipeline: # Yield chunks from the generator return self._stream_res(out) else: - return out \ No newline at end of file + return out + + async def ainvoke(self, *nargs, **kwargs): + """Async version of invoke using LangGraph's native async support.""" + out = await self.graph.ainvoke(*nargs, **kwargs) + + # If streaming, return async generator + if kwargs.get("as_stream"): + return self._astream_res(out) + + # Non-streaming path + if kwargs.get("as_raw"): + return out + + if isinstance(out, SystemMessage) or isinstance(out, HumanMessage): + return out.content + + if isinstance(out, list): + return out[-1].content + + if isinstance(out, str): + return out + + assert 0, "something is wrong" + + async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None): + """ + Async version of chat using LangGraph's native async support. + + as_stream (bool): if true, enable the thing to be streamable + as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] + """ + # NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph + u = DEFAULT_PROMPT + + thread_id = thread_id if thread_id is not None else 3 + inp_data = {"messages":[SystemMessage(u), + HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id}} + + if as_stream: + # Return async generator for streaming + out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw) + return self._astream_res(out) + else: + return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw) \ No newline at end of file