From 8f6b181ff834c6cb134b36e0a3641e4145fc75e2 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 11:52:40 +0800 Subject: [PATCH] make tool_dict for better usage --- lang_agent/tool_manager.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 0b21f0a..12ae52b 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool from lang_agent.config import InstantiateConfig, ToolConfig from lang_agent.base import LangToolBase -## import tool configs from lang_agent.rag.simple import SimpleRagConfig from lang_agent.dummy.calculator import CalculatorConfig from catering_end.lang_tool import CartToolConfig, CartTool -# from langchain.tools import StructuredTool from langchain_core.tools.structured import StructuredTool +import jax @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class ToolManagerConfig(InstantiateConfig): _target: Type = field(default_factory=lambda: ToolManager) - # tool configs here; + # tool configs here; MUST HAVE 'config' in name and must be dataclass rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) cart_config: CartToolConfig = field(default_factory=CartToolConfig) @@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable: return sync_wrapper + class ToolManager: def __init__(self, config:ToolManagerConfig): self.config = config @@ -89,31 +89,42 @@ class ToolManager: """instantiate all object with tools""" self.tool_fncs = [] + self.tool_dict = {} tool_configs = self._get_tool_config() for tool_conf in tool_configs: + tool_name = tool_conf.get_name()[:-6] if tool_conf.use_tool: - logger.info(f"making tool:{tool_conf._target}") - self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup())) + logger.info(f"making tool:{tool_name}") + fnc_list = self._get_tool_fnc(tool_conf.setup()) + self.tool_fncs.extend(fnc_list) + self.tool_dict[tool_name] = fnc_list else: - logger.info(f"skipping tool:{tool_conf._target}") + logger.info(f"skipping tool:{tool_name}") def get_tool_fncs(self): return self.tool_fncs + def get_tool_dict(self): + return self.tool_dict + + + def fnc_to_structool(self, func): + if inspect.iscoroutinefunction(func): + return StructuredTool.from_function( + func=async_to_sync(func), + coroutine=func) + + else: + return StructuredTool.from_function(func=func) + - def get_langchain_tools(self): + def get_list_langchain_tools(self): out = [] for func in self.get_tool_fncs(): - if inspect.iscoroutinefunction(func): - out.append( - StructuredTool.from_function( - func=async_to_sync(func), - coroutine=func) - ) - else: - out.append( - StructuredTool.from_function(func=func) - ) + out.append(self.fnc_to_structool(func)) - return out \ No newline at end of file + return out + + def get_dict_langchain_tools(self): + return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict)