auto get tool list

This commit is contained in:
2025-10-14 20:38:22 +08:00
parent 0935c0ff0b
commit cf7a686fe7

View File

@@ -1,5 +1,5 @@
# https://gofastmcp.com/patterns/decorating-methods # https://gofastmcp.com/patterns/decorating-methods
from dataclasses import dataclass, field from dataclasses import dataclass, field, is_dataclass
from typing import Type, Literal from typing import Type, Literal
import tyro import tyro
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -9,7 +9,8 @@ from loguru import logger
from lang_agent.rag.simple import SimpleRagConfig from lang_agent.rag.simple import SimpleRagConfig
from lang_agent.base import LangToolBase from lang_agent.base import LangToolBase
from lang_agent.config import InstantiateConfig from lang_agent.config import InstantiateConfig, ToolConfig
from lang_agent.dummy.calculator import Calculator, CalculatorConfig
from catering_end.lang_tool import CartToolConfig, CartTool from catering_end.lang_tool import CartToolConfig, CartTool
@@ -30,17 +31,19 @@ class MCPServerConfig(InstantiateConfig):
transport:Literal["stdio", "sse", "streamable-http"] = "streamable-http" transport:Literal["stdio", "sse", "streamable-http"] = "streamable-http"
"""transport method""" """transport method"""
# tool configs here # tool configs here;
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
cart_config: CartToolConfig = field(default_factory=CartToolConfig) cart_config: CartToolConfig = field(default_factory=CartToolConfig)
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
class MCPServer: class MCPServer:
def __init__(self, config: MCPServerConfig): def __init__(self, config: MCPServerConfig):
self.config = config self.config = config
self.mcp = FastMCP(self.config.server_name) self.mcp = FastMCP(self.config.server_name)
self.register_mcp_functions() # self.register_mcp_functions()
def _register_tool_fnc(self, tool:LangToolBase): def _register_tool_fnc(self, tool:LangToolBase):
for fnc in tool.get_tool_fnc(): for fnc in tool.get_tool_fnc():
@@ -48,6 +51,15 @@ class MCPServer:
fnc = fnc.fn fnc = fnc.fn
self.mcp.tool(fnc) self.mcp.tool(fnc)
def _get_tool_config(self):
tool_confs = []
for e in dir(self.config):
el = getattr(self.config, e)
if ("config" in e) and is_dataclass(el):
tool_confs.append(el)
return tool_confs
def register_mcp_functions(self): def register_mcp_functions(self):
# NOTE: add config here for new tools; too stupid to do this automatically # NOTE: add config here for new tools; too stupid to do this automatically
@@ -76,3 +88,9 @@ class MCPServer:
self.mcp.run(transport=self.config.transport, self.mcp.run(transport=self.config.transport,
host=self.config.host, host=self.config.host,
port=self.config.port) port=self.config.port)
if __name__ == "__main__":
conf:MCPServer = MCPServerConfig().setup()
tool_conf = conf._get_tool_config()
for e in tool_conf:
print(e)