support async
This commit is contained in:
@@ -128,7 +128,7 @@ async def application_responses(
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
# Invoke pipeline (non-stream) then stream-chunk it to the client
|
||||
result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
@@ -206,7 +206,7 @@ async def application_completion(
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
|
||||
@@ -89,6 +89,46 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
DashScope SDK expects accumulated text in each chunk (not deltas).
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
accumulated_text = ""
|
||||
|
||||
async for chunk in chunk_generator:
|
||||
if chunk:
|
||||
accumulated_text += chunk
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": False,
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message with complete text
|
||||
final = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
async def application_responses(
|
||||
@@ -137,15 +177,15 @@ async def application_responses(
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
# Use actual streaming from pipeline
|
||||
chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_stream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: get full result
|
||||
result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
@@ -217,15 +257,15 @@ async def application_completion(
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
# Use actual streaming from pipeline
|
||||
chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_stream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: get full result
|
||||
result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
|
||||
@@ -91,6 +91,47 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
async for chunk in chunk_generator:
|
||||
if chunk:
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message
|
||||
final = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
@@ -121,15 +162,15 @@ async def chat_completions(request: Request):
|
||||
created_time = int(time.time())
|
||||
|
||||
if stream:
|
||||
# Use actual streaming from pipeline
|
||||
chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_stream(chunk_generator, response_id=response_id, model=model, created_time=created_time),
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: get full result
|
||||
result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
|
||||
@@ -113,13 +113,13 @@ def main():
|
||||
print(f"\nUsing base_url = {BASE_URL}\n")
|
||||
|
||||
# Test both streaming and non-streaming
|
||||
# streaming_result = test_streaming()
|
||||
streaming_result = test_streaming()
|
||||
non_streaming_result = test_non_streaming()
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("SUMMARY")
|
||||
print("="*60)
|
||||
# print(f"Streaming response length: {len(streaming_result)}")
|
||||
print(f"Streaming response length: {len(streaming_result)}")
|
||||
print(f"Non-streaming response length: {len(non_streaming_result)}")
|
||||
print("\nBoth tests completed successfully!")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Callable, Tuple, Dict
|
||||
from typing import List, Callable, Tuple, Dict, AsyncIterator
|
||||
from abc import ABC, abstractmethod
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@@ -26,6 +26,10 @@ class GraphBase(ABC):
|
||||
def invoke(self, *nargs, **kwargs):
|
||||
pass
|
||||
|
||||
async def ainvoke(self, *nargs, **kwargs):
|
||||
"""Async version of invoke. Subclasses should override for true async support."""
|
||||
raise NotImplementedError("Subclass should implement ainvoke for async support")
|
||||
|
||||
def show_graph(self, ret_img:bool=False):
|
||||
#NOTE: just a useful tool for debugging; has zero useful functionality
|
||||
|
||||
@@ -59,4 +63,8 @@ class ToolNodeBase(GraphBase):
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, inp)->Dict[str, List[BaseMessage]]:
|
||||
pass
|
||||
pass
|
||||
|
||||
async def ainvoke(self, inp)->Dict[str, List[BaseMessage]]:
|
||||
"""Async version of invoke. Subclasses should override for true async support."""
|
||||
raise NotImplementedError("Subclass should implement ainvoke for async support")
|
||||
@@ -1,7 +1,8 @@
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Any, List, Tuple, Optional
|
||||
from typing import Iterator, AsyncIterator, Any, List, Tuple, Optional
|
||||
from collections import deque
|
||||
|
||||
|
||||
@@ -256,3 +257,219 @@ class TextReleaser:
|
||||
else:
|
||||
# Yield entire chunk immediately
|
||||
yield text_to_yield
|
||||
|
||||
|
||||
class AsyncTextReleaser:
|
||||
"""
|
||||
Async version of TextReleaser for use with async generators.
|
||||
|
||||
Controls the release rate of streaming text chunks based on special key markers.
|
||||
Uses asyncio instead of threading for non-blocking operation.
|
||||
"""
|
||||
KEY_START_CHAR = "["
|
||||
|
||||
def __init__(self, start_key: str = None, end_key: str = None):
|
||||
"""
|
||||
Initialize the AsyncTextReleaser.
|
||||
|
||||
Args:
|
||||
start_key: Marker that triggers delayed yielding (e.g., "[CHATTY_OUT]")
|
||||
end_key: Marker that ends delayed yielding and skips pending chunks (e.g., "[TOOL_OUT]")
|
||||
"""
|
||||
self.start_key = start_key
|
||||
self.end_key = end_key
|
||||
self.WAIT_TIME = 0.15 # sec/word in chinese
|
||||
self._accumulated_text = ""
|
||||
|
||||
def _is_prefix_of_key(self, text: str) -> bool:
|
||||
"""Check if text is a prefix of start_key or end_key."""
|
||||
if self.start_key and self.start_key.startswith(text):
|
||||
return True
|
||||
if self.end_key and self.end_key.startswith(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _find_key_range(self, key: str, after_pos: int = 0) -> Optional[Tuple[int, int]]:
|
||||
"""Find the start and end position of a key in accumulated text after given position."""
|
||||
if not key:
|
||||
return None
|
||||
idx = self._accumulated_text.find(key, after_pos)
|
||||
if idx == -1:
|
||||
return None
|
||||
return (idx, idx + len(key))
|
||||
|
||||
def _chunk_overlaps_range(self, chunk_start: int, chunk_end: int, range_start: int, range_end: int) -> bool:
|
||||
"""Check if chunk overlaps with a given range."""
|
||||
return not (chunk_end <= range_start or chunk_start >= range_end)
|
||||
|
||||
def _search_for_keys(self, state: ReleaseState, accumulated_len: int) -> None:
|
||||
"""Search for complete start_key and end_key occurrences in accumulated text."""
|
||||
while True:
|
||||
key_range = self._find_key_range(self.start_key, state.start_key_search_pos)
|
||||
if key_range and key_range[1] <= accumulated_len:
|
||||
state.found_start_keys.append(key_range)
|
||||
state.start_key_search_pos = key_range[1]
|
||||
else:
|
||||
break
|
||||
|
||||
while True:
|
||||
key_range = self._find_key_range(self.end_key, state.end_key_search_pos)
|
||||
if key_range and key_range[1] <= accumulated_len:
|
||||
state.found_end_keys.append(key_range)
|
||||
state.end_key_search_pos = key_range[1]
|
||||
else:
|
||||
break
|
||||
|
||||
def _find_potential_key_position(self, accumulated: str) -> int:
|
||||
"""Find position of potential incomplete key at end of accumulated text."""
|
||||
max_key_len = max(len(self.start_key or ""), len(self.end_key or ""))
|
||||
for search_start in range(max(0, len(accumulated) - max_key_len + 1), len(accumulated)):
|
||||
suffix = accumulated[search_start:]
|
||||
if suffix.startswith(self.KEY_START_CHAR) and self._is_prefix_of_key(suffix):
|
||||
return search_start
|
||||
return -1
|
||||
|
||||
def _get_safe_end_pos(self, accumulated: str, producer_done: bool) -> int:
|
||||
"""Determine the safe position up to which we can yield chunks."""
|
||||
potential_key_pos = self._find_potential_key_position(accumulated)
|
||||
if potential_key_pos >= 0 and not producer_done:
|
||||
return potential_key_pos
|
||||
return len(accumulated)
|
||||
|
||||
def _update_delay_mode(self, state: ReleaseState, y_start: int, y_end: int) -> None:
|
||||
"""Update delay mode based on chunk position relative to keys."""
|
||||
if not state.in_delay_mode and state.found_start_keys:
|
||||
for sk_range in state.found_start_keys:
|
||||
if y_start >= sk_range[1] or (y_start < sk_range[1] <= y_end):
|
||||
state.in_delay_mode = True
|
||||
break
|
||||
|
||||
if state.in_delay_mode and state.found_end_keys:
|
||||
for ek_range in state.found_end_keys:
|
||||
if y_start >= ek_range[1] or (y_start < ek_range[1] <= y_end):
|
||||
state.in_delay_mode = False
|
||||
break
|
||||
|
||||
def _should_skip_to_end_key(self, state: ReleaseState, y_end: int) -> bool:
|
||||
"""Check if chunk should be skipped because it's before an end_key in delay mode."""
|
||||
if not state.in_delay_mode:
|
||||
return False
|
||||
for ek_range in state.found_end_keys:
|
||||
if y_end <= ek_range[0]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_text_to_yield(self, y_start: int, y_end: int, state: ReleaseState) -> Optional[str]:
|
||||
"""
|
||||
Given a chunk's position range, return the text that should be yielded.
|
||||
Returns None if the entire chunk should be skipped.
|
||||
"""
|
||||
all_key_ranges = state.found_start_keys + state.found_end_keys
|
||||
|
||||
relevant_ranges = sorted(
|
||||
[r for r in all_key_ranges if self._chunk_overlaps_range(y_start, y_end, r[0], r[1])],
|
||||
key=lambda x: x[0]
|
||||
)
|
||||
|
||||
if not relevant_ranges:
|
||||
return self._accumulated_text[y_start:y_end]
|
||||
|
||||
result_parts = []
|
||||
current_pos = y_start
|
||||
|
||||
for key_start, key_end in relevant_ranges:
|
||||
text_start = max(current_pos, y_start)
|
||||
text_end = min(key_start, y_end)
|
||||
if text_end > text_start:
|
||||
result_parts.append(self._accumulated_text[text_start:text_end])
|
||||
current_pos = max(current_pos, key_end)
|
||||
|
||||
if current_pos < y_end:
|
||||
result_parts.append(self._accumulated_text[current_pos:y_end])
|
||||
|
||||
return "".join(result_parts) if result_parts else None
|
||||
|
||||
async def release(self, text_iterator: AsyncIterator[str]) -> AsyncIterator[str]:
|
||||
"""
|
||||
Async version of release that works with async generators.
|
||||
|
||||
Yields chunks from text_iterator with the following behavior:
|
||||
- Before start_key: yield chunks immediately (but hold back if potential key prefix)
|
||||
- After start_key (until end_key): yield with WAIT_TIME delay
|
||||
- start_key and end_key are never yielded
|
||||
- When end_key is seen: skip all pending chunks and resume after end_key
|
||||
"""
|
||||
buffer: deque = deque() # stores (chunk, chunk_start_pos, chunk_end_pos)
|
||||
state = ReleaseState()
|
||||
producer_done = False
|
||||
|
||||
async def consume_and_process():
|
||||
nonlocal producer_done
|
||||
|
||||
async for chunk in text_iterator:
|
||||
start_pos = len(self._accumulated_text)
|
||||
self._accumulated_text += chunk
|
||||
end_pos = len(self._accumulated_text)
|
||||
buffer.append((chunk, start_pos, end_pos))
|
||||
|
||||
# Process available chunks
|
||||
self._search_for_keys(state, len(self._accumulated_text))
|
||||
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, False)
|
||||
|
||||
while state.yield_idx < len(buffer):
|
||||
chunk_at_yield = buffer[state.yield_idx]
|
||||
y_chunk, y_start, y_end = chunk_at_yield
|
||||
|
||||
if y_end > safe_end_pos:
|
||||
break
|
||||
|
||||
self._update_delay_mode(state, y_start, y_end)
|
||||
|
||||
if self._should_skip_to_end_key(state, y_end):
|
||||
state.yield_idx += 1
|
||||
continue
|
||||
|
||||
state.yield_idx += 1
|
||||
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
|
||||
|
||||
if not text_to_yield:
|
||||
continue
|
||||
|
||||
if state.in_delay_mode:
|
||||
for char in text_to_yield:
|
||||
yield char
|
||||
await asyncio.sleep(self.WAIT_TIME)
|
||||
else:
|
||||
yield text_to_yield
|
||||
|
||||
producer_done = True
|
||||
|
||||
# Process remaining chunks after producer is done
|
||||
self._search_for_keys(state, len(self._accumulated_text))
|
||||
safe_end_pos = self._get_safe_end_pos(self._accumulated_text, True)
|
||||
|
||||
while state.yield_idx < len(buffer):
|
||||
chunk_at_yield = buffer[state.yield_idx]
|
||||
y_chunk, y_start, y_end = chunk_at_yield
|
||||
|
||||
self._update_delay_mode(state, y_start, y_end)
|
||||
|
||||
if self._should_skip_to_end_key(state, y_end):
|
||||
state.yield_idx += 1
|
||||
continue
|
||||
|
||||
state.yield_idx += 1
|
||||
text_to_yield = self._get_text_to_yield(y_start, y_end, state)
|
||||
|
||||
if not text_to_yield:
|
||||
continue
|
||||
|
||||
if state.in_delay_mode:
|
||||
for char in text_to_yield:
|
||||
yield char
|
||||
await asyncio.sleep(self.WAIT_TIME)
|
||||
else:
|
||||
yield text_to_yield
|
||||
|
||||
async for chunk in consume_and_process():
|
||||
yield chunk
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -14,7 +14,7 @@ from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State
|
||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||
from lang_agent.components.text_releaser import TextReleaser
|
||||
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
@@ -110,7 +110,55 @@ class RoutingGraph(GraphBase):
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
return msg_list[-1].content
|
||||
|
||||
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
self._validate_input(*nargs, **kwargs)
|
||||
|
||||
if as_stream:
|
||||
# Stream messages from the workflow asynchronously
|
||||
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
|
||||
return self._astream_result(*nargs, **kwargs)
|
||||
else:
|
||||
state = await self.workflow.ainvoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
e.pretty_print()
|
||||
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
|
||||
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
||||
"""Async streaming using LangGraph's astream method."""
|
||||
streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
|
||||
|
||||
async def text_iterator():
|
||||
async for chunk, metadata in self.workflow.astream(
|
||||
{"inp": nargs},
|
||||
stream_mode="messages",
|
||||
subgraphs=True,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(metadata, tuple):
|
||||
chunk, metadata = metadata
|
||||
|
||||
tags = metadata.get("tags")
|
||||
if not (tags in streamable_tags):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
yield chunk.content
|
||||
|
||||
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
|
||||
async for chunk in text_releaser.release(text_iterator()):
|
||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||
yield chunk
|
||||
|
||||
def _validate_input(self, *nargs, **kwargs):
|
||||
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
|
||||
|
||||
@@ -140,6 +140,11 @@ class Pipeline:
|
||||
for chunk in out:
|
||||
yield chunk
|
||||
|
||||
async def _astream_res(self, out):
|
||||
"""Async version of _stream_res for async generators."""
|
||||
async for chunk in out:
|
||||
yield chunk
|
||||
|
||||
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None):
|
||||
"""
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
@@ -158,4 +163,48 @@ class Pipeline:
|
||||
# Yield chunks from the generator
|
||||
return self._stream_res(out)
|
||||
else:
|
||||
return out
|
||||
return out
|
||||
|
||||
async def ainvoke(self, *nargs, **kwargs):
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||
|
||||
# If streaming, return async generator
|
||||
if kwargs.get("as_stream"):
|
||||
return self._astream_res(out)
|
||||
|
||||
# Non-streaming path
|
||||
if kwargs.get("as_raw"):
|
||||
return out
|
||||
|
||||
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
|
||||
return out.content
|
||||
|
||||
if isinstance(out, list):
|
||||
return out[-1].content
|
||||
|
||||
if isinstance(out, str):
|
||||
return out
|
||||
|
||||
assert 0, "something is wrong"
|
||||
|
||||
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None):
|
||||
"""
|
||||
Async version of chat using LangGraph's native async support.
|
||||
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
"""
|
||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||
u = DEFAULT_PROMPT
|
||||
|
||||
thread_id = thread_id if thread_id is not None else 3
|
||||
inp_data = {"messages":[SystemMessage(u),
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
if as_stream:
|
||||
# Return async generator for streaming
|
||||
out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw)
|
||||
return self._astream_res(out)
|
||||
else:
|
||||
return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw)
|
||||
Reference in New Issue
Block a user