support stream
This commit is contained in:
@@ -1,29 +1,41 @@
|
|||||||
import tyro
|
import tyro
|
||||||
import asyncio
|
from typing import Annotated
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
from lang_agent.config import load_tyro_conf
|
from lang_agent.config import load_tyro_conf
|
||||||
|
|
||||||
def main(conf:PipelineConfig):
|
|
||||||
|
def main(
|
||||||
|
conf: PipelineConfig,
|
||||||
|
stream: Annotated[bool, tyro.conf.arg(name="stream")] = True,
|
||||||
|
):
|
||||||
|
"""Demo chat script for langchain-agent pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conf: Pipeline configuration
|
||||||
|
stream: Enable streaming mode for chat responses
|
||||||
|
"""
|
||||||
if conf.config_f is not None:
|
if conf.config_f is not None:
|
||||||
conf = load_tyro_conf(conf.config_f)
|
conf = load_tyro_conf(conf.config_f)
|
||||||
|
|
||||||
pipeline:Pipeline = conf.setup()
|
pipeline: Pipeline = conf.setup()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
user_input = input("请讲:")
|
user_input = input("请讲:")
|
||||||
if user_input.lower() == "exit":
|
if user_input.lower() == "exit":
|
||||||
break
|
break
|
||||||
response = pipeline.chat(user_input)
|
|
||||||
# print(f"回答: {response}")
|
if stream:
|
||||||
|
# Streaming mode: print chunks as they arrive
|
||||||
# # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
|
print("回答: ", end="", flush=True)
|
||||||
# out = pipeline.chat("你叫什么名字,我今天心情不好,而且天气也不好,我想去外面玩,帮我计划一下", as_stream=True)
|
for chunk in pipeline.chat(user_input, as_stream=True):
|
||||||
# # out = pipeline.chat("testing", as_stream=True)
|
print(chunk, end="", flush=True)
|
||||||
# print("=========== final ==========")
|
print() # New line after streaming completes
|
||||||
# print(out)
|
else:
|
||||||
|
# Non-streaming mode: print full response
|
||||||
|
response = pipeline.chat(user_input, as_stream=False)
|
||||||
|
print(f"回答: {response}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(tyro.cli(PipelineConfig))
|
tyro.cli(main)
|
||||||
Reference in New Issue
Block a user