from dataclasses import dataclass, field, is_dataclass import functools from typing import Type, List, Callable, Any import tyro import inspect import asyncio import os.path as osp from loguru import logger from fastmcp.tools.tool import Tool from lang_agent.config import InstantiateConfig, ToolConfig from lang_agent.base import LangToolBase from lang_agent.rag.simple import SimpleRagConfig from lang_agent.dummy.calculator import CalculatorConfig # from catering_end.lang_tool import CartToolConfig, CartTool from langchain_core.tools.structured import StructuredTool import jax @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class ToolManagerConfig(InstantiateConfig): _target: Type = field(default_factory=lambda: ToolManager) # tool configs here; MUST HAVE 'config' in name and must be dataclass rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) # cart_config: CartToolConfig = field(default_factory=CartToolConfig) calc_config: CalculatorConfig = field(default_factory=CalculatorConfig) def async_to_sync(async_func: Callable) -> Callable: """ Decorator that converts an async function to a sync function. Args: async_func: The async function to convert Returns: A synchronous wrapper function """ @functools.wraps(async_func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: try: loop = asyncio.get_event_loop() if loop.is_running(): # Handle nested event loops (e.g., in Jupyter) import nest_asyncio nest_asyncio.apply() return loop.run_until_complete(async_func(*args, **kwargs)) else: return loop.run_until_complete(async_func(*args, **kwargs)) except RuntimeError: # No event loop exists, create a new one return asyncio.run(async_func(*args, **kwargs)) return sync_wrapper class ToolManager: def __init__(self, config:ToolManagerConfig): self.config = config self.tool_fncs = [] # list of functions that should be turned into tools self.populate_modules() def _get_tool_config(self)->List[ToolConfig]: tool_confs = [] for e in dir(self.config): el = getattr(self.config, e) if ("config" in e) and is_dataclass(el): tool_confs.append(el) return tool_confs def _get_tool_fnc(self, tool_obj:LangToolBase)->List: fnc_list = [] for fnc in tool_obj.get_tool_fnc(): if isinstance(fnc, Tool): fnc = fnc.fn fnc_list.append(fnc) return fnc_list def populate_modules(self): """instantiate all object with tools""" self.tool_fncs = [] tool_configs = self._get_tool_config() for tool_conf in tool_configs: tool_name = tool_conf.get_name()[:-6] if tool_conf.use_tool: logger.info(f"making tool:{tool_name}") fnc_list = self._get_tool_fnc(tool_conf.setup()) self.tool_fncs.extend(fnc_list) else: logger.info(f"skipping tool:{tool_name}") self._build_langchain_tools() def get_tool_fncs(self): return self.tool_fncs def get_tool_dict(self): return self.tool_dict def fnc_to_structool(self, func): if inspect.iscoroutinefunction(func): return StructuredTool.from_function( func=async_to_sync(func), coroutine=func) else: return StructuredTool.from_function(func=func) def _build_langchain_tools(self): self.langchain_tools = [] for func in self.get_tool_fncs(): self.langchain_tools.append(self.fnc_to_structool(func)) return self.langchain_tools def get_list_langchain_tools(self)->List[StructuredTool]: return self.langchain_tools if __name__ == "__main__": man: ToolManager = ToolManagerConfig().setup() for lang_tool in man.get_list_langchain_tools(): print(lang_tool.name)