use another's async to sync

This commit is contained in:
2025-12-10 22:39:01 +08:00
parent a26eec0fb4
commit c04fe98025

View File

@@ -16,6 +16,8 @@ from lang_agent.dummy.calculator import CalculatorConfig
# from catering_end.lang_tool import CartToolConfig, CartTool # from catering_end.lang_tool import CartToolConfig, CartTool
from langchain_core.tools.structured import StructuredTool from langchain_core.tools.structured import StructuredTool
from lang_agent.components.client_tool_manager import ClientToolManager from lang_agent.components.client_tool_manager import ClientToolManager
from asgiref.sync import async_to_sync
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class ToolManagerConfig(InstantiateConfig): class ToolManagerConfig(InstantiateConfig):
@@ -31,32 +33,32 @@ class ToolManagerConfig(InstantiateConfig):
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig) calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
def async_to_sync(async_func: Callable) -> Callable: # def async_to_sync(async_func: Callable) -> Callable:
""" # """
Decorator that converts an async function to a sync function. # Decorator that converts an async function to a sync function.
Args: # Args:
async_func: The async function to convert # async_func: The async function to convert
Returns: # Returns:
A synchronous wrapper function # A synchronous wrapper function
""" # """
@functools.wraps(async_func) # @functools.wraps(async_func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: # def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try: # try:
loop = asyncio.get_event_loop() # loop = asyncio.get_event_loop()
if loop.is_running(): # if loop.is_running():
# Handle nested event loops (e.g., in Jupyter) # # Handle nested event loops (e.g., in Jupyter)
import nest_asyncio # import nest_asyncio
nest_asyncio.apply() # nest_asyncio.apply()
return loop.run_until_complete(async_func(*args, **kwargs)) # return loop.run_until_complete(async_func(*args, **kwargs))
else: # else:
return loop.run_until_complete(async_func(*args, **kwargs)) # return loop.run_until_complete(async_func(*args, **kwargs))
except RuntimeError: # except RuntimeError:
# No event loop exists, create a new one # # No event loop exists, create a new one
return asyncio.run(async_func(*args, **kwargs)) # return asyncio.run(async_func(*args, **kwargs))
return sync_wrapper # return sync_wrapper
class ToolManager: class ToolManager:
@@ -132,8 +134,13 @@ class ToolManager:
def fnc_to_structool(self, func): def fnc_to_structool(self, func):
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
# Wrap async_to_sync result to preserve signature
sync_wrapper = async_to_sync(func)
@functools.wraps(func)
def sync_func(*args, **kwargs):
return sync_wrapper(*args, **kwargs)
return StructuredTool.from_function( return StructuredTool.from_function(
func=async_to_sync(func), func=sync_func,
coroutine=func) coroutine=func)
else: else:
return StructuredTool.from_function(func=func) return StructuredTool.from_function(func=func)
@@ -143,7 +150,11 @@ class ToolManager:
for func in self.get_tool_fncs(): for func in self.get_tool_fncs():
if isinstance(func, StructuredTool): if isinstance(func, StructuredTool):
if hasattr(func, 'coroutine') and func.coroutine is not None and (not hasattr(func, 'func') or func.func is None): if hasattr(func, 'coroutine') and func.coroutine is not None and (not hasattr(func, 'func') or func.func is None):
sync_func = async_to_sync(func.coroutine) # Wrap async_to_sync result to preserve signature
sync_wrapper = async_to_sync(func.coroutine)
@functools.wraps(func.coroutine)
def sync_func(*args, _wrapper=sync_wrapper, **kwargs):
return _wrapper(*args, **kwargs)
new_tool = StructuredTool( new_tool = StructuredTool(
name=func.name, name=func.name,
description=func.description, description=func.description,