make rag into tool
This commit is contained in:
@@ -1,22 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
from typing import Type, List
|
||||
import tyro
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from typing import List
|
||||
import tyro
|
||||
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.documents.base import Document
|
||||
|
||||
from lang_agent.rag.emb import QwenEmbeddings
|
||||
from lang_agent.config import InstantiateConfig
|
||||
from lang_agent.config import ToolConfig
|
||||
from lang_agent.base import LangToolBase
|
||||
|
||||
|
||||
mcp = FastMCP("Rag")
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class SimpleRagConfig(InstantiateConfig):
|
||||
class SimpleRagConfig(ToolConfig):
|
||||
_target: Type = field(default_factory=lambda: SimpleRag)
|
||||
|
||||
model_name:str = "text-embedding-v4"
|
||||
@@ -29,7 +29,7 @@ class SimpleRagConfig(InstantiateConfig):
|
||||
"""path to local database"""
|
||||
|
||||
|
||||
class SimpleRag:
|
||||
class SimpleRag(LangToolBase):
|
||||
def __init__(self, config:SimpleRagConfig):
|
||||
self.config = config
|
||||
self.emb = QwenEmbeddings(self.config.api_key,
|
||||
@@ -64,6 +64,9 @@ class SimpleRag:
|
||||
for doc in retrieved_docs
|
||||
)
|
||||
return serialized, retrieved_docs
|
||||
|
||||
def get_tool_fnc(self):
|
||||
return [self.retrieve]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user