修改提示词、添加获取远程tool_manager.py实现获取远程mcp工具功能、添加实验数据集
This commit is contained in:
@@ -13,9 +13,10 @@ from lang_agent.base import LangToolBase
|
||||
|
||||
from lang_agent.rag.simple import SimpleRagConfig
|
||||
from lang_agent.dummy.calculator import CalculatorConfig
|
||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
# from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from lang_agent.client_tool_manager import ClientToolManager
|
||||
import jax
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@@ -26,7 +27,7 @@ class ToolManagerConfig(InstantiateConfig):
|
||||
# tool configs here; MUST HAVE 'config' in name and must be dataclass
|
||||
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||
|
||||
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||
# cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||
|
||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||
|
||||
@@ -64,6 +65,7 @@ class ToolManager:
|
||||
self.config = config
|
||||
|
||||
self.tool_fncs = [] # list of functions that should be turned into tools
|
||||
self.client_tool_manager = [] # 用于获取 MCP 工具
|
||||
self.populate_modules()
|
||||
|
||||
def _get_tool_config(self)->List[ToolConfig]:
|
||||
@@ -99,11 +101,27 @@ class ToolManager:
|
||||
else:
|
||||
logger.info(f"skipping tool:{tool_name}")
|
||||
|
||||
try:
|
||||
from lang_agent.client_tool_manager import ClientToolManagerConfig
|
||||
client_config = ClientToolManagerConfig()
|
||||
self.client_tool_manager = ClientToolManager(client_config)
|
||||
logger.info("Successfully initialized client_tool_manager for MCP tools")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize client_tool_manager: {e}")
|
||||
self.client_tool_manager = []
|
||||
self._build_langchain_tools()
|
||||
|
||||
|
||||
def get_tool_fncs(self):
|
||||
return self.tool_fncs
|
||||
all_tools = []
|
||||
all_tools.extend(self.tool_fncs)
|
||||
if self.client_tool_manager is not None:
|
||||
try:
|
||||
mcp_tools = self.client_tool_manager.get_tools()
|
||||
all_tools.extend(mcp_tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get MCP tools: {e}")
|
||||
return all_tools
|
||||
|
||||
def get_tool_dict(self):
|
||||
return self.tool_dict
|
||||
@@ -113,20 +131,33 @@ class ToolManager:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return StructuredTool.from_function(
|
||||
func=async_to_sync(func),
|
||||
coroutine=func)
|
||||
|
||||
coroutine=func)
|
||||
else:
|
||||
return StructuredTool.from_function(func=func)
|
||||
|
||||
def _build_langchain_tools(self):
|
||||
self.langchain_tools = []
|
||||
for func in self.get_tool_fncs():
|
||||
self.langchain_tools.append(self.fnc_to_structool(func))
|
||||
if isinstance(func, StructuredTool):
|
||||
self.langchain_tools.append(func)
|
||||
else:
|
||||
self.langchain_tools.append(self.fnc_to_structool(func))
|
||||
|
||||
return self.langchain_tools
|
||||
|
||||
def get_list_langchain_tools(self)->List[StructuredTool]:
|
||||
return self.langchain_tools
|
||||
all_langchain_tools = []
|
||||
all_langchain_tools.extend(self.langchain_tools)
|
||||
# 如果有 client_tool_manager,添加 MCP 工具(已经是 LangChain 格式)
|
||||
if self.client_tool_manager:
|
||||
try:
|
||||
# 获取 MCP 工具(已经是 StructuredTool 格式)
|
||||
mcp_tools = self.client_tool_manager.get_tools()
|
||||
all_langchain_tools.extend(mcp_tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get MCP tools: {e}")
|
||||
|
||||
return all_langchain_tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user