make rag into tool

This commit is contained in:
2025-10-13 19:21:23 +08:00
parent 1f9f8cb76c
commit 2b6ffffafa

View File

@@ -1,22 +1,22 @@
from dataclasses import dataclass, field
from typing import Type
from typing import Type, List
import tyro
from mcp.server.fastmcp import FastMCP
from typing import List
import tyro
from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings
from lang_agent.config import InstantiateConfig
from lang_agent.config import ToolConfig
from lang_agent.base import LangToolBase
mcp = FastMCP("Rag")
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class SimpleRagConfig(InstantiateConfig):
class SimpleRagConfig(ToolConfig):
_target: Type = field(default_factory=lambda: SimpleRag)
model_name:str = "text-embedding-v4"
@@ -29,7 +29,7 @@ class SimpleRagConfig(InstantiateConfig):
"""path to local database"""
class SimpleRag:
class SimpleRag(LangToolBase):
def __init__(self, config:SimpleRagConfig):
self.config = config
self.emb = QwenEmbeddings(self.config.api_key,
@@ -64,6 +64,9 @@ class SimpleRag:
for doc in retrieved_docs
)
return serialized, retrieved_docs
def get_tool_fnc(self):
return [self.retrieve]
if __name__ == "__main__":