Merge branch 'main' of https://github.com/tangledup-ai/langchain-agent
This commit is contained in:
10
lang_agent/__init__.py
Normal file
10
lang_agent/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
LangChain Agent - 智能代理系统
|
||||
|
||||
这是一个基于LangChain和LangGraph构建的智能代理系统,集成了RAG(检索增强生成)、
|
||||
工具调用和WebSocket通信功能。项目主要用于茶饮场景的智能对话和订单处理,
|
||||
支持多种工具调用和远程MCP(Model Context Protocol)服务器集成。
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "LangChain Agent Team"
|
||||
@@ -12,7 +12,11 @@ load_dotenv()
|
||||
|
||||
## NOTE: base classes taken from nerfstudio
|
||||
class PrintableConfig:
|
||||
"""Printable Config defining str function"""
|
||||
"""
|
||||
Printable Config defining str function
|
||||
定义 __str__ 方法的可打印配置类
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
lines = [self.__class__.__name__ + ":"]
|
||||
@@ -43,25 +47,52 @@ class PrintableConfig:
|
||||
# Base instantiate configs
|
||||
@dataclass
|
||||
class InstantiateConfig(PrintableConfig):
|
||||
"""Config class for instantiating an the class specified in the _target attribute."""
|
||||
"""
|
||||
Config class for instantiating an the class specified in the _target attribute.
|
||||
|
||||
用于实例化 _target 属性指定的类的配置类
|
||||
|
||||
"""
|
||||
|
||||
_target: Type
|
||||
|
||||
def setup(self, **kwargs) -> Any:
|
||||
"""Returns the instantiated object using the config."""
|
||||
"""
|
||||
Returns the instantiated object using the config.
|
||||
|
||||
使用配置返回实例化的对象
|
||||
|
||||
"""
|
||||
return self._target(self, **kwargs)
|
||||
|
||||
def save_config(self, filename: str) -> None:
|
||||
"""Save the config to a YAML file."""
|
||||
"""
|
||||
Save the config to a YAML file.
|
||||
|
||||
将配置保存到 YAML 文件
|
||||
|
||||
"""
|
||||
def mask_value(key, value):
|
||||
# Apply masking if key is secret-like
|
||||
"""
|
||||
Apply masking if key is secret-like
|
||||
如果键是敏感的,应用掩码
|
||||
|
||||
检查键是否敏感(如包含 "secret" 或 "api_key"),如果是,则对值进行掩码处理
|
||||
|
||||
"""
|
||||
if isinstance(value, str) and self.is_secrete(key):
|
||||
sval = str(value)
|
||||
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
|
||||
return value
|
||||
|
||||
def to_masked_serializable(obj):
|
||||
# Recursively convert dataclasses and containers to serializable with masked secrets
|
||||
|
||||
"""
|
||||
Recursively convert dataclasses and containers to serializable with masked secrets
|
||||
|
||||
递归地将数据类和容器转换为可序列化的格式,同时对敏感信息进行掩码处理
|
||||
|
||||
"""
|
||||
if is_dataclass(obj):
|
||||
out = {}
|
||||
for k, v in vars(obj).items():
|
||||
@@ -115,10 +146,19 @@ class KeyConfig(InstantiateConfig):
|
||||
@dataclass
|
||||
class ToolConfig(InstantiateConfig):
|
||||
use_tool:bool = True
|
||||
"""specify to use tool or not"""
|
||||
"""
|
||||
specify to use tool or not
|
||||
|
||||
指定是否使用工具
|
||||
"""
|
||||
|
||||
def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
|
||||
"""load and overwrite config from file"""
|
||||
"""
|
||||
load and overwrite config from file
|
||||
|
||||
从文件加载并覆盖配置
|
||||
|
||||
"""
|
||||
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
||||
|
||||
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||
@@ -127,36 +167,84 @@ def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
|
||||
def is_default(instance, field_):
|
||||
"""
|
||||
Check if the value of a field in a dataclass instance is the default value.
|
||||
|
||||
检查数据类实例中字段的值是否为默认值
|
||||
|
||||
"""
|
||||
value = getattr(instance, field_.name)
|
||||
|
||||
if field_.default is not MISSING:
|
||||
# Compare with default value
|
||||
# Compare with default value
|
||||
"""
|
||||
与默认值进行比较
|
||||
|
||||
如果字段有默认值,则将当前值与默认值进行比较
|
||||
|
||||
"""
|
||||
return value == field_.default
|
||||
elif field_.default_factory is not MISSING:
|
||||
# Compare with value generated by the default factory
|
||||
"""
|
||||
与默认工厂生成的值进行比较
|
||||
|
||||
如果字段有默认工厂,则将当前值与默认工厂生成的值进行比较
|
||||
|
||||
"""
|
||||
return value == field_.default_factory()
|
||||
else:
|
||||
# No default value specified
|
||||
return False
|
||||
|
||||
def ovewrite_config(loaded_conf, inp_conf):
|
||||
"""for non-default values in inp_conf, overwrite the corresponding values in loaded_conf"""
|
||||
"""
|
||||
for non-default values in inp_conf, overwrite the corresponding values in loaded_conf
|
||||
|
||||
对于 inp_conf 中的非默认值,覆盖 loaded_conf 中对应的配置
|
||||
|
||||
"""
|
||||
if not (is_dataclass(loaded_conf) and is_dataclass(inp_conf)):
|
||||
return loaded_conf
|
||||
|
||||
for field_ in fields(loaded_conf):
|
||||
field_name = field_.name
|
||||
# if field_name in inp_conf:
|
||||
"""
|
||||
if field_name in inp_conf:
|
||||
|
||||
如果字段名在 inp_conf 中,则进行覆盖
|
||||
|
||||
"""
|
||||
current_value = getattr(inp_conf, field_name)
|
||||
new_value = getattr(inp_conf, field_name) #inp_conf[field_name]
|
||||
new_value = getattr(inp_conf, field_name)
|
||||
|
||||
"""
|
||||
inp_conf[field_name]
|
||||
从 inp_conf 中获取字段值
|
||||
|
||||
如果字段名在 inp_conf 中,则获取其值
|
||||
|
||||
"""
|
||||
|
||||
if is_dataclass(current_value):
|
||||
# Recurse for nested dataclasses
|
||||
|
||||
"""
|
||||
Recurse for nested dataclasses
|
||||
|
||||
递归处理嵌套的数据类
|
||||
|
||||
如果当前值是数据类,则递归调用 ovewrite_config 进行合并
|
||||
|
||||
"""
|
||||
merged_value = ovewrite_config(current_value, new_value)
|
||||
setattr(loaded_conf, field_name, merged_value)
|
||||
elif not is_default(inp_conf, field_):
|
||||
# Overwrite only if the current value is not default
|
||||
"""
|
||||
Overwrite only if the current value is not default
|
||||
|
||||
仅在当前值不是默认值时进行覆盖
|
||||
|
||||
如果 inp_conf 中的字段值不是默认值,则覆盖 loaded_conf 中的对应值
|
||||
|
||||
"""
|
||||
setattr(loaded_conf, field_name, new_value)
|
||||
|
||||
return loaded_conf
|
||||
|
||||
9
lang_agent/dummy/__init__.py
Normal file
9
lang_agent/dummy/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
示例工具模块
|
||||
|
||||
该模块包含各种示例工具的实现,用于演示代理系统的工具调用能力。
|
||||
"""
|
||||
|
||||
from .calculator import Calculator
|
||||
|
||||
__all__ = ["Calculator"]
|
||||
@@ -13,7 +13,7 @@ from lang_agent.config import InstantiateConfig, ToolConfig
|
||||
from lang_agent.dummy.calculator import Calculator, CalculatorConfig
|
||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||
|
||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
# from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
@@ -22,7 +22,7 @@ class MCPServerConfig(InstantiateConfig):
|
||||
|
||||
server_name:str = "langserver"
|
||||
|
||||
host: str = "6.6.6.136"
|
||||
host: str = "127.0.0.1"
|
||||
"""host of server"""
|
||||
|
||||
port: int = 50051
|
||||
|
||||
14
lang_agent/rag/__init__.py
Normal file
14
lang_agent/rag/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
RAG (Retrieval Augmented Generation) 模块
|
||||
|
||||
该模块提供了检索增强生成的功能,包括:
|
||||
- 嵌入向量生成和存储
|
||||
- 相似度搜索和文档检索
|
||||
- 基于FAISS的向量数据库支持
|
||||
- 阿里云DashScope嵌入服务集成
|
||||
"""
|
||||
|
||||
from .emb import QwenEmbeddings
|
||||
from .simple import SimpleRag
|
||||
|
||||
__all__ = ["QwenEmbeddings", "SimpleRag"]
|
||||
@@ -7,7 +7,6 @@ import asyncio
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
from fastmcp.tools.tool import Tool
|
||||
|
||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||
from lang_agent.base import LangToolBase
|
||||
from lang_agent.client_tool_manager import ClientToolManagerConfig
|
||||
@@ -15,11 +14,8 @@ from lang_agent.client_tool_manager import ClientToolManagerConfig
|
||||
from lang_agent.rag.simple import SimpleRagConfig
|
||||
# from lang_agent.dummy.calculator import CalculatorConfig
|
||||
# from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from lang_agent.client_tool_manager import ClientToolManager
|
||||
import jax
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class ToolManagerConfig(InstantiateConfig):
|
||||
|
||||
Reference in New Issue
Block a user