diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 0128479..9414c8c 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -118,9 +118,7 @@ class Pipeline: # If streaming, yield chunks from the generator if kwargs.get("as_stream"): - for chunk in out: - yield chunk - return + return self._stream_res(out) # Non-streaming path if kwargs.get("as_raw"): @@ -165,6 +163,9 @@ class Pipeline: def get_ws_url(self): return f"ws://{self.config.host}:{self.config.port}" + def _stream_res(self, out:list): + for chunk in out: + yield chunk def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:int = None): # NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph @@ -178,7 +179,6 @@ class Pipeline: if as_stream: # Yield chunks from the generator - for chunk in out: - yield chunk + return self._stream_res(out) else: return out \ No newline at end of file