use another's async to sync
This commit is contained in:
@@ -16,6 +16,8 @@ from lang_agent.dummy.calculator import CalculatorConfig
|
|||||||
# from catering_end.lang_tool import CartToolConfig, CartTool
|
# from catering_end.lang_tool import CartToolConfig, CartTool
|
||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
from lang_agent.components.client_tool_manager import ClientToolManager
|
from lang_agent.components.client_tool_manager import ClientToolManager
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolManagerConfig(InstantiateConfig):
|
class ToolManagerConfig(InstantiateConfig):
|
||||||
@@ -31,32 +33,32 @@ class ToolManagerConfig(InstantiateConfig):
|
|||||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||||
|
|
||||||
|
|
||||||
def async_to_sync(async_func: Callable) -> Callable:
|
# def async_to_sync(async_func: Callable) -> Callable:
|
||||||
"""
|
# """
|
||||||
Decorator that converts an async function to a sync function.
|
# Decorator that converts an async function to a sync function.
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
async_func: The async function to convert
|
# async_func: The async function to convert
|
||||||
|
|
||||||
Returns:
|
# Returns:
|
||||||
A synchronous wrapper function
|
# A synchronous wrapper function
|
||||||
"""
|
# """
|
||||||
@functools.wraps(async_func)
|
# @functools.wraps(async_func)
|
||||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
# def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
try:
|
# try:
|
||||||
loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
# if loop.is_running():
|
||||||
# Handle nested event loops (e.g., in Jupyter)
|
# # Handle nested event loops (e.g., in Jupyter)
|
||||||
import nest_asyncio
|
# import nest_asyncio
|
||||||
nest_asyncio.apply()
|
# nest_asyncio.apply()
|
||||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
# return loop.run_until_complete(async_func(*args, **kwargs))
|
||||||
else:
|
# else:
|
||||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
# return loop.run_until_complete(async_func(*args, **kwargs))
|
||||||
except RuntimeError:
|
# except RuntimeError:
|
||||||
# No event loop exists, create a new one
|
# # No event loop exists, create a new one
|
||||||
return asyncio.run(async_func(*args, **kwargs))
|
# return asyncio.run(async_func(*args, **kwargs))
|
||||||
|
|
||||||
return sync_wrapper
|
# return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
@@ -132,8 +134,13 @@ class ToolManager:
|
|||||||
|
|
||||||
def fnc_to_structool(self, func):
|
def fnc_to_structool(self, func):
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
|
# Wrap async_to_sync result to preserve signature
|
||||||
|
sync_wrapper = async_to_sync(func)
|
||||||
|
@functools.wraps(func)
|
||||||
|
def sync_func(*args, **kwargs):
|
||||||
|
return sync_wrapper(*args, **kwargs)
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
func=async_to_sync(func),
|
func=sync_func,
|
||||||
coroutine=func)
|
coroutine=func)
|
||||||
else:
|
else:
|
||||||
return StructuredTool.from_function(func=func)
|
return StructuredTool.from_function(func=func)
|
||||||
@@ -143,7 +150,11 @@ class ToolManager:
|
|||||||
for func in self.get_tool_fncs():
|
for func in self.get_tool_fncs():
|
||||||
if isinstance(func, StructuredTool):
|
if isinstance(func, StructuredTool):
|
||||||
if hasattr(func, 'coroutine') and func.coroutine is not None and (not hasattr(func, 'func') or func.func is None):
|
if hasattr(func, 'coroutine') and func.coroutine is not None and (not hasattr(func, 'func') or func.func is None):
|
||||||
sync_func = async_to_sync(func.coroutine)
|
# Wrap async_to_sync result to preserve signature
|
||||||
|
sync_wrapper = async_to_sync(func.coroutine)
|
||||||
|
@functools.wraps(func.coroutine)
|
||||||
|
def sync_func(*args, _wrapper=sync_wrapper, **kwargs):
|
||||||
|
return _wrapper(*args, **kwargs)
|
||||||
new_tool = StructuredTool(
|
new_tool = StructuredTool(
|
||||||
name=func.name,
|
name=func.name,
|
||||||
description=func.description,
|
description=func.description,
|
||||||
|
|||||||
Reference in New Issue
Block a user