Concept #43Hardpractical-coding

Design a robust LLM API client with retry logic and rate limiting.

#gen-ai#python#llm

Answer

Robust LLM API Client with Retry and Rate Limiting

python
import time
import asyncio
import logging
import threading
from typing import Optional
from collections import deque
from dataclasses import dataclass, field
from openai import OpenAI, AsyncOpenAI, RateLimitError, APIConnectionError, InternalServerError

logger = logging.getLogger(__name__)


@dataclass
class TokenBucket:
    '''Rate limiter using token bucket algorithm.'''
    capacity: int          # Max tokens in bucket
    refill_rate: float     # Tokens added per second
    _tokens: float = field(init=False)
    _last_refill: float = field(init=False)
    _lock: threading.Lock = field(default_factory=threading.Lock, init=False)

    def __post_init__(self):
        self._tokens = float(self.capacity)
        self._last_refill = time.monotonic()

    def acquire(self, tokens: int = 1) -> float:
        '''Block until tokens are available. Returns wait time.'''
        with self._lock:
            now = time.monotonic()
            elapsed = now - self._last_refill
            self._tokens = min(self.capacity, self._tokens + elapsed * self.refill_rate)
            self._last_refill = now

            if self._tokens >= tokens:
                self._tokens -= tokens
                return 0.0

            wait = (tokens - self._tokens) / self.refill_rate
            return wait

    def wait_and_acquire(self, tokens: int = 1) -> None:
        wait = self.acquire(tokens)
        if wait > 0:
            logger.debug(f"Rate limit: waiting {wait:.2f}s")
            time.sleep(wait)


class LLMClient:
    '''
    Production LLM client with:
    - Exponential backoff retry
    - Token bucket rate limiting
    - Request timeout
    - Comprehensive logging
    '''

    RETRYABLE_EXCEPTIONS = (RateLimitError, APIConnectionError, InternalServerError)

    def __init__(
        self,
        model: str = "gpt-4o",
        max_retries: int = 5,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        requests_per_minute: int = 60,
        timeout: float = 30.0,
    ):
        self.model = model
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.timeout = timeout
        self.client = OpenAI(timeout=timeout)

        # Rate limiter: RPM converted to tokens/second
        self.rate_limiter = TokenBucket(
            capacity=requests_per_minute,
            refill_rate=requests_per_minute / 60.0,
        )

        # Metrics
        self._total_requests = 0
        self._total_retries = 0
        self._total_errors = 0

    def complete(
        self,
        messages: list[dict],
        temperature: float = 0.0,
        max_tokens: Optional[int] = None,
        **kwargs,
    ) -> str:
        '''Make a chat completion with retry and rate limiting.'''
        self.rate_limiter.wait_and_acquire()

        last_exception = None
        for attempt in range(self.max_retries):
            try:
                self._total_requests += 1
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    **kwargs,
                )
                logger.info(
                    "llm_request_success",
                    extra={
                        "attempt": attempt + 1,
                        "model": self.model,
                        "tokens": response.usage.total_tokens,
                    }
                )
                return response.choices[0].message.content

            except self.RETRYABLE_EXCEPTIONS as e:
                last_exception = e
                self._total_retries += 1

                if attempt == self.max_retries - 1:
                    break

                # Respect Retry-After header if present
                retry_after = getattr(e, 'retry_after', None)
                delay = float(retry_after) if retry_after else min(
                    self.base_delay * (2 ** attempt), self.max_delay
                )

                logger.warning(
                    f"LLM request failed (attempt {attempt + 1}/{self.max_retries}): "
                    f"{type(e).__name__}. Retrying in {delay:.1f}s..."
                )
                time.sleep(delay)

            except Exception as e:
                self._total_errors += 1
                logger.error(f"Non-retryable LLM error: {type(e).__name__}: {e}")
                raise

        self._total_errors += 1
        raise last_exception

    @property
    def stats(self) -> dict:
        return {
            "total_requests": self._total_requests,
            "total_retries": self._total_retries,
            "total_errors": self._total_errors,
            "retry_rate": self._total_retries / max(self._total_requests, 1),
        }


# Usage
if __name__ == "__main__":
    client = LLMClient(
        model="gpt-4o",
        max_retries=5,
        requests_per_minute=50,
    )

    answer = client.complete([
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is RAG?"},
    ])
    print(answer)
    print("Stats:", client.stats)