Batch Embedder with Error Handling
import time
import logging
from typing import Optional
from dataclasses import dataclass
import numpy as np
from openai import OpenAI, RateLimitError, APIConnectionError
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingResult:
index: int
text: str
embedding: Optional[list[float]]
error: Optional[str] = None
@property
def success(self) -> bool:
return self.embedding is not None
class BatchEmbedder:
'''
Embed large lists of texts in batches with:
- Configurable batch size
- Exponential backoff on rate limits
- Per-item error tracking (no silent failures)
- Progress reporting
'''
def __init__(
self,
model: str = "text-embedding-3-small",
batch_size: int = 100,
max_retries: int = 5,
base_delay: float = 1.0,
show_progress: bool = True,
):
self.model = model
self.batch_size = batch_size
self.max_retries = max_retries
self.base_delay = base_delay
self.show_progress = show_progress
self.client = OpenAI()
def _embed_batch(self, texts: list[str]) -> list[Optional[list[float]]]:
'''Embed one batch with retry. Returns None for failed items.'''
for attempt in range(self.max_retries):
try:
response = self.client.embeddings.create(
model=self.model,
input=texts,
)
# Maintain original order (API returns sorted by index)
ordered = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in ordered]
except RateLimitError as e:
delay = float(getattr(e, 'retry_after', self.base_delay * (2 ** attempt)))
logger.warning(f"Rate limited. Waiting {delay:.1f}s (attempt {attempt + 1})")
time.sleep(delay)
except APIConnectionError as e:
delay = self.base_delay * (2 ** attempt)
logger.warning(f"Connection error. Retrying in {delay:.1f}s: {e}")
time.sleep(delay)
except Exception as e:
logger.error(f"Unretryable embedding error: {e}")
return [None] * len(texts) # Fail entire batch
logger.error(f"Batch failed after {self.max_retries} attempts")
return [None] * len(texts)
def embed(self, texts: list[str]) -> list[EmbeddingResult]:
'''
Embed all texts, returning results with success/error per item.
Args:
texts: List of strings to embed
Returns:
List of EmbeddingResult, one per input text, in original order
'''
if not texts:
return []
results: list[EmbeddingResult] = [None] * len(texts)
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
for batch_idx in range(total_batches):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, len(texts))
batch_texts = texts[start:end]
if self.show_progress:
print(f"\rEmbedding batch {batch_idx + 1}/{total_batches} "
f"({end}/{len(texts)} texts)", end="", flush=True)
embeddings = self._embed_batch(batch_texts)
for i, (text, embedding) in enumerate(zip(batch_texts, embeddings)):
global_idx = start + i
if embedding is not None:
results[global_idx] = EmbeddingResult(global_idx, text, embedding)
else:
results[global_idx] = EmbeddingResult(
global_idx, text, None,
error=f"Failed to embed item {global_idx}"
)
if self.show_progress:
print() # Newline after progress
successes = sum(1 for r in results if r.success)
logger.info(f"Embedded {successes}/{len(texts)} texts successfully")
return results
def embed_as_matrix(self, texts: list[str]) -> tuple[np.ndarray, list[int]]:
'''
Embed texts, returning only successful embeddings as a numpy matrix.
Returns:
(matrix of shape (N_success, D), list of successful original indices)
'''
results = self.embed(texts)
successful = [r for r in results if r.success]
failed = [r for r in results if not r.success]
if failed:
logger.warning(f"{len(failed)} embeddings failed: "
f"indices {[r.index for r in failed]}")
if not successful:
return np.array([]), []
matrix = np.array([r.embedding for r in successful])
indices = [r.index for r in successful]
return matrix, indices
# Usage
if __name__ == "__main__":
embedder = BatchEmbedder(model="text-embedding-3-small", batch_size=100)
texts = [f"Document {i}: This is sample text about topic {i}" for i in range(250)]
results = embedder.embed(texts)
successful = [r for r in results if r.success]
failed = [r for r in results if not r.success]
print(f"Success: {len(successful)}, Failed: {len(failed)}")
print(f"First embedding shape: {len(successful[0].embedding)}")
# Get as numpy matrix
matrix, indices = embedder.embed_as_matrix(texts[:10])
print(f"Matrix shape: {matrix.shape}")