make tool_dict for better usage
This commit is contained in:
@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
|
|||||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||||
from lang_agent.base import LangToolBase
|
from lang_agent.base import LangToolBase
|
||||||
|
|
||||||
## import tool configs
|
|
||||||
from lang_agent.rag.simple import SimpleRagConfig
|
from lang_agent.rag.simple import SimpleRagConfig
|
||||||
from lang_agent.dummy.calculator import CalculatorConfig
|
from lang_agent.dummy.calculator import CalculatorConfig
|
||||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||||
|
|
||||||
# from langchain.tools import StructuredTool
|
|
||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
|
import jax
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolManagerConfig(InstantiateConfig):
|
class ToolManagerConfig(InstantiateConfig):
|
||||||
_target: Type = field(default_factory=lambda: ToolManager)
|
_target: Type = field(default_factory=lambda: ToolManager)
|
||||||
|
|
||||||
# tool configs here;
|
# tool configs here; MUST HAVE 'config' in name and must be dataclass
|
||||||
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||||
|
|
||||||
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||||
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
|
|||||||
|
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
def __init__(self, config:ToolManagerConfig):
|
def __init__(self, config:ToolManagerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -89,31 +89,42 @@ class ToolManager:
|
|||||||
"""instantiate all object with tools"""
|
"""instantiate all object with tools"""
|
||||||
|
|
||||||
self.tool_fncs = []
|
self.tool_fncs = []
|
||||||
|
self.tool_dict = {}
|
||||||
tool_configs = self._get_tool_config()
|
tool_configs = self._get_tool_config()
|
||||||
for tool_conf in tool_configs:
|
for tool_conf in tool_configs:
|
||||||
|
tool_name = tool_conf.get_name()[:-6]
|
||||||
if tool_conf.use_tool:
|
if tool_conf.use_tool:
|
||||||
logger.info(f"making tool:{tool_conf._target}")
|
logger.info(f"making tool:{tool_name}")
|
||||||
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup()))
|
fnc_list = self._get_tool_fnc(tool_conf.setup())
|
||||||
|
self.tool_fncs.extend(fnc_list)
|
||||||
|
self.tool_dict[tool_name] = fnc_list
|
||||||
else:
|
else:
|
||||||
logger.info(f"skipping tool:{tool_conf._target}")
|
logger.info(f"skipping tool:{tool_name}")
|
||||||
|
|
||||||
|
|
||||||
def get_tool_fncs(self):
|
def get_tool_fncs(self):
|
||||||
return self.tool_fncs
|
return self.tool_fncs
|
||||||
|
|
||||||
|
def get_tool_dict(self):
|
||||||
|
return self.tool_dict
|
||||||
|
|
||||||
def get_langchain_tools(self):
|
|
||||||
out = []
|
def fnc_to_structool(self, func):
|
||||||
for func in self.get_tool_fncs():
|
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
out.append(
|
return StructuredTool.from_function(
|
||||||
StructuredTool.from_function(
|
|
||||||
func=async_to_sync(func),
|
func=async_to_sync(func),
|
||||||
coroutine=func)
|
coroutine=func)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
out.append(
|
return StructuredTool.from_function(func=func)
|
||||||
StructuredTool.from_function(func=func)
|
|
||||||
)
|
|
||||||
|
def get_list_langchain_tools(self):
|
||||||
|
out = []
|
||||||
|
for func in self.get_tool_fncs():
|
||||||
|
out.append(self.fnc_to_structool(func))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_dict_langchain_tools(self):
|
||||||
|
return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user