update async pipeline to handle non-string
This commit is contained in:
@@ -135,10 +135,13 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
CONV_STORE.record_message_list(conv_id, chunk)
|
CONV_STORE.record_message_list(conv_id, chunk)
|
||||||
|
|
||||||
async def _astream_res(self, out):
|
async def _astream_res(self, out, conv_id:str=None):
|
||||||
"""Async version of _stream_res for async generators."""
|
"""Async version of _stream_res for async generators."""
|
||||||
async for chunk in out:
|
async for chunk in out:
|
||||||
yield chunk
|
if isinstance(chunk, str):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
CONV_STORE.record_message_list(conv_id, chunk)
|
||||||
|
|
||||||
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
||||||
"""
|
"""
|
||||||
@@ -167,9 +170,9 @@ class Pipeline:
|
|||||||
"""Async version of invoke using LangGraph's native async support."""
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||||
|
|
||||||
# If streaming, return async generator
|
# If streaming, return the raw generator (let caller handle wrapping)
|
||||||
if kwargs.get("as_stream"):
|
if kwargs.get("as_stream"):
|
||||||
return self._astream_res(out)
|
return out
|
||||||
|
|
||||||
# Non-streaming path
|
# Non-streaming path
|
||||||
if kwargs.get("as_raw"):
|
if kwargs.get("as_raw"):
|
||||||
@@ -207,12 +210,13 @@ class Pipeline:
|
|||||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||||
"device_id":device_id}}
|
"device_id":device_id}}
|
||||||
|
|
||||||
|
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
|
||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
# Return async generator for streaming
|
# Yield chunks from the generator
|
||||||
out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw)
|
return self._astream_res(out, thread_id)
|
||||||
return self._astream_res(out)
|
|
||||||
else:
|
else:
|
||||||
return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw)
|
return out
|
||||||
|
|
||||||
def clear_memory(self):
|
def clear_memory(self):
|
||||||
"""Clear all memory from the graph."""
|
"""Clear all memory from the graph."""
|
||||||
|
|||||||
Reference in New Issue
Block a user