make rag into tool

This commit is contained in:
2025-10-13 19:21:23 +08:00
parent 1f9f8cb76c
commit 2b6ffffafa

View File

@@ -1,22 +1,22 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type, List
import tyro import tyro
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from typing import List
import tyro import tyro
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings from lang_agent.rag.emb import QwenEmbeddings
from lang_agent.config import InstantiateConfig from lang_agent.config import ToolConfig
from lang_agent.base import LangToolBase
mcp = FastMCP("Rag") mcp = FastMCP("Rag")
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class SimpleRagConfig(InstantiateConfig): class SimpleRagConfig(ToolConfig):
_target: Type = field(default_factory=lambda: SimpleRag) _target: Type = field(default_factory=lambda: SimpleRag)
model_name:str = "text-embedding-v4" model_name:str = "text-embedding-v4"
@@ -29,7 +29,7 @@ class SimpleRagConfig(InstantiateConfig):
"""path to local database""" """path to local database"""
class SimpleRag: class SimpleRag(LangToolBase):
def __init__(self, config:SimpleRagConfig): def __init__(self, config:SimpleRagConfig):
self.config = config self.config = config
self.emb = QwenEmbeddings(self.config.api_key, self.emb = QwenEmbeddings(self.config.api_key,
@@ -65,6 +65,9 @@ class SimpleRag:
) )
return serialized, retrieved_docs return serialized, retrieved_docs
def get_tool_fnc(self):
return [self.retrieve]
if __name__ == "__main__": if __name__ == "__main__":
# config = tyro.cli(SimpleRagConfig) # config = tyro.cli(SimpleRagConfig)