Design a robust LLM API client with retry logic and rate limiting.
#gen-ai#python#llm
Answer
Robust LLM API Client with Retry and Rate Limiting
pythonimport time import asyncio import logging import threading from typing import Optional from collections import deque from dataclasses import dataclass, field from openai import OpenAI, AsyncOpenAI, RateLimitError, APIConnectionError, InternalServerError logger = logging.getLogger(__name__) @dataclass class TokenBucket: '''Rate limiter using token bucket algorithm.''' capacity: int # Max tokens in bucket refill_rate: float # Tokens added per second _tokens: float = field(init=False) _last_refill: float = field(init=False) _lock: threading.Lock = field(default_factory=threading.Lock, init=False) def __post_init__(self): self._tokens = float(self.capacity) self._last_refill = time.monotonic() def acquire(self, tokens: int = 1) -> float: '''Block until tokens are available. Returns wait time.''' with self._lock: now = time.monotonic() elapsed = now - self._last_refill self._tokens = min(self.capacity, self._tokens + elapsed * self.refill_rate) self._last_refill = now if self._tokens >= tokens: self._tokens -= tokens return 0.0 wait = (tokens - self._tokens) / self.refill_rate return wait def wait_and_acquire(self, tokens: int = 1) -> None: wait = self.acquire(tokens) if wait > 0: logger.debug(f"Rate limit: waiting {wait:.2f}s") time.sleep(wait) class LLMClient: ''' Production LLM client with: - Exponential backoff retry - Token bucket rate limiting - Request timeout - Comprehensive logging ''' RETRYABLE_EXCEPTIONS = (RateLimitError, APIConnectionError, InternalServerError) def __init__( self, model: str = "gpt-4o", max_retries: int = 5, base_delay: float = 1.0, max_delay: float = 60.0, requests_per_minute: int = 60, timeout: float = 30.0, ): self.model = model self.max_retries = max_retries self.base_delay = base_delay self.max_delay = max_delay self.timeout = timeout self.client = OpenAI(timeout=timeout) # Rate limiter: RPM converted to tokens/second self.rate_limiter = TokenBucket( capacity=requests_per_minute, refill_rate=requests_per_minute / 60.0, ) # Metrics self._total_requests = 0 self._total_retries = 0 self._total_errors = 0 def complete( self, messages: list[dict], temperature: float = 0.0, max_tokens: Optional[int] = None, **kwargs, ) -> str: '''Make a chat completion with retry and rate limiting.''' self.rate_limiter.wait_and_acquire() last_exception = None for attempt in range(self.max_retries): try: self._total_requests += 1 response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens, **kwargs, ) logger.info( "llm_request_success", extra={ "attempt": attempt + 1, "model": self.model, "tokens": response.usage.total_tokens, } ) return response.choices[0].message.content except self.RETRYABLE_EXCEPTIONS as e: last_exception = e self._total_retries += 1 if attempt == self.max_retries - 1: break # Respect Retry-After header if present retry_after = getattr(e, 'retry_after', None) delay = float(retry_after) if retry_after else min( self.base_delay * (2 ** attempt), self.max_delay ) logger.warning( f"LLM request failed (attempt {attempt + 1}/{self.max_retries}): " f"{type(e).__name__}. Retrying in {delay:.1f}s..." ) time.sleep(delay) except Exception as e: self._total_errors += 1 logger.error(f"Non-retryable LLM error: {type(e).__name__}: {e}") raise self._total_errors += 1 raise last_exception @property def stats(self) -> dict: return { "total_requests": self._total_requests, "total_retries": self._total_retries, "total_errors": self._total_errors, "retry_rate": self._total_retries / max(self._total_requests, 1), } # Usage if __name__ == "__main__": client = LLMClient( model="gpt-4o", max_retries=5, requests_per_minute=50, ) answer = client.complete([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is RAG?"}, ]) print(answer) print("Stats:", client.stats)