simple rag tool

This commit is contained in:
2025-10-11 14:13:57 +08:00
parent 852db6e6cb
commit a35a0914c4

View File

@@ -1,9 +1,18 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type
import tyro import tyro
from mcp.server.fastmcp import FastMCP
from typing import List
import tyro
from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings from lang_agent.rag.emb import QwenEmbeddings
mcp = FastMCP("Rag")
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class SimpleRagConfig: class SimpleRagConfig:
@@ -15,7 +24,7 @@ class SimpleRagConfig:
api_key:str = "wrong-key" api_key:str = "wrong-key"
"""api_key for model; for generic text splitting; give a wrong key""" """api_key for model; for generic text splitting; give a wrong key"""
database_path:str = None folder_path:str = None
"""path to local database""" """path to local database"""
@@ -24,3 +33,25 @@ class SimpleRag:
self.config = config self.config = config
self.emb = QwenEmbeddings(self.config.api_key, self.emb = QwenEmbeddings(self.config.api_key,
self.config.model_name) self.config.model_name)
self.vec_store = FAISS.load_local(
folder_path=self.config.folder_path,
embeddings=self.emb,
allow_dangerous_deserialization=True # Required for LangChain >= 0.1.1
)
# self.retriever = self.vec_store.as_retriever(search_kwargs={"k":3})
@mcp.tool()
def retrieve(self, query:str):
retrieved_docs:List[Document] = self.vec_store.search(query, search_kwargs={"k":3})
serialized = "\n\n".join(
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs
if __name__ == "__main__":
config = tyro.cli(SimpleRagConfig)
rag = SimpleRag(config)
mcp.run(transport="stdio")