use key config

This commit is contained in:
2025-10-22 12:10:49 +08:00
parent 193944a3a2
commit ba7577f4a6
2 changed files with 21 additions and 69 deletions

View File

@@ -9,29 +9,28 @@ import os
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.messages import SystemMessage, HumanMessage
# from langgraph.prebuilt import create_react_agent
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class PipelineConfig(InstantiateConfig): class PipelineConfig(KeyConfig):
_target: Type = field(default_factory=lambda: Pipeline) _target: Type = field(default_factory=lambda: Pipeline)
config_f: str = None config_f: str = None
"""path to config file""" """path to config file"""
llm_name: str = "qwen-plus" llm_name: str = None
"""name of llm""" """name of llm; use default for qwen-plus"""
llm_provider:str = "openai" llm_provider:str = None
"""provider of the llm""" """provider of the llm; use default for openai"""
api_key:str = None
"""api key for llm"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider""" """base url; could be used to overwrite the baseurl in llm provider"""
@@ -45,14 +44,6 @@ class PipelineConfig(InstantiateConfig):
# NOTE: For reference # NOTE: For reference
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
if self.api_key == "wrong-key" or self.api_key is None:
# logger.info("wrong embedding key, using simple retrieval method")
self.api_key = os.environ.get("ALI_API_KEY")
if self.api_key is None:
logger.error(f"no ALI_API_KEY provided for embedding")
else:
logger.info("ALI_API_KEY loaded from environ")
@@ -63,6 +54,11 @@ class Pipeline:
self.populate_module() self.populate_module()
def populate_module(self): def populate_module(self):
if self.config.llm_name is None:
logger.info(f"setting llm_provider to default")
self.config.llm_name = "qwen-turbo"
self.config.llm_provider = "openai"
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider, model_provider=self.config.llm_provider,
api_key=self.config.api_key, api_key=self.config.api_key,
@@ -149,30 +145,6 @@ class Pipeline:
return out['messages'][-1].content return out['messages'][-1].content
# if __name__ == "__main__":
# pipeline:Pipeline = PipelineConfig().setup()
# # u = pipeline.chat("查查光与尘这杯茶的特点", as_stream=True)
# pipeline.chat("我想和红茶有什么推荐的吗", as_stream=True)
# # pipeline.chat("我叫什么名字", as_stream=True)
# def main():
# pipeline_config = PipelineConfig()
# pipeline: Pipeline = pipeline_config.setup()
# # 进行循环对话
# while True:
# try:
# user_input = input("请讲:")
# if user_input.lower() == "exit":
# break
# response = pipeline.chat(user_input, as_stream=True)
# print(f"回答: {response}")
# except Exception as e:
# logger.error(f"对话过程中出现错误: {e}")
import signal import signal
import sys import sys
def signal_handler(sig, frame): def signal_handler(sig, frame):
@@ -189,18 +161,12 @@ def main():
# 进行循环对话 # 进行循环对话
while True: while True:
try:
user_input = input("请讲:") user_input = input("请讲:")
if user_input.lower() == "exit": if user_input.lower() == "exit":
break
response = pipeline.chat(user_input, as_stream=True)
print(f"回答: {response}")
except KeyboardInterrupt:
# Handle Ctrl+C during input
print("\n程序正在退出...")
break break
except Exception as e: response = pipeline.chat(user_input, as_stream=True)
logger.error(f"对话过程中出现错误: {e}") print(f"回答: {response}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -9,35 +9,21 @@ from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings from lang_agent.rag.emb import QwenEmbeddings
from lang_agent.config import ToolConfig from lang_agent.config import ToolConfig, KeyConfig
from lang_agent.base import LangToolBase from lang_agent.base import LangToolBase
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class SimpleRagConfig(ToolConfig): class SimpleRagConfig(ToolConfig, KeyConfig):
_target: Type = field(default_factory=lambda: SimpleRag) _target: Type = field(default_factory=lambda: SimpleRag)
model_name:str = "text-embedding-v4" model_name:str = "text-embedding-v4"
"""embedding model name""" """embedding model name"""
api_key:str = "wrong-key"
"""api_key for model; for generic text splitting; give a wrong key <-- wrong, MUST have api key"""
folder_path:str = "/home/smith/projects/work/langchain-agent/assets/xiaozhan_emb" folder_path:str = "/home/smith/projects/work/langchain-agent/assets/xiaozhan_emb"
"""path to local database""" """path to local database"""
def __post_init__(self):
if self.api_key == "wrong-key":
# logger.info("wrong embedding key, using simple retrieval method")
self.api_key = os.environ.get("ALI_API_KEY")
if self.api_key is None:
logger.error(f"no ALI_API_KEY provided for embedding")
else:
logger.info("ALI_API_KEY loaded from environ")
logger.info(f"using {self.folder_path} as database")
class SimpleRag(LangToolBase): class SimpleRag(LangToolBase):