diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index ce97914..41151f7 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, field, is_dataclass -from typing import Type, List +import functools +from typing import Type, List, Callable, Any import tyro -import json +import inspect import asyncio import os.path as osp from loguru import logger @@ -30,6 +31,33 @@ class ToolManagerConfig(InstantiateConfig): calc_config: CalculatorConfig = field(default_factory=CalculatorConfig) +def async_to_sync(async_func: Callable) -> Callable: + """ + Decorator that converts an async function to a sync function. + + Args: + async_func: The async function to convert + + Returns: + A synchronous wrapper function + """ + @functools.wraps(async_func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Handle nested event loops (e.g., in Jupyter) + import nest_asyncio + nest_asyncio.apply() + return loop.run_until_complete(async_func(*args, **kwargs)) + else: + return loop.run_until_complete(async_func(*args, **kwargs)) + except RuntimeError: + # No event loop exists, create a new one + return asyncio.run(async_func(*args, **kwargs)) + + return sync_wrapper + class ToolManager: def __init__(self, config:ToolManagerConfig): self.config = config @@ -74,4 +102,18 @@ class ToolManager: def get_langchain_tools(self): - return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()] \ No newline at end of file + out = [] + for func in self.get_tool_fncs(): + if inspect.iscoroutinefunction(func): + out.append( + StructuredTool.from_function( + func=async_to_sync(func), + coroutine=func) + ) + else: + out.append( + StructuredTool.from_function(func=func) + ) + + return out + # return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()] \ No newline at end of file