diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index dd870d8..36c5686 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -135,10 +135,13 @@ class Pipeline: else: CONV_STORE.record_message_list(conv_id, chunk) - async def _astream_res(self, out): + async def _astream_res(self, out, conv_id:str=None): """Async version of _stream_res for async generators.""" async for chunk in out: - yield chunk + if isinstance(chunk, str): + yield chunk + else: + CONV_STORE.record_message_list(conv_id, chunk) def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'): """ @@ -167,9 +170,9 @@ class Pipeline: """Async version of invoke using LangGraph's native async support.""" out = await self.graph.ainvoke(*nargs, **kwargs) - # If streaming, return async generator + # If streaming, return the raw generator (let caller handle wrapping) if kwargs.get("as_stream"): - return self._astream_res(out) + return out # Non-streaming path if kwargs.get("as_raw"): @@ -207,12 +210,13 @@ class Pipeline: HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, "device_id":device_id}} + out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw) + if as_stream: - # Return async generator for streaming - out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw) - return self._astream_res(out) + # Yield chunks from the generator + return self._astream_res(out, thread_id) else: - return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw) + return out def clear_memory(self): """Clear all memory from the graph."""