Files
2025-10-13 21:50:52 +08:00

230 lines
9.4 KiB
Python

from langchain.embeddings.base import Embeddings
import dashscope
from dashscope import TextEmbedding
from typing import List
import asyncio
from concurrent.futures import ThreadPoolExecutor
from loguru import logger
import time
class QwenEmbeddings(Embeddings):
"""Custom Qwen embeddings using DashScope API"""
def __init__(self,
api_key: str,
model: str = "text-embedding-v4",
max_workers: int = 5,
embedding_dimension: int = 512, # one of [2048, 1536, 1024, 768, 512, 256, 128, 64]
batch_size: int = 10, # NOTE: DashScope supports up to 10 texts per batch
rate_limit_delay: float = 0.00001):
"""
Initialize Qwen embeddings
Args:
api_key: DashScope API key
model: Model name (text-embedding-v1, text-embedding-v2, etc.)
max_workers: Maximum number of concurrent workers for async operations
embedding_dimension: Dimension of embedding vectors (adjust based on model)
batch_size: Number of texts to process in one API call (max 10 for DashScope)
rate_limit_delay: Delay between batches to respect rate limits
"""
dashscope.api_key = api_key
if api_key is None:
logger.warning("no api_key provided!!")
self.MAX_BATCH_SIZE = 10
self.model = model
self.max_workers = max_workers
self.embedding_dimension = embedding_dimension
self.batch_size = min(batch_size, self.MAX_BATCH_SIZE) # DashScope limit
self.rate_limit_delay = rate_limit_delay
def _get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a batch of texts using DashScope native batch API"""
try:
# DashScope supports batch processing natively
response = TextEmbedding.call(
model=self.model,
input=texts, # Pass list directly
dimension=self.embedding_dimension
)
if response.status_code == 200:
embeddings = []
for embedding_data in response.output['embeddings']:
embeddings.append(embedding_data['embedding'])
return embeddings
else:
logger.error(f"Batch API Error: {response.status_code}, {response.message}")
# Return zero vectors as fallback
return [[0.0] * self.embedding_dimension for _ in texts]
except Exception as e:
logger.error(f"Error embedding batch of {len(texts)} texts: {e}")
# Return zero vectors as fallback
return [[0.0] * self.embedding_dimension for _ in texts]
def _get_single_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text (fallback method)"""
return self._get_batch_embeddings([text])[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using smart batching"""
if not texts:
return []
all_embeddings = []
# Process in batches
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
batch_num = i // self.batch_size + 1
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
logger.info(f"Processing batch {batch_num}/{total_batches} ({len(batch)} texts)")
batch_embeddings = self._get_batch_embeddings(batch)
all_embeddings.extend(batch_embeddings)
# Add delay between batches to respect rate limits (except for last batch)
if i + self.batch_size < len(texts) and self.rate_limit_delay > 0:
time.sleep(self.rate_limit_delay)
return all_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a single query text"""
return self._get_single_embedding(text)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents asynchronously with smart batching"""
if not texts:
return []
loop = asyncio.get_event_loop()
# Create batches
batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)]
async def process_batch_with_delay(batch: List[str], batch_idx: int) -> List[List[float]]:
"""Process a single batch with rate limiting"""
# Add delay before processing (except first batch)
if batch_idx > 0 and self.rate_limit_delay > 0:
await asyncio.sleep(self.rate_limit_delay)
# Run the batch embedding in executor
return await loop.run_in_executor(
None,
self._get_batch_embeddings,
batch
)
# Process batches with controlled concurrency
semaphore = asyncio.Semaphore(self.max_workers)
async def process_batch_limited(batch: List[str], batch_idx: int) -> List[List[float]]:
async with semaphore:
logger.info(f"Processing async batch {batch_idx + 1}/{len(batches)} ({len(batch)} texts)")
return await process_batch_with_delay(batch, batch_idx)
# Execute all batches
tasks = [
process_batch_limited(batch, idx)
for idx, batch in enumerate(batches)
]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Flatten results and handle exceptions
all_embeddings = []
for i, batch_result in enumerate(batch_results):
if isinstance(batch_result, Exception):
logger.error(f"Error processing async batch {i}: {batch_result}")
# Add zero vectors for failed batch
batch_size = len(batches[i])
all_embeddings.extend([[0.0] * self.embedding_dimension] * batch_size)
else:
all_embeddings.extend(batch_result)
return all_embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Embed a single query text asynchronously"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._get_single_embedding, text)
def get_embedding_dimension(self) -> int:
"""Get the dimension of embeddings"""
return self.embedding_dimension
def batch_embed_documents(self, texts: List[str], batch_size: int = None) -> List[List[float]]:
"""
Embed documents in batches (legacy method - now just calls embed_documents)
Args:
texts: List of texts to embed
batch_size: Batch size (if None, uses instance default)
"""
if batch_size is not None and batch_size != self.batch_size:
# Temporarily override batch size
original_batch_size = self.batch_size
self.batch_size = min(batch_size, self.MAX_BATCH_SIZE)
try:
return self.embed_documents(texts)
finally:
self.batch_size = original_batch_size
else:
return self.embed_documents(texts)
async def abatch_embed_documents(self, texts: List[str], batch_size: int = None) -> List[List[float]]:
"""
Embed documents in batches asynchronously (legacy method - now just calls aembed_documents)
Args:
texts: List of texts to embed
batch_size: Batch size (if None, uses instance default)
"""
if batch_size is not None and batch_size != self.batch_size:
# Temporarily override batch size
original_batch_size = self.batch_size
self.batch_size = min(batch_size, self.MAX_BATCH_SIZE)
try:
return await self.aembed_documents(texts)
finally:
self.batch_size = original_batch_size
else:
return await self.aembed_documents(texts)
def estimate_cost(self, texts: List[str], cost_per_1k_tokens: float = 0.0007) -> dict:
"""
Estimate the cost of embedding the given texts
Args:
texts: List of texts to estimate cost for
cost_per_1k_tokens: Cost per 1000 tokens (adjust based on current pricing)
Returns:
Dict with cost estimation details
"""
# Rough estimation: ~1 token per 4 characters for Chinese/English mixed text
total_chars = sum(len(text) for text in texts)
estimated_tokens = total_chars / 4
estimated_cost = (estimated_tokens / 1000) * cost_per_1k_tokens
batches_needed = (len(texts) + self.batch_size - 1) // self.batch_size
return {
"total_texts": len(texts),
"total_characters": total_chars,
"estimated_tokens": int(estimated_tokens),
"estimated_cost_usd": round(estimated_cost, 4),
"batches_needed": batches_needed,
"estimated_time_seconds": batches_needed * self.rate_limit_delay
}
if __name__ == "__main__":
# EXAMPLE USAGE
embeddings = QwenEmbeddings(api_key="YOUR KEY")
vector = embeddings.embed_query("Qwen embeddings are powerful for bilingual tasks.")
print(vector)