diff --git a/scripts/make_rag_database.py b/scripts/make_rag_database.py index 8b7fad3..4c3c88a 100644 --- a/scripts/make_rag_database.py +++ b/scripts/make_rag_database.py @@ -6,8 +6,7 @@ from lang_agent.rag.emb import QwenEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_community.vectorstores import FAISS -from langchain.chains import RetrievalQA -from langchain.llms import OpenAI +from langchain_openai import OpenAIEmbeddings from langchain.schema import Document def main(save_path = "assets/xiaozhan_emb"): @@ -33,13 +32,24 @@ def main(save_path = "assets/xiaozhan_emb"): texts = data embeddings = QwenEmbeddings( api_key=os.environ.get("ALI_API_KEY") - ) # Needs OPENAI_API_KEY + ) + # embeddings = OpenAIEmbeddings( + # model="text-embedding-v4", + # api_key=os.environ.get("ALI_API_KEY"), + # base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" + # ) + # embeddings = OpenAIEmbeddings() if not osp.exists(save_path): # --- STEP 2: Create vector store --- # vectorstore = FAISS.from_documents(texts, embeddings) - out_emb = embeddings.batch_embed_documents(texts) - vectorstore = FAISS.from_embeddings(zip(texts, out_emb), embeddings) + + if os.environ.get("ALI_API_KEY") is None or os.environ.get("ALI_API_KEY") == "SOMESHIT": + texts = [Document(e) for e in data] + vectorstore = FAISS.from_documents(texts, embeddings) + else: + out_emb = embeddings.batch_embed_documents(texts) + vectorstore = FAISS.from_embeddings(zip(texts, out_emb), embeddings) # --- STEP 3: SAVE the FAISS index to local files --- vectorstore.save_local(save_path)