diff --git a/scripts/demo_chat.py b/scripts/demo_chat.py index 36f23a2..8c56f4f 100644 --- a/scripts/demo_chat.py +++ b/scripts/demo_chat.py @@ -1,5 +1,7 @@ import tyro from typing import Annotated +import uuid + from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.config import load_tyro_conf @@ -19,7 +21,7 @@ def main( conf = load_tyro_conf(conf.config_f) pipeline: Pipeline = conf.setup() - + thread_id = str(uuid.uuid4()) while True: user_input = input("请讲:") if user_input.lower() == "exit": @@ -28,12 +30,12 @@ def main( if stream: # Streaming mode: print chunks as they arrive print("回答: ", end="", flush=True) - for chunk in pipeline.chat(user_input, as_stream=True): + for chunk in pipeline.chat(user_input, as_stream=True, thread_id=thread_id): print(chunk, end="", flush=True) print() # New line after streaming completes else: # Non-streaming mode: print full response - response = pipeline.chat(user_input, as_stream=False) + response = pipeline.chat(user_input, as_stream=False, thread_id=thread_id) print(f"回答: {response}")