diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 4eef942..45bb46c 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -9,29 +9,28 @@ import os from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage -# from langgraph.prebuilt import create_react_agent + from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver -from lang_agent.config import InstantiateConfig +from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig + + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class PipelineConfig(InstantiateConfig): +class PipelineConfig(KeyConfig): _target: Type = field(default_factory=lambda: Pipeline) config_f: str = None """path to config file""" - llm_name: str = "qwen-plus" - """name of llm""" + llm_name: str = None + """name of llm; use default for qwen-plus""" - llm_provider:str = "openai" - """provider of the llm""" - - api_key:str = None - """api key for llm""" + llm_provider:str = None + """provider of the llm; use default for openai""" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" @@ -45,14 +44,6 @@ class PipelineConfig(InstantiateConfig): # NOTE: For reference tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) - def __post_init__(self): - if self.api_key == "wrong-key" or self.api_key is None: - # logger.info("wrong embedding key, using simple retrieval method") - self.api_key = os.environ.get("ALI_API_KEY") - if self.api_key is None: - logger.error(f"no ALI_API_KEY provided for embedding") - else: - logger.info("ALI_API_KEY loaded from environ") @@ -63,6 +54,11 @@ class Pipeline: self.populate_module() def populate_module(self): + if self.config.llm_name is None: + logger.info(f"setting llm_provider to default") + self.config.llm_name = "qwen-turbo" + self.config.llm_provider = "openai" + self.llm = init_chat_model(model=self.config.llm_name, model_provider=self.config.llm_provider, api_key=self.config.api_key, @@ -149,30 +145,6 @@ class Pipeline: return out['messages'][-1].content -# if __name__ == "__main__": -# pipeline:Pipeline = PipelineConfig().setup() - -# # u = pipeline.chat("查查光与尘这杯茶的特点", as_stream=True) -# pipeline.chat("我想和红茶有什么推荐的吗", as_stream=True) - -# # pipeline.chat("我叫什么名字", as_stream=True) - - -# def main(): -# pipeline_config = PipelineConfig() -# pipeline: Pipeline = pipeline_config.setup() - -# # 进行循环对话 -# while True: -# try: -# user_input = input("请讲:") -# if user_input.lower() == "exit": -# break -# response = pipeline.chat(user_input, as_stream=True) -# print(f"回答: {response}") -# except Exception as e: -# logger.error(f"对话过程中出现错误: {e}") - import signal import sys def signal_handler(sig, frame): @@ -189,18 +161,12 @@ def main(): # 进行循环对话 while True: - try: - user_input = input("请讲:") - if user_input.lower() == "exit": - break - response = pipeline.chat(user_input, as_stream=True) - print(f"回答: {response}") - except KeyboardInterrupt: - # Handle Ctrl+C during input - print("\n程序正在退出...") + + user_input = input("请讲:") + if user_input.lower() == "exit": break - except Exception as e: - logger.error(f"对话过程中出现错误: {e}") + response = pipeline.chat(user_input, as_stream=True) + print(f"回答: {response}") if __name__ == "__main__": diff --git a/lang_agent/rag/simple.py b/lang_agent/rag/simple.py index 8167310..bb64487 100644 --- a/lang_agent/rag/simple.py +++ b/lang_agent/rag/simple.py @@ -9,35 +9,21 @@ from langchain_community.vectorstores import FAISS from langchain_core.documents.base import Document from lang_agent.rag.emb import QwenEmbeddings -from lang_agent.config import ToolConfig +from lang_agent.config import ToolConfig, KeyConfig from lang_agent.base import LangToolBase @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class SimpleRagConfig(ToolConfig): +class SimpleRagConfig(ToolConfig, KeyConfig): _target: Type = field(default_factory=lambda: SimpleRag) model_name:str = "text-embedding-v4" """embedding model name""" - api_key:str = "wrong-key" - """api_key for model; for generic text splitting; give a wrong key <-- wrong, MUST have api key""" - folder_path:str = "/home/smith/projects/work/langchain-agent/assets/xiaozhan_emb" """path to local database""" - def __post_init__(self): - if self.api_key == "wrong-key": - # logger.info("wrong embedding key, using simple retrieval method") - self.api_key = os.environ.get("ALI_API_KEY") - if self.api_key is None: - logger.error(f"no ALI_API_KEY provided for embedding") - else: - logger.info("ALI_API_KEY loaded from environ") - - logger.info(f"using {self.folder_path} as database") - class SimpleRag(LangToolBase):