make tool_dict for better usage

This commit is contained in:
2025-10-29 11:52:40 +08:00
parent 666e0c4d23
commit 8f6b181ff8

View File

@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
from lang_agent.config import InstantiateConfig, ToolConfig
from lang_agent.base import LangToolBase
## import tool configs
from lang_agent.rag.simple import SimpleRagConfig
from lang_agent.dummy.calculator import CalculatorConfig
from catering_end.lang_tool import CartToolConfig, CartTool
# from langchain.tools import StructuredTool
from langchain_core.tools.structured import StructuredTool
import jax
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class ToolManagerConfig(InstantiateConfig):
_target: Type = field(default_factory=lambda: ToolManager)
# tool configs here;
# tool configs here; MUST HAVE 'config' in name and must be dataclass
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
return sync_wrapper
class ToolManager:
def __init__(self, config:ToolManagerConfig):
self.config = config
@@ -89,31 +89,42 @@ class ToolManager:
"""instantiate all object with tools"""
self.tool_fncs = []
self.tool_dict = {}
tool_configs = self._get_tool_config()
for tool_conf in tool_configs:
tool_name = tool_conf.get_name()[:-6]
if tool_conf.use_tool:
logger.info(f"making tool:{tool_conf._target}")
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup()))
logger.info(f"making tool:{tool_name}")
fnc_list = self._get_tool_fnc(tool_conf.setup())
self.tool_fncs.extend(fnc_list)
self.tool_dict[tool_name] = fnc_list
else:
logger.info(f"skipping tool:{tool_conf._target}")
logger.info(f"skipping tool:{tool_name}")
def get_tool_fncs(self):
return self.tool_fncs
def get_tool_dict(self):
return self.tool_dict
def fnc_to_structool(self, func):
if inspect.iscoroutinefunction(func):
return StructuredTool.from_function(
func=async_to_sync(func),
coroutine=func)
else:
return StructuredTool.from_function(func=func)
def get_langchain_tools(self):
def get_list_langchain_tools(self):
out = []
for func in self.get_tool_fncs():
if inspect.iscoroutinefunction(func):
out.append(
StructuredTool.from_function(
func=async_to_sync(func),
coroutine=func)
)
else:
out.append(
StructuredTool.from_function(func=func)
)
out.append(self.fnc_to_structool(func))
return out
return out
def get_dict_langchain_tools(self):
return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict)