From 2b6ffffafa2ab86af50ced36cbfd1b2dd565b5b8 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 13 Oct 2025 19:21:23 +0800 Subject: [PATCH] make rag into tool --- lang_agent/rag/simple.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lang_agent/rag/simple.py b/lang_agent/rag/simple.py index 2c7d429..d548745 100644 --- a/lang_agent/rag/simple.py +++ b/lang_agent/rag/simple.py @@ -1,22 +1,22 @@ from dataclasses import dataclass, field -from typing import Type +from typing import Type, List import tyro from mcp.server.fastmcp import FastMCP -from typing import List import tyro from langchain_community.vectorstores import FAISS from langchain_core.documents.base import Document from lang_agent.rag.emb import QwenEmbeddings -from lang_agent.config import InstantiateConfig +from lang_agent.config import ToolConfig +from lang_agent.base import LangToolBase mcp = FastMCP("Rag") @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class SimpleRagConfig(InstantiateConfig): +class SimpleRagConfig(ToolConfig): _target: Type = field(default_factory=lambda: SimpleRag) model_name:str = "text-embedding-v4" @@ -29,7 +29,7 @@ class SimpleRagConfig(InstantiateConfig): """path to local database""" -class SimpleRag: +class SimpleRag(LangToolBase): def __init__(self, config:SimpleRagConfig): self.config = config self.emb = QwenEmbeddings(self.config.api_key, @@ -64,6 +64,9 @@ class SimpleRag: for doc in retrieved_docs ) return serialized, retrieved_docs + + def get_tool_fnc(self): + return [self.retrieve] if __name__ == "__main__":