Concept #41Hardpractical-coding

Implement cosine similarity search.

#gen-ai#vector-db#python

Answer

Cosine Similarity Search

python
import numpy as np
from typing import NamedTuple


class SearchResult(NamedTuple):
    index: int
    score: float
    text: str


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    '''Compute cosine similarity between two vectors.'''
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    if norm_a == 0 or norm_b == 0:
        return 0.0
    return float(np.dot(a, b) / (norm_a * norm_b))


def cosine_similarity_matrix(
    query: np.ndarray,
    corpus: np.ndarray,
) -> np.ndarray:
    '''
    Compute cosine similarity between a query vector and all corpus vectors.
    Vectorised — much faster than looping for large corpora.

    Args:
        query: Shape (D,) — single query embedding
        corpus: Shape (N, D) — N document embeddings

    Returns:
        Shape (N,) — similarity scores
    '''
    query_norm = query / (np.linalg.norm(query) + 1e-10)
    corpus_norms = np.linalg.norm(corpus, axis=1, keepdims=True) + 1e-10
    corpus_normalised = corpus / corpus_norms
    return corpus_normalised @ query_norm  # (N,) dot products


def similarity_search(
    query_embedding: np.ndarray,
    document_embeddings: np.ndarray,
    documents: list[str],
    top_k: int = 5,
    min_score: float = 0.0,
) -> list[SearchResult]:
    '''
    Find top-k most similar documents to query.

    Args:
        query_embedding: Query vector, shape (D,)
        document_embeddings: Document matrix, shape (N, D)
        documents: List of document texts
        top_k: Number of results to return
        min_score: Minimum similarity threshold

    Returns:
        List of SearchResult sorted by score descending
    '''
    if len(document_embeddings) == 0:
        return []

    scores = cosine_similarity_matrix(query_embedding, document_embeddings)

    # Get top-k indices (unsorted)
    k = min(top_k, len(scores))
    top_indices = np.argpartition(scores, -k)[-k:]

    # Sort by score descending
    top_indices = top_indices[np.argsort(scores[top_indices])[::-1]]

    results = []
    for idx in top_indices:
        score = float(scores[idx])
        if score >= min_score:
            results.append(SearchResult(
                index=int(idx),
                score=score,
                text=documents[idx],
            ))

    return results


def batch_similarity_search(
    query_embeddings: np.ndarray,
    document_embeddings: np.ndarray,
    documents: list[str],
    top_k: int = 5,
) -> list[list[SearchResult]]:
    '''Search for multiple queries at once.'''
    # Normalise all vectors upfront
    q_norms = np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-10
    d_norms = np.linalg.norm(document_embeddings, axis=1, keepdims=True) + 1e-10
    q_norm = query_embeddings / q_norms
    d_norm = document_embeddings / d_norms

    # All pairwise similarities: (Q, D)
    all_scores = q_norm @ d_norm.T

    results = []
    for i, scores in enumerate(all_scores):
        k = min(top_k, len(scores))
        top_indices = np.argsort(scores)[::-1][:k]
        results.append([
            SearchResult(int(idx), float(scores[idx]), documents[idx])
            for idx in top_indices
        ])
    return results


# Demo
if __name__ == "__main__":
    from sentence_transformers import SentenceTransformer

    model = SentenceTransformer("all-MiniLM-L6-v2")

    corpus = [
        "RAG retrieves relevant documents to ground LLM responses",
        "Fine-tuning updates model weights on domain-specific data",
        "Transformers use attention mechanisms to process sequences",
        "The weather in Paris is often overcast in winter",
        "Vector databases store high-dimensional embeddings for fast search",
    ]

    query = "How does retrieval augmented generation work?"
    query_emb = model.encode(query)
    doc_embs = model.encode(corpus)

    results = similarity_search(query_emb, doc_embs, corpus, top_k=3)
    for r in results:
        print(f"Score: {r.score:.4f} | {r.text}")