support async
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, List
|
||||
import functools
|
||||
from typing import Type, List, Callable, Any
|
||||
import tyro
|
||||
import json
|
||||
import inspect
|
||||
import asyncio
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
@@ -30,6 +31,33 @@ class ToolManagerConfig(InstantiateConfig):
|
||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||
|
||||
|
||||
def async_to_sync(async_func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that converts an async function to a sync function.
|
||||
|
||||
Args:
|
||||
async_func: The async function to convert
|
||||
|
||||
Returns:
|
||||
A synchronous wrapper function
|
||||
"""
|
||||
@functools.wraps(async_func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Handle nested event loops (e.g., in Jupyter)
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
return asyncio.run(async_func(*args, **kwargs))
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config:ToolManagerConfig):
|
||||
self.config = config
|
||||
@@ -74,4 +102,18 @@ class ToolManager:
|
||||
|
||||
|
||||
def get_langchain_tools(self):
|
||||
return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()]
|
||||
out = []
|
||||
for func in self.get_tool_fncs():
|
||||
if inspect.iscoroutinefunction(func):
|
||||
out.append(
|
||||
StructuredTool.from_function(
|
||||
func=async_to_sync(func),
|
||||
coroutine=func)
|
||||
)
|
||||
else:
|
||||
out.append(
|
||||
StructuredTool.from_function(func=func)
|
||||
)
|
||||
|
||||
return out
|
||||
# return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()]
|
||||
Reference in New Issue
Block a user