diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 12ae52b..83c8d14 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -89,7 +89,6 @@ class ToolManager: """instantiate all object with tools""" self.tool_fncs = [] - self.tool_dict = {} tool_configs = self._get_tool_config() for tool_conf in tool_configs: tool_name = tool_conf.get_name()[:-6] @@ -97,9 +96,10 @@ class ToolManager: logger.info(f"making tool:{tool_name}") fnc_list = self._get_tool_fnc(tool_conf.setup()) self.tool_fncs.extend(fnc_list) - self.tool_dict[tool_name] = fnc_list else: logger.info(f"skipping tool:{tool_name}") + + self._build_langchain_tools() def get_tool_fncs(self): @@ -118,13 +118,12 @@ class ToolManager: else: return StructuredTool.from_function(func=func) + def _build_langchain_tools(self): + self.langchain_tools = [] + for func in self.get_tool_fncs(): + self.langchain_tools.append(self.fnc_to_structool(func)) + + return self.langchain_tools def get_list_langchain_tools(self): - out = [] - for func in self.get_tool_fncs(): - out.append(self.fnc_to_structool(func)) - - return out - - def get_dict_langchain_tools(self): - return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict) + return self.langchain_tools \ No newline at end of file