change tool def
This commit is contained in:
@@ -51,7 +51,7 @@ class SimpleRag(LangToolBase):
|
|||||||
|
|
||||||
# self.retriever = self.vec_store.as_retriever(search_kwargs={"k":3})
|
# self.retriever = self.vec_store.as_retriever(search_kwargs={"k":3})
|
||||||
|
|
||||||
def retrieve(self, query:str):
|
def retrieve(self, query:str)->str:
|
||||||
"""
|
"""
|
||||||
检索与给定查询相关的文档,并将其序列化为字符串格式。
|
检索与给定查询相关的文档,并将其序列化为字符串格式。
|
||||||
参数:
|
参数:
|
||||||
@@ -72,7 +72,7 @@ class SimpleRag(LangToolBase):
|
|||||||
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
||||||
for doc in retrieved_docs
|
for doc in retrieved_docs
|
||||||
)
|
)
|
||||||
return serialized, retrieved_docs
|
return serialized #, retrieved_docs
|
||||||
|
|
||||||
def get_tool_fnc(self):
|
def get_tool_fnc(self):
|
||||||
return [self.retrieve]
|
return [self.retrieve]
|
||||||
@@ -82,6 +82,10 @@ if __name__ == "__main__":
|
|||||||
# # config = tyro.cli(SimpleRagConfig)
|
# # config = tyro.cli(SimpleRagConfig)
|
||||||
config = SimpleRagConfig()
|
config = SimpleRagConfig()
|
||||||
rag:SimpleRag = config.setup()
|
rag:SimpleRag = config.setup()
|
||||||
|
|
||||||
|
import time
|
||||||
|
st_time = time.time()
|
||||||
u = rag.retrieve("灯与尘")
|
u = rag.retrieve("灯与尘")
|
||||||
|
print(time.time() - st_time)
|
||||||
print(u)
|
print(u)
|
||||||
# mcp.run(transport="stdio")
|
# mcp.run(transport="stdio")
|
||||||
Reference in New Issue
Block a user