use key config
This commit is contained in:
@@ -9,29 +9,28 @@ import os
|
|||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
# from langgraph.prebuilt import create_react_agent
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineConfig(InstantiateConfig):
|
class PipelineConfig(KeyConfig):
|
||||||
_target: Type = field(default_factory=lambda: Pipeline)
|
_target: Type = field(default_factory=lambda: Pipeline)
|
||||||
|
|
||||||
config_f: str = None
|
config_f: str = None
|
||||||
"""path to config file"""
|
"""path to config file"""
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
llm_name: str = None
|
||||||
"""name of llm"""
|
"""name of llm; use default for qwen-plus"""
|
||||||
|
|
||||||
llm_provider:str = "openai"
|
llm_provider:str = None
|
||||||
"""provider of the llm"""
|
"""provider of the llm; use default for openai"""
|
||||||
|
|
||||||
api_key:str = None
|
|
||||||
"""api key for llm"""
|
|
||||||
|
|
||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||||
@@ -45,14 +44,6 @@ class PipelineConfig(InstantiateConfig):
|
|||||||
# NOTE: For reference
|
# NOTE: For reference
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.api_key == "wrong-key" or self.api_key is None:
|
|
||||||
# logger.info("wrong embedding key, using simple retrieval method")
|
|
||||||
self.api_key = os.environ.get("ALI_API_KEY")
|
|
||||||
if self.api_key is None:
|
|
||||||
logger.error(f"no ALI_API_KEY provided for embedding")
|
|
||||||
else:
|
|
||||||
logger.info("ALI_API_KEY loaded from environ")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +54,11 @@ class Pipeline:
|
|||||||
self.populate_module()
|
self.populate_module()
|
||||||
|
|
||||||
def populate_module(self):
|
def populate_module(self):
|
||||||
|
if self.config.llm_name is None:
|
||||||
|
logger.info(f"setting llm_provider to default")
|
||||||
|
self.config.llm_name = "qwen-turbo"
|
||||||
|
self.config.llm_provider = "openai"
|
||||||
|
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
@@ -149,30 +145,6 @@ class Pipeline:
|
|||||||
return out['messages'][-1].content
|
return out['messages'][-1].content
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# pipeline:Pipeline = PipelineConfig().setup()
|
|
||||||
|
|
||||||
# # u = pipeline.chat("查查光与尘这杯茶的特点", as_stream=True)
|
|
||||||
# pipeline.chat("我想和红茶有什么推荐的吗", as_stream=True)
|
|
||||||
|
|
||||||
# # pipeline.chat("我叫什么名字", as_stream=True)
|
|
||||||
|
|
||||||
|
|
||||||
# def main():
|
|
||||||
# pipeline_config = PipelineConfig()
|
|
||||||
# pipeline: Pipeline = pipeline_config.setup()
|
|
||||||
|
|
||||||
# # 进行循环对话
|
|
||||||
# while True:
|
|
||||||
# try:
|
|
||||||
# user_input = input("请讲:")
|
|
||||||
# if user_input.lower() == "exit":
|
|
||||||
# break
|
|
||||||
# response = pipeline.chat(user_input, as_stream=True)
|
|
||||||
# print(f"回答: {response}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"对话过程中出现错误: {e}")
|
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
@@ -189,18 +161,12 @@ def main():
|
|||||||
|
|
||||||
# 进行循环对话
|
# 进行循环对话
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
user_input = input("请讲:")
|
user_input = input("请讲:")
|
||||||
if user_input.lower() == "exit":
|
if user_input.lower() == "exit":
|
||||||
break
|
break
|
||||||
response = pipeline.chat(user_input, as_stream=True)
|
response = pipeline.chat(user_input, as_stream=True)
|
||||||
print(f"回答: {response}")
|
print(f"回答: {response}")
|
||||||
except KeyboardInterrupt:
|
|
||||||
# Handle Ctrl+C during input
|
|
||||||
print("\n程序正在退出...")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"对话过程中出现错误: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -9,35 +9,21 @@ from langchain_community.vectorstores import FAISS
|
|||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
|
|
||||||
from lang_agent.rag.emb import QwenEmbeddings
|
from lang_agent.rag.emb import QwenEmbeddings
|
||||||
from lang_agent.config import ToolConfig
|
from lang_agent.config import ToolConfig, KeyConfig
|
||||||
from lang_agent.base import LangToolBase
|
from lang_agent.base import LangToolBase
|
||||||
|
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimpleRagConfig(ToolConfig):
|
class SimpleRagConfig(ToolConfig, KeyConfig):
|
||||||
_target: Type = field(default_factory=lambda: SimpleRag)
|
_target: Type = field(default_factory=lambda: SimpleRag)
|
||||||
|
|
||||||
model_name:str = "text-embedding-v4"
|
model_name:str = "text-embedding-v4"
|
||||||
"""embedding model name"""
|
"""embedding model name"""
|
||||||
|
|
||||||
api_key:str = "wrong-key"
|
|
||||||
"""api_key for model; for generic text splitting; give a wrong key <-- wrong, MUST have api key"""
|
|
||||||
|
|
||||||
folder_path:str = "/home/smith/projects/work/langchain-agent/assets/xiaozhan_emb"
|
folder_path:str = "/home/smith/projects/work/langchain-agent/assets/xiaozhan_emb"
|
||||||
"""path to local database"""
|
"""path to local database"""
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.api_key == "wrong-key":
|
|
||||||
# logger.info("wrong embedding key, using simple retrieval method")
|
|
||||||
self.api_key = os.environ.get("ALI_API_KEY")
|
|
||||||
if self.api_key is None:
|
|
||||||
logger.error(f"no ALI_API_KEY provided for embedding")
|
|
||||||
else:
|
|
||||||
logger.info("ALI_API_KEY loaded from environ")
|
|
||||||
|
|
||||||
logger.info(f"using {self.folder_path} as database")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleRag(LangToolBase):
|
class SimpleRag(LangToolBase):
|
||||||
|
|||||||
Reference in New Issue
Block a user