diff --git a/lang_agent/components/tool_manager.py b/lang_agent/components/tool_manager.py index b95717e..5d6c5c7 100644 --- a/lang_agent/components/tool_manager.py +++ b/lang_agent/components/tool_manager.py @@ -16,7 +16,7 @@ from lang_agent.dummy.calculator import CalculatorConfig # from catering_end.lang_tool import CartToolConfig, CartTool from langchain_core.tools.structured import StructuredTool from lang_agent.components.client_tool_manager import ClientToolManager -from asgiref.sync import async_to_sync +# from asgiref.sync import async_to_sync # NOTE: THIS SHT DOES NOT WORK @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass @@ -33,32 +33,32 @@ class ToolManagerConfig(InstantiateConfig): calc_config: CalculatorConfig = field(default_factory=CalculatorConfig) -# def async_to_sync(async_func: Callable) -> Callable: -# """ -# Decorator that converts an async function to a sync function. +def async_to_sync(async_func: Callable) -> Callable: + """ + Decorator that converts an async function to a sync function. -# Args: -# async_func: The async function to convert + Args: + async_func: The async function to convert -# Returns: -# A synchronous wrapper function -# """ -# @functools.wraps(async_func) -# def sync_wrapper(*args: Any, **kwargs: Any) -> Any: -# try: -# loop = asyncio.get_event_loop() -# if loop.is_running(): -# # Handle nested event loops (e.g., in Jupyter) -# import nest_asyncio -# nest_asyncio.apply() -# return loop.run_until_complete(async_func(*args, **kwargs)) -# else: -# return loop.run_until_complete(async_func(*args, **kwargs)) -# except RuntimeError: -# # No event loop exists, create a new one -# return asyncio.run(async_func(*args, **kwargs)) + Returns: + A synchronous wrapper function + """ + @functools.wraps(async_func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Handle nested event loops (e.g., in Jupyter) + import nest_asyncio + nest_asyncio.apply() + return loop.run_until_complete(async_func(*args, **kwargs)) + else: + return loop.run_until_complete(async_func(*args, **kwargs)) + except RuntimeError: + # No event loop exists, create a new one + return asyncio.run(async_func(*args, **kwargs)) -# return sync_wrapper + return sync_wrapper class ToolManager: