make rag into tool
This commit is contained in:
@@ -1,22 +1,22 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type
|
from typing import Type, List
|
||||||
import tyro
|
import tyro
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from typing import List
|
|
||||||
import tyro
|
import tyro
|
||||||
|
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
|
|
||||||
from lang_agent.rag.emb import QwenEmbeddings
|
from lang_agent.rag.emb import QwenEmbeddings
|
||||||
from lang_agent.config import InstantiateConfig
|
from lang_agent.config import ToolConfig
|
||||||
|
from lang_agent.base import LangToolBase
|
||||||
|
|
||||||
|
|
||||||
mcp = FastMCP("Rag")
|
mcp = FastMCP("Rag")
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimpleRagConfig(InstantiateConfig):
|
class SimpleRagConfig(ToolConfig):
|
||||||
_target: Type = field(default_factory=lambda: SimpleRag)
|
_target: Type = field(default_factory=lambda: SimpleRag)
|
||||||
|
|
||||||
model_name:str = "text-embedding-v4"
|
model_name:str = "text-embedding-v4"
|
||||||
@@ -29,7 +29,7 @@ class SimpleRagConfig(InstantiateConfig):
|
|||||||
"""path to local database"""
|
"""path to local database"""
|
||||||
|
|
||||||
|
|
||||||
class SimpleRag:
|
class SimpleRag(LangToolBase):
|
||||||
def __init__(self, config:SimpleRagConfig):
|
def __init__(self, config:SimpleRagConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.emb = QwenEmbeddings(self.config.api_key,
|
self.emb = QwenEmbeddings(self.config.api_key,
|
||||||
@@ -64,6 +64,9 @@ class SimpleRag:
|
|||||||
for doc in retrieved_docs
|
for doc in retrieved_docs
|
||||||
)
|
)
|
||||||
return serialized, retrieved_docs
|
return serialized, retrieved_docs
|
||||||
|
|
||||||
|
def get_tool_fnc(self):
|
||||||
|
return [self.retrieve]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user