diff --git a/lang_agent/rag/simple.py b/lang_agent/rag/simple.py index 3c2b364..713aabd 100644 --- a/lang_agent/rag/simple.py +++ b/lang_agent/rag/simple.py @@ -32,8 +32,8 @@ class SimpleRagConfig(ToolConfig): def __post_init__(self): if self.api_key == "wrong-key": # logger.info("wrong embedding key, using simple retrieval method") - key = os.environ.get("ALI_API_KEY") - if key is None: + self.api_key = os.environ.get("ALI_API_KEY") + if self.api_key is None: logger.error(f"no ALI_API_KEY provided for embedding") @@ -66,7 +66,8 @@ class SimpleRag(LangToolBase): 1. 用户询问“推荐一些辣味食物”,系统会检索并返回相关的辣味美食推荐文档。 2. 用户搜索“适合夏天的清爽饮品”,系统会检索并返回相关饮品推荐及其来源信息。 """ - retrieved_docs:List[Document] = self.vec_store.search(query, search_kwargs={"k":3}) + retrieved_docs:List[Document] = self.vec_store.similarity_search(query, + k=3) serialized = "\n\n".join( (f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs @@ -77,8 +78,10 @@ class SimpleRag(LangToolBase): return [self.retrieve] -# if __name__ == "__main__": +if __name__ == "__main__": # # config = tyro.cli(SimpleRagConfig) -# config = SimpleRagConfig() -# rag:SimpleRag = config.setup() + config = SimpleRagConfig() + rag:SimpleRag = config.setup() + u = rag.retrieve("灯与尘") + print(u) # mcp.run(transport="stdio") \ No newline at end of file