auto get tool list
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# https://gofastmcp.com/patterns/decorating-methods
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, Literal
|
||||
import tyro
|
||||
from fastmcp import FastMCP
|
||||
@@ -9,7 +9,8 @@ from loguru import logger
|
||||
|
||||
from lang_agent.rag.simple import SimpleRagConfig
|
||||
from lang_agent.base import LangToolBase
|
||||
from lang_agent.config import InstantiateConfig
|
||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||
from lang_agent.dummy.calculator import Calculator, CalculatorConfig
|
||||
|
||||
|
||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
@@ -30,17 +31,19 @@ class MCPServerConfig(InstantiateConfig):
|
||||
transport:Literal["stdio", "sse", "streamable-http"] = "streamable-http"
|
||||
"""transport method"""
|
||||
|
||||
# tool configs here
|
||||
# tool configs here;
|
||||
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||
|
||||
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||
|
||||
calc_config: CalculatorConfig = field(default_factory=CalculatorConfig)
|
||||
|
||||
|
||||
class MCPServer:
|
||||
def __init__(self, config: MCPServerConfig):
|
||||
self.config = config
|
||||
self.mcp = FastMCP(self.config.server_name)
|
||||
self.register_mcp_functions()
|
||||
# self.register_mcp_functions()
|
||||
|
||||
def _register_tool_fnc(self, tool:LangToolBase):
|
||||
for fnc in tool.get_tool_fnc():
|
||||
@@ -48,6 +51,15 @@ class MCPServer:
|
||||
fnc = fnc.fn
|
||||
self.mcp.tool(fnc)
|
||||
|
||||
def _get_tool_config(self):
|
||||
tool_confs = []
|
||||
for e in dir(self.config):
|
||||
el = getattr(self.config, e)
|
||||
if ("config" in e) and is_dataclass(el):
|
||||
tool_confs.append(el)
|
||||
|
||||
return tool_confs
|
||||
|
||||
def register_mcp_functions(self):
|
||||
|
||||
# NOTE: add config here for new tools; too stupid to do this automatically
|
||||
@@ -76,3 +88,9 @@ class MCPServer:
|
||||
self.mcp.run(transport=self.config.transport,
|
||||
host=self.config.host,
|
||||
port=self.config.port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
conf:MCPServer = MCPServerConfig().setup()
|
||||
tool_conf = conf._get_tool_config()
|
||||
for e in tool_conf:
|
||||
print(e)
|
||||
Reference in New Issue
Block a user