moved files
This commit is contained in:
75
lang_agent/components/client_tool_manager.py
Normal file
75
lang_agent/components/client_tool_manager.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
import tyro
|
||||
import commentjson
|
||||
import asyncio
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class ClientToolManagerConfig(InstantiateConfig):
|
||||
_target: Type = field(default_factory=lambda: ClientToolManager)
|
||||
|
||||
mcp_config_f: str = None
|
||||
"""path to all mcp configurations; expect json file"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mcp_config_f is None:
|
||||
self.mcp_config_f = osp.join(osp.dirname(osp.dirname(__file__)), "configs", "mcp_config.json")
|
||||
logger.warning(f"config_f was not provided. Using default: {self.mcp_config_f}")
|
||||
assert osp.exists(self.mcp_config_f), f"Default config_f {self.mcp_config_f} does not exist."
|
||||
|
||||
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
|
||||
|
||||
|
||||
class ClientToolManager:
|
||||
def __init__(self, config:ClientToolManagerConfig):
|
||||
self.config = config
|
||||
|
||||
self.populate_module()
|
||||
|
||||
def populate_module(self):
|
||||
with open(self.config.mcp_config_f, "r") as f:
|
||||
self.mcp_configs = commentjson.load(f)
|
||||
|
||||
self.cli = MultiServerMCPClient(self.mcp_configs)
|
||||
|
||||
async def aget_tools(self):
|
||||
tools = await self.cli.get_tools()
|
||||
return tools
|
||||
|
||||
def get_tools(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# Event loop is already running, we need to run in a thread
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
# Create a new event loop in this thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(self.aget_tools())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
tools = future.result()
|
||||
return tools
|
||||
except RuntimeError:
|
||||
# No event loop running, safe to use asyncio.run()
|
||||
tools = asyncio.run(self.aget_tools())
|
||||
return tools
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE: Simple test
|
||||
config = ClientToolManagerConfig()
|
||||
tool_manager = ClientToolManager(config)
|
||||
tools = tool_manager.get_tools()
|
||||
[print(e.name) for e in tools]
|
||||
170
lang_agent/components/tool_manager.py
Normal file
170
lang_agent/components/tool_manager.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
import functools
|
||||
from typing import Type, List, Callable, Any
|
||||
import tyro
|
||||
import inspect
|
||||
import asyncio
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
from fastmcp.tools.tool import Tool
|
||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||
from lang_agent.base import LangToolBase
|
||||
from lang_agent.client_tool_manager import ClientToolManagerConfig
|
||||
|
||||
from lang_agent.rag.simple import SimpleRagConfig
|
||||
from lang_agent.dummy.calculator import CalculatorConfig
|
||||
# from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from lang_agent.client_tool_manager import ClientToolManager
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class ToolManagerConfig(InstantiateConfig):
|
||||
_target: Type = field(default_factory=lambda: ToolManager)
|
||||
|
||||
client_tool_manager: ClientToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
||||
|
||||
# tool configs here; MUST HAVE 'config' in name and must be dataclass
|
||||
# rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||
|
||||
# cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||
|
||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||
|
||||
|
||||
def async_to_sync(async_func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that converts an async function to a sync function.
|
||||
|
||||
Args:
|
||||
async_func: The async function to convert
|
||||
|
||||
Returns:
|
||||
A synchronous wrapper function
|
||||
"""
|
||||
@functools.wraps(async_func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Handle nested event loops (e.g., in Jupyter)
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
return asyncio.run(async_func(*args, **kwargs))
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config:ToolManagerConfig):
|
||||
self.config = config
|
||||
|
||||
self.tool_fncs = [] # list of functions that should be turned into tools
|
||||
self.client_tool_manager = [] # 用于获取 MCP 工具
|
||||
self.populate_modules()
|
||||
|
||||
logger.info("available tools:")
|
||||
for tool in self.get_list_langchain_tools():
|
||||
logger.info(tool.name)
|
||||
|
||||
def _get_tool_config(self)->List[ToolConfig]:
|
||||
tool_confs = []
|
||||
for e in dir(self.config):
|
||||
el = getattr(self.config, e)
|
||||
if ("config" in e) and is_dataclass(el):
|
||||
tool_confs.append(el)
|
||||
|
||||
return tool_confs
|
||||
|
||||
def _get_tool_fnc(self, tool_obj:LangToolBase)->List:
|
||||
fnc_list = []
|
||||
for fnc in tool_obj.get_tool_fnc():
|
||||
if isinstance(fnc, Tool):
|
||||
fnc = fnc.fn
|
||||
fnc_list.append(fnc)
|
||||
|
||||
return fnc_list
|
||||
|
||||
|
||||
def populate_modules(self):
|
||||
"""instantiate all object with tools"""
|
||||
|
||||
self.tool_fncs = []
|
||||
tool_configs = self._get_tool_config()
|
||||
for tool_conf in tool_configs:
|
||||
tool_name = tool_conf.get_name()[:-6]
|
||||
if tool_conf.use_tool:
|
||||
logger.info(f"making tool:{tool_name}")
|
||||
fnc_list = self._get_tool_fnc(tool_conf.setup())
|
||||
self.tool_fncs.extend(fnc_list)
|
||||
else:
|
||||
logger.info(f"skipping tool:{tool_name}")
|
||||
|
||||
try:
|
||||
# client_config = self.config.client_tool_manager
|
||||
# self.client_tool_manager = ClientToolManager(client_config)
|
||||
# self.client_tool_manager = ClientToolManager(self.config.client_tool_manager)
|
||||
self.client_tool_manager:ClientToolManager = self.config.client_tool_manager.setup()
|
||||
logger.info("Successfully initialized client_tool_manager for MCP tools")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize client_tool_manager: {e}")
|
||||
self.client_tool_manager = []
|
||||
self._build_langchain_tools()
|
||||
|
||||
def get_tool_fncs(self):
|
||||
all_tools = []
|
||||
all_tools.extend(self.tool_fncs)
|
||||
if self.client_tool_manager is not None:
|
||||
try:
|
||||
mcp_tools = self.client_tool_manager.get_tools()
|
||||
all_tools.extend(mcp_tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get MCP tools: {e}")
|
||||
return all_tools
|
||||
|
||||
def get_tool_dict(self):
|
||||
return self.tool_dict
|
||||
|
||||
|
||||
def fnc_to_structool(self, func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return StructuredTool.from_function(
|
||||
func=async_to_sync(func),
|
||||
coroutine=func)
|
||||
else:
|
||||
return StructuredTool.from_function(func=func)
|
||||
|
||||
def _build_langchain_tools(self):
|
||||
self.langchain_tools = []
|
||||
for func in self.get_tool_fncs():
|
||||
if isinstance(func, StructuredTool):
|
||||
if hasattr(func, 'coroutine') and func.coroutine is not None and (not hasattr(func, 'func') or func.func is None):
|
||||
sync_func = async_to_sync(func.coroutine)
|
||||
new_tool = StructuredTool(
|
||||
name=func.name,
|
||||
description=func.description,
|
||||
args_schema=func.args_schema,
|
||||
func=sync_func,
|
||||
coroutine=func.coroutine,
|
||||
metadata=func.metadata if hasattr(func, 'metadata') else None,
|
||||
return_direct=func.return_direct if hasattr(func, 'return_direct') else False,
|
||||
)
|
||||
self.langchain_tools.append(new_tool)
|
||||
else:
|
||||
self.langchain_tools.append(func)
|
||||
else:
|
||||
self.langchain_tools.append(self.fnc_to_structool(func))
|
||||
return self.langchain_tools
|
||||
|
||||
def get_list_langchain_tools(self)->List[StructuredTool]:
|
||||
return self.langchain_tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
man: ToolManager = ToolManagerConfig().setup()
|
||||
for lang_tool in man.get_list_langchain_tools():
|
||||
print(lang_tool.name)
|
||||
Reference in New Issue
Block a user