from langchain.embeddings.base import Embeddings import dashscope from dashscope import TextEmbedding from typing import List import asyncio from concurrent.futures import ThreadPoolExecutor from loguru import logger import time class QwenEmbeddings(Embeddings): """Custom Qwen embeddings using DashScope API""" def __init__(self, api_key: str, model: str = "text-embedding-v4", max_workers: int = 5, embedding_dimension: int = 512, batch_size: int = 10, # NOTE: DashScope supports up to 10 texts per batch rate_limit_delay: float = 0.00001): """ Initialize Qwen embeddings Args: api_key: DashScope API key model: Model name (text-embedding-v1, text-embedding-v2, etc.) max_workers: Maximum number of concurrent workers for async operations embedding_dimension: Dimension of embedding vectors (adjust based on model) batch_size: Number of texts to process in one API call (max 10 for DashScope) rate_limit_delay: Delay between batches to respect rate limits """ dashscope.api_key = api_key if api_key is None: logger.warning("no api_key provided!!") self.MAX_BATCH_SIZE = 10 self.model = model self.max_workers = max_workers self.embedding_dimension = embedding_dimension self.batch_size = min(batch_size, self.MAX_BATCH_SIZE) # DashScope limit self.rate_limit_delay = rate_limit_delay def _get_batch_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for a batch of texts using DashScope native batch API""" try: # DashScope supports batch processing natively response = TextEmbedding.call( model=self.model, input=texts, # Pass list directly dimension=self.embedding_dimension ) if response.status_code == 200: embeddings = [] for embedding_data in response.output['embeddings']: embeddings.append(embedding_data['embedding']) return embeddings else: logger.error(f"Batch API Error: {response.status_code}, {response.message}") # Return zero vectors as fallback return [[0.0] * self.embedding_dimension for _ in texts] except Exception as e: logger.error(f"Error embedding batch of {len(texts)} texts: {e}") # Return zero vectors as fallback return [[0.0] * self.embedding_dimension for _ in texts] def _get_single_embedding(self, text: str) -> List[float]: """Get embedding for a single text (fallback method)""" return self._get_batch_embeddings([text])[0] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of documents using smart batching""" if not texts: return [] all_embeddings = [] # Process in batches for i in range(0, len(texts), self.batch_size): batch = texts[i:i + self.batch_size] batch_num = i // self.batch_size + 1 total_batches = (len(texts) + self.batch_size - 1) // self.batch_size logger.info(f"Processing batch {batch_num}/{total_batches} ({len(batch)} texts)") batch_embeddings = self._get_batch_embeddings(batch) all_embeddings.extend(batch_embeddings) # Add delay between batches to respect rate limits (except for last batch) if i + self.batch_size < len(texts) and self.rate_limit_delay > 0: time.sleep(self.rate_limit_delay) return all_embeddings def embed_query(self, text: str) -> List[float]: """Embed a single query text""" return self._get_single_embedding(text) async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of documents asynchronously with smart batching""" if not texts: return [] loop = asyncio.get_event_loop() # Create batches batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)] async def process_batch_with_delay(batch: List[str], batch_idx: int) -> List[List[float]]: """Process a single batch with rate limiting""" # Add delay before processing (except first batch) if batch_idx > 0 and self.rate_limit_delay > 0: await asyncio.sleep(self.rate_limit_delay) # Run the batch embedding in executor return await loop.run_in_executor( None, self._get_batch_embeddings, batch ) # Process batches with controlled concurrency semaphore = asyncio.Semaphore(self.max_workers) async def process_batch_limited(batch: List[str], batch_idx: int) -> List[List[float]]: async with semaphore: logger.info(f"Processing async batch {batch_idx + 1}/{len(batches)} ({len(batch)} texts)") return await process_batch_with_delay(batch, batch_idx) # Execute all batches tasks = [ process_batch_limited(batch, idx) for idx, batch in enumerate(batches) ] batch_results = await asyncio.gather(*tasks, return_exceptions=True) # Flatten results and handle exceptions all_embeddings = [] for i, batch_result in enumerate(batch_results): if isinstance(batch_result, Exception): logger.error(f"Error processing async batch {i}: {batch_result}") # Add zero vectors for failed batch batch_size = len(batches[i]) all_embeddings.extend([[0.0] * self.embedding_dimension] * batch_size) else: all_embeddings.extend(batch_result) return all_embeddings async def aembed_query(self, text: str) -> List[float]: """Embed a single query text asynchronously""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._get_single_embedding, text) def get_embedding_dimension(self) -> int: """Get the dimension of embeddings""" return self.embedding_dimension def batch_embed_documents(self, texts: List[str], batch_size: int = None) -> List[List[float]]: """ Embed documents in batches (legacy method - now just calls embed_documents) Args: texts: List of texts to embed batch_size: Batch size (if None, uses instance default) """ if batch_size is not None and batch_size != self.batch_size: # Temporarily override batch size original_batch_size = self.batch_size self.batch_size = min(batch_size, self.MAX_BATCH_SIZE) try: return self.embed_documents(texts) finally: self.batch_size = original_batch_size else: return self.embed_documents(texts) async def abatch_embed_documents(self, texts: List[str], batch_size: int = None) -> List[List[float]]: """ Embed documents in batches asynchronously (legacy method - now just calls aembed_documents) Args: texts: List of texts to embed batch_size: Batch size (if None, uses instance default) """ if batch_size is not None and batch_size != self.batch_size: # Temporarily override batch size original_batch_size = self.batch_size self.batch_size = min(batch_size, self.MAX_BATCH_SIZE) try: return await self.aembed_documents(texts) finally: self.batch_size = original_batch_size else: return await self.aembed_documents(texts) def estimate_cost(self, texts: List[str], cost_per_1k_tokens: float = 0.0007) -> dict: """ Estimate the cost of embedding the given texts Args: texts: List of texts to estimate cost for cost_per_1k_tokens: Cost per 1000 tokens (adjust based on current pricing) Returns: Dict with cost estimation details """ # Rough estimation: ~1 token per 4 characters for Chinese/English mixed text total_chars = sum(len(text) for text in texts) estimated_tokens = total_chars / 4 estimated_cost = (estimated_tokens / 1000) * cost_per_1k_tokens batches_needed = (len(texts) + self.batch_size - 1) // self.batch_size return { "total_texts": len(texts), "total_characters": total_chars, "estimated_tokens": int(estimated_tokens), "estimated_cost_usd": round(estimated_cost, 4), "batches_needed": batches_needed, "estimated_time_seconds": batches_needed * self.rate_limit_delay } if __name__ == "__main__": # EXAMPLE USAGE embeddings = QwenEmbeddings(api_key="YOUR KEY") vector = embeddings.embed_query("Qwen embeddings are powerful for bilingual tasks.") print(vector)