support cartool

This commit is contained in:
2025-10-14 18:12:27 +08:00
parent a2954e1724
commit 783864a3f8

View File

@@ -4,6 +4,7 @@ from typing import Type, Literal
import tyro import tyro
from fastmcp import FastMCP from fastmcp import FastMCP
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastmcp.tools.tool import FunctionTool
from loguru import logger from loguru import logger
from lang_agent.rag.simple import SimpleRagConfig from lang_agent.rag.simple import SimpleRagConfig
@@ -11,6 +12,8 @@ from lang_agent.base import LangToolBase
from lang_agent.config import InstantiateConfig from lang_agent.config import InstantiateConfig
from catering_end.lang_tool import CartToolConfig, CartTool
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class MCPServerConfig(InstantiateConfig): class MCPServerConfig(InstantiateConfig):
@@ -30,6 +33,8 @@ class MCPServerConfig(InstantiateConfig):
# tool configs here # tool configs here
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
class MCPServer: class MCPServer:
def __init__(self, config: MCPServerConfig): def __init__(self, config: MCPServerConfig):
@@ -39,12 +44,14 @@ class MCPServer:
def _register_tool_fnc(self, tool:LangToolBase): def _register_tool_fnc(self, tool:LangToolBase):
for fnc in tool.get_tool_fnc(): for fnc in tool.get_tool_fnc():
if isinstance(fnc, FunctionTool):
fnc = fnc.fn
self.mcp.tool(fnc) self.mcp.tool(fnc)
def register_mcp_functions(self): def register_mcp_functions(self):
# NOTE: add config here for new tools; too stupid to do this automatically # NOTE: add config here for new tools; too stupid to do this automatically
tool_configs = [self.config.rag_config] tool_configs = [self.config.rag_config, self.config.cart_config]
for tool_conf in tool_configs: for tool_conf in tool_configs:
if tool_conf.use_tool: if tool_conf.use_tool:
logger.info(f"using tool:{tool_conf._target}") logger.info(f"using tool:{tool_conf._target}")