Concept #33Mediumpython-for-gen-ai

How would you implement caching for expensive LLM API calls?

#gen-ai#python#mlops

Answer

Caching Expensive LLM API Calls

LLM API calls are slow (1–30s) and expensive (0.010.01–0.10 per request). Caching identical or semantically similar queries can dramatically reduce both latency and cost.

Level 1: Exact Match Cache (Simple & Fast)

python
import hashlib
import json
import functools
from typing import Any

def make_cache_key(messages: list[dict], model: str, **kwargs) -> str:
    '''Create deterministic cache key from inputs.'''
    payload = {"messages": messages, "model": model, **kwargs}
    return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()

# In-memory cache (process lifetime, not persistent)
_cache: dict[str, str] = {}

def cached_completion(messages: list[dict], model: str = "gpt-4o", **kwargs) -> str:
    key = make_cache_key(messages, model, **kwargs)
    if key in _cache:
        return _cache[key]

    from openai import OpenAI
    response = OpenAI().chat.completions.create(
        model=model, messages=messages, **kwargs
    ).choices[0].message.content

    _cache[key] = response
    return response

Level 2: Redis Cache (Persistent, Shared Across Processes)

python
import redis
import hashlib
import json
from typing import Optional

class LLMCache:
    def __init__(self, ttl_seconds: int = 3600):  # 1 hour default TTL
        self.redis = redis.Redis(host="localhost", port=6379, decode_responses=True)
        self.ttl = ttl_seconds

    def _key(self, prompt: str, model: str) -> str:
        raw = f"{model}:{prompt}"
        return f"llm:cache:{hashlib.sha256(raw.encode()).hexdigest()}"

    def get(self, prompt: str, model: str) -> Optional[str]:
        return self.redis.get(self._key(prompt, model))

    def set(self, prompt: str, model: str, response: str) -> None:
        self.redis.setex(self._key(prompt, model), self.ttl, response)

    def call_with_cache(self, prompt: str, model: str = "gpt-4o") -> tuple[str, bool]:
        '''Returns (response, was_cached).'''
        cached = self.get(prompt, model)
        if cached:
            return cached, True

        from openai import OpenAI
        response = OpenAI().chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0,  # Important: cache only deterministic responses
        ).choices[0].message.content

        self.set(prompt, model, response)
        return response, False

cache = LLMCache(ttl_seconds=86400)  # 24 hours
answer, from_cache = cache.call_with_cache("What is RAG?")
print(f"Response: {answer}")
print(f"Cache hit: {from_cache}")

Level 3: Semantic Cache (Similar Queries → Same Response)

python
from sentence_transformers import SentenceTransformer
import numpy as np
import redis
import json

class SemanticCache:
    def __init__(self, similarity_threshold: float = 0.92):
        self.model = SentenceTransformer("all-MiniLM-L6-v2")
        self.threshold = similarity_threshold
        self.cache: list[dict] = []  # [{embedding, query, response}]

    def _find_similar(self, query_embedding: np.ndarray) -> Optional[str]:
        if not self.cache:
            return None

        stored_embeddings = np.array([e["embedding"] for e in self.cache])
        similarities = stored_embeddings @ query_embedding  # Dot product (normalised)

        best_idx = np.argmax(similarities)
        if similarities[best_idx] >= self.threshold:
            return self.cache[best_idx]["response"]
        return None

    def query(self, prompt: str, model: str = "gpt-4o") -> tuple[str, bool]:
        embedding = self.model.encode([prompt], normalize_embeddings=True)[0]

        cached = self._find_similar(embedding)
        if cached:
            return cached, True

        from openai import OpenAI
        response = OpenAI().chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}]
        ).choices[0].message.content

        self.cache.append({"embedding": embedding, "query": prompt, "response": response})
        return response, False

sem_cache = SemanticCache(similarity_threshold=0.92)

# These two queries return the same cached response
r1, cached = sem_cache.query("What is Retrieval-Augmented Generation?")
r2, cached = sem_cache.query("Can you explain RAG in AI?")  # cached=True!

Caching Strategy by Use Case

Use CaseCache TypeTTL
FAQ answersExact match + Redis24–72 hours
Document summariesExact match + RedisUntil document changes
Query reformulationSemantic1 hour
Creative generationDon't cache
Dynamic data (prices, weather)Don't cache

Important: Only cache responses generated with

text
temperature=0
. Caching probabilistic outputs can return stale, inconsistent answers. Also cache the entire chain output (with retrieved context), not just the LLM call, to avoid embedding costs too.