support stream

This commit is contained in:
2026-01-30 10:38:12 +08:00
parent 609e31c9ad
commit 2872e91dd1

View File

@@ -1,29 +1,41 @@
import tyro
import asyncio
from loguru import logger
from typing import Annotated
from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config import load_tyro_conf
def main(conf:PipelineConfig):
def main(
conf: PipelineConfig,
stream: Annotated[bool, tyro.conf.arg(name="stream")] = True,
):
"""Demo chat script for langchain-agent pipeline.
Args:
conf: Pipeline configuration
stream: Enable streaming mode for chat responses
"""
if conf.config_f is not None:
conf = load_tyro_conf(conf.config_f)
pipeline:Pipeline = conf.setup()
pipeline: Pipeline = conf.setup()
while True:
user_input = input("请讲:")
if user_input.lower() == "exit":
break
response = pipeline.chat(user_input)
# print(f"回答: {response}")
# # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
# out = pipeline.chat("你叫什么名字,我今天心情不好,而且天气也不好,我想去外面玩,帮我计划一下", as_stream=True)
# # out = pipeline.chat("testing", as_stream=True)
# print("=========== final ==========")
# print(out)
if stream:
# Streaming mode: print chunks as they arrive
print("回答: ", end="", flush=True)
for chunk in pipeline.chat(user_input, as_stream=True):
print(chunk, end="", flush=True)
print() # New line after streaming completes
else:
# Non-streaming mode: print full response
response = pipeline.chat(user_input, as_stream=False)
print(f"回答: {response}")
if __name__ == "__main__":
main(tyro.cli(PipelineConfig))
tyro.cli(main)