simple rag tool

This commit is contained in:
2025-10-11 14:13:57 +08:00
parent 852db6e6cb
commit a35a0914c4

View File

@@ -1,9 +1,18 @@
from dataclasses import dataclass, field
from typing import Type
import tyro
from mcp.server.fastmcp import FastMCP
from typing import List
import tyro
from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings
mcp = FastMCP("Rag")
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class SimpleRagConfig:
@@ -15,7 +24,7 @@ class SimpleRagConfig:
api_key:str = "wrong-key"
"""api_key for model; for generic text splitting; give a wrong key"""
database_path:str = None
folder_path:str = None
"""path to local database"""
@@ -24,3 +33,25 @@ class SimpleRag:
self.config = config
self.emb = QwenEmbeddings(self.config.api_key,
self.config.model_name)
self.vec_store = FAISS.load_local(
folder_path=self.config.folder_path,
embeddings=self.emb,
allow_dangerous_deserialization=True # Required for LangChain >= 0.1.1
)
# self.retriever = self.vec_store.as_retriever(search_kwargs={"k":3})
@mcp.tool()
def retrieve(self, query:str):
retrieved_docs:List[Document] = self.vec_store.search(query, search_kwargs={"k":3})
serialized = "\n\n".join(
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs
if __name__ == "__main__":
config = tyro.cli(SimpleRagConfig)
rag = SimpleRag(config)
mcp.run(transport="stdio")