From a35a0914c4e5630162c3195f45ff1f66f29ba2fa Mon Sep 17 00:00:00 2001 From: goulustis Date: Sat, 11 Oct 2025 14:13:57 +0800 Subject: [PATCH] simple rag tool --- lang_agent/rag/simple.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/lang_agent/rag/simple.py b/lang_agent/rag/simple.py index 745facd..ed0fecb 100644 --- a/lang_agent/rag/simple.py +++ b/lang_agent/rag/simple.py @@ -1,9 +1,18 @@ from dataclasses import dataclass, field from typing import Type import tyro +from mcp.server.fastmcp import FastMCP +from typing import List +import tyro + +from langchain_community.vectorstores import FAISS +from langchain_core.documents.base import Document from lang_agent.rag.emb import QwenEmbeddings + +mcp = FastMCP("Rag") + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class SimpleRagConfig: @@ -15,7 +24,7 @@ class SimpleRagConfig: api_key:str = "wrong-key" """api_key for model; for generic text splitting; give a wrong key""" - database_path:str = None + folder_path:str = None """path to local database""" @@ -23,4 +32,26 @@ class SimpleRag: def __init__(self, config:SimpleRagConfig): self.config = config self.emb = QwenEmbeddings(self.config.api_key, - self.config.model_name) \ No newline at end of file + self.config.model_name) + self.vec_store = FAISS.load_local( + folder_path=self.config.folder_path, + embeddings=self.emb, + allow_dangerous_deserialization=True # Required for LangChain >= 0.1.1 + ) + + # self.retriever = self.vec_store.as_retriever(search_kwargs={"k":3}) + + @mcp.tool() + def retrieve(self, query:str): + retrieved_docs:List[Document] = self.vec_store.search(query, search_kwargs={"k":3}) + serialized = "\n\n".join( + (f"Source: {doc.metadata}\nContent: {doc.page_content}") + for doc in retrieved_docs + ) + return serialized, retrieved_docs + + +if __name__ == "__main__": + config = tyro.cli(SimpleRagConfig) + rag = SimpleRag(config) + mcp.run(transport="stdio") \ No newline at end of file