diff --git a/scripts/demo_chat.py b/scripts/demo_chat.py index 892083c..36f23a2 100644 --- a/scripts/demo_chat.py +++ b/scripts/demo_chat.py @@ -1,29 +1,41 @@ import tyro -import asyncio -from loguru import logger +from typing import Annotated from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.config import load_tyro_conf -def main(conf:PipelineConfig): + +def main( + conf: PipelineConfig, + stream: Annotated[bool, tyro.conf.arg(name="stream")] = True, +): + """Demo chat script for langchain-agent pipeline. + + Args: + conf: Pipeline configuration + stream: Enable streaming mode for chat responses + """ if conf.config_f is not None: conf = load_tyro_conf(conf.config_f) - pipeline:Pipeline = conf.setup() + pipeline: Pipeline = conf.setup() while True: user_input = input("请讲:") if user_input.lower() == "exit": break - response = pipeline.chat(user_input) - # print(f"回答: {response}") - - # # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) - # out = pipeline.chat("你叫什么名字,我今天心情不好,而且天气也不好,我想去外面玩,帮我计划一下", as_stream=True) - # # out = pipeline.chat("testing", as_stream=True) - # print("=========== final ==========") - # print(out) + + if stream: + # Streaming mode: print chunks as they arrive + print("回答: ", end="", flush=True) + for chunk in pipeline.chat(user_input, as_stream=True): + print(chunk, end="", flush=True) + print() # New line after streaming completes + else: + # Non-streaming mode: print full response + response = pipeline.chat(user_input, as_stream=False) + print(f"回答: {response}") if __name__ == "__main__": - main(tyro.cli(PipelineConfig)) \ No newline at end of file + tyro.cli(main) \ No newline at end of file