update async pipeline to handle non-string
This commit is contained in:
@@ -135,10 +135,13 @@ class Pipeline:
|
||||
else:
|
||||
CONV_STORE.record_message_list(conv_id, chunk)
|
||||
|
||||
async def _astream_res(self, out):
|
||||
async def _astream_res(self, out, conv_id:str=None):
|
||||
"""Async version of _stream_res for async generators."""
|
||||
async for chunk in out:
|
||||
yield chunk
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
CONV_STORE.record_message_list(conv_id, chunk)
|
||||
|
||||
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
||||
"""
|
||||
@@ -167,9 +170,9 @@ class Pipeline:
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||
|
||||
# If streaming, return async generator
|
||||
# If streaming, return the raw generator (let caller handle wrapping)
|
||||
if kwargs.get("as_stream"):
|
||||
return self._astream_res(out)
|
||||
return out
|
||||
|
||||
# Non-streaming path
|
||||
if kwargs.get("as_raw"):
|
||||
@@ -207,12 +210,13 @@ class Pipeline:
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
|
||||
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
|
||||
|
||||
if as_stream:
|
||||
# Return async generator for streaming
|
||||
out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw)
|
||||
return self._astream_res(out)
|
||||
# Yield chunks from the generator
|
||||
return self._astream_res(out, thread_id)
|
||||
else:
|
||||
return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw)
|
||||
return out
|
||||
|
||||
def clear_memory(self):
|
||||
"""Clear all memory from the graph."""
|
||||
|
||||
Reference in New Issue
Block a user