support async
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, List
|
import functools
|
||||||
|
from typing import Type, List, Callable, Any
|
||||||
import tyro
|
import tyro
|
||||||
import json
|
import inspect
|
||||||
import asyncio
|
import asyncio
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -30,6 +31,33 @@ class ToolManagerConfig(InstantiateConfig):
|
|||||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def async_to_sync(async_func: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator that converts an async function to a sync function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_func: The async function to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A synchronous wrapper function
|
||||||
|
"""
|
||||||
|
@functools.wraps(async_func)
|
||||||
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# Handle nested event loops (e.g., in Jupyter)
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||||
|
else:
|
||||||
|
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop exists, create a new one
|
||||||
|
return asyncio.run(async_func(*args, **kwargs))
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
def __init__(self, config:ToolManagerConfig):
|
def __init__(self, config:ToolManagerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -74,4 +102,18 @@ class ToolManager:
|
|||||||
|
|
||||||
|
|
||||||
def get_langchain_tools(self):
|
def get_langchain_tools(self):
|
||||||
return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()]
|
out = []
|
||||||
|
for func in self.get_tool_fncs():
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
out.append(
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=async_to_sync(func),
|
||||||
|
coroutine=func)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out.append(
|
||||||
|
StructuredTool.from_function(func=func)
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
# return [StructuredTool.from_function(func=func) for func in self.get_tool_fncs()]
|
||||||
Reference in New Issue
Block a user