Concept #152Hardsystem-design

Provide me complete architecture of how a Chat LLM works in detail

#system-design#architecture#transformer#attention#rag#tokenization#embeddings#llm

Answer

Complete Architecture: How a Chat LLM Works

A chat LLM processes your message through a multi-stage pipeline β€” from raw text to meaningful tokens, through dozens of transformer layers, to an auto-regressively generated response. Here is every stage explained in full detail.


Complete Flow Diagram


Phase 1 β€” Tokenization

The user's raw text is broken into tokens (subwords, not words). Each token maps to an integer ID from the model's vocabulary.

python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")

user_message = "How does attention work in transformers?"
tokens = tokenizer(user_message, return_tensors="pt")

print(tokens["input_ids"])
# tensor([[  1, 1128,  1838, 8570, 664, 1543, 297, 4979, 29879, 29973]])

print(tokenizer.convert_ids_to_tokens(tokens["input_ids"][0]))
# ['▁How', '▁does', '▁attention', '▁work', '▁in', '▁transform', 'ers', '?']

Key facts:

  • Vocabulary size: ~32K (LLaMA) to ~100K (GPT-4) tokens
  • Tokenizer type: BPE (Byte-Pair Encoding) or SentencePiece
  • Unknown words are split into subword pieces:
    text
    "transformers"
    β†’
    text
    ["transform", "ers"]

Phase 2 β€” Token Embeddings + Positional Encoding

Each token ID is looked up in an embedding matrix to get a dense vector. Since attention has no inherent sense of order, positional encodings are added.

python
import torch
import torch.nn as nn

vocab_size = 32000
d_model = 4096   # embedding dimension (LLaMA 3B)

# Embedding lookup table: vocab_size Γ— d_model
embedding = nn.Embedding(vocab_size, d_model)

token_ids = tokens["input_ids"]           # shape: [batch, seq_len]
x = embedding(token_ids)                  # shape: [batch, seq_len, 4096]

# Positional encoding (RoPE β€” Rotary Position Embedding, used in LLaMA)
# Encodes relative position by rotating Q and K vectors
# No separate addition needed β€” applied inside attention
Encoding TypeModelHow it works
Absolute (sinusoidal)Original TransformerAdd fixed sin/cos vectors by position
Learned absoluteGPT-2, BERTTrainable position embeddings
RoPELLaMA, MistralRotate Q and K in attention by position angle
ALiBiMPT, FalconAdd position bias directly to attention scores

Phase 3 β€” Conversation History & Prompt Assembly

Before entering the transformer, the full context window is assembled. This includes system prompt, conversation history, and the new user message.

python
messages = [
    {"role": "system",    "content": "You are a helpful AI assistant."},
    {"role": "user",      "content": "What is attention?"},
    {"role": "assistant", "content": "Attention allows the model to focus on relevant tokens..."},
    {"role": "user",      "content": "How does attention work in transformers?"},
]

# Tokenizer formats this into a single flat sequence using chat template
input_ids = tokenizer.apply_chat_template(
    messages,
    return_tensors="pt",
    add_generation_prompt=True
)
# Results in a long sequence like:
# <|system|> You are a helpful AI... <|user|> What is attention? <|assistant|> Attention... <|user|> How does attention... <|assistant|>

Context window limits: GPT-4 = 128K tokens, LLaMA 3.1 = 128K tokens, Claude 3.5 = 200K tokens


Phase 4 β€” RAG Pipeline (if enabled)

When the application uses RAG, external knowledge is retrieved and injected into the prompt before the LLM sees it.

Step 4a β€” Document Chunking

python
from langchain.text_splitter import RecursiveCharacterTextSplitter

splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,        # characters per chunk
    chunk_overlap=64,      # overlap between chunks for continuity
    separators=["

", "
", ".", " "]
)

document = "Large document text here..."
chunks = splitter.split_text(document)
# ["chunk 1 text...", "chunk 2 text...", ...]

Step 4b β€” Chunk Embedding & Storage

python
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")  # 1536-dim vectors

# Each chunk β†’ 1536-dimensional float vector
vectorstore = Chroma.from_texts(
    texts=chunks,
    embedding=embeddings,
    persist_directory="./chroma_db"
)

Step 4c β€” Query Retrieval at Runtime

python
query = "How does attention work in transformers?"

# Embed the query using SAME model
query_vector = embeddings.embed_query(query)       # shape: [1536]

# Cosine similarity search β†’ top-K most relevant chunks
results = vectorstore.similarity_search(query, k=3)

# Inject retrieved chunks into prompt as context
context = "

".join([doc.page_content for doc in results])
augmented_prompt = f"Context:
{context}

Question: {query}"

Phase 5 β€” Inside the Transformer: Multi-Head Self-Attention

This is the core of the LLM. Each of the N transformer blocks runs the same set of operations.

Self-Attention (Single Head)

python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]                               # key dimension

    # Step 1: compute raw attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))   # [seq, seq]
    scores = scores / math.sqrt(d_k)                # scale to prevent vanishing gradients

    # Step 2: apply causal mask (decoder: can't look at future tokens)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: softmax β†’ attention weights (sum to 1)
    weights = F.softmax(scores, dim=-1)             # [seq, seq]

    # Step 4: weighted sum of values
    output = torch.matmul(weights, V)               # [seq, d_v]
    return output, weights

Multi-Head Attention

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=4096, num_heads=32):
        super().__init__()
        self.d_k = d_model // num_heads   # 128 per head

        self.W_q = nn.Linear(d_model, d_model)   # query projection
        self.W_k = nn.Linear(d_model, d_model)   # key projection
        self.W_v = nn.Linear(d_model, d_model)   # value projection
        self.W_o = nn.Linear(d_model, d_model)   # output projection
        self.num_heads = num_heads

    def forward(self, x, mask=None):
        B, S, D = x.shape

        # Project and reshape into multiple heads
        Q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        # Each head attends independently
        attn_out, _ = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate all heads and project
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, D)
        return self.W_o(attn_out)

Why multiple heads? Each head learns to attend to different aspects β€” one head may focus on syntax, another on semantics, another on coreference.


Phase 6 β€” Feed-Forward Network (FFN)

After attention, each position passes independently through a 2-layer FFN:

python
class FeedForward(nn.Module):
    def __init__(self, d_model=4096, d_ff=11008):   # LLaMA uses 4x expansion
        super().__init__()
        self.gate = nn.Linear(d_model, d_ff, bias=False)
        self.up   = nn.Linear(d_model, d_ff, bias=False)
        self.down = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        # SwiGLU activation (used in LLaMA, Mistral)
        return self.down(F.silu(self.gate(x)) * self.up(x))

Full Transformer Block

python
class TransformerBlock(nn.Module):
    def __init__(self, d_model=4096, num_heads=32):
        super().__init__()
        self.attn   = MultiHeadAttention(d_model, num_heads)
        self.ffn    = FeedForward(d_model)
        self.norm1  = nn.RMSNorm(d_model)   # RMSNorm used in LLaMA
        self.norm2  = nn.RMSNorm(d_model)

    def forward(self, x, mask=None):
        # Attention sub-layer with residual connection
        x = x + self.attn(self.norm1(x), mask)
        # FFN sub-layer with residual connection
        x = x + self.ffn(self.norm2(x))
        return x

# Stack N blocks (LLaMA 3.2 3B has 28 blocks)
class TransformerStack(nn.Module):
    def __init__(self, n_layers=28):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(n_layers)])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

Phase 7 β€” Token Generation (Autoregressive Decoding)

The LLM generates one token at a time. After the full transformer stack, the last hidden state is projected to vocabulary logits.

python
class LMHead(nn.Module):
    def __init__(self, d_model=4096, vocab_size=32000):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        return self.linear(x)   # shape: [batch, seq_len, 32000]

# Full generation loop
def generate(model, tokenizer, prompt, max_new_tokens=200, temperature=0.7, top_p=0.9):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    generated = input_ids.clone()

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(generated).logits          # [1, seq, vocab_size]

        next_token_logits = logits[:, -1, :]          # last position only

        # Apply temperature (higher = more random)
        next_token_logits = next_token_logits / temperature

        # Top-p (nucleus) sampling β€” keep tokens summing to top_p probability
        sorted_logits, sorted_ids = torch.sort(next_token_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_logits[cumulative_probs > top_p] = float('-inf')

        probs = F.softmax(sorted_logits, dim=-1)
        next_token = sorted_ids[0, torch.multinomial(probs, 1)]

        generated = torch.cat([generated, next_token.unsqueeze(0).unsqueeze(0)], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(generated[0], skip_special_tokens=True)

Sampling Strategies Compared

StrategyHow it worksWhen to use
GreedyAlways pick highest probability tokenFast, deterministic, can be repetitive
TemperatureScale logits before softmax (low=focused, high=creative)General use
Top-KSample only from top K tokensReduces incoherence
Top-p (nucleus)Sample from smallest set summing to pBest quality/diversity balance
Beam searchExplore multiple sequences in parallelTranslation, structured output

Phase 8 β€” KV Cache (Inference Optimization)

Without KV cache, every new token would recompute attention over the entire context from scratch β€” O(nΒ²) per step. KV cache stores computed K and V tensors so only the new token needs processing.

python
# Without KV cache: recompute all tokens every step β†’ very slow
# With KV cache: only compute new token, reuse stored K/V β†’ O(n) per step

outputs = model.generate(
    input_ids,
    max_new_tokens=200,
    use_cache=True,          # KV cache enabled by default in HuggingFace
    temperature=0.7,
    do_sample=True,
)

Phase 9 β€” Detokenization

The generated token ID sequence is decoded back to human-readable text.

python
output_ids = [1, 8570, 664, 1543, ...]    # generated token IDs

response = tokenizer.decode(
    output_ids,
    skip_special_tokens=True,    # remove <s>, </s>, <pad> etc.
    clean_up_tokenization_spaces=True
)
print(response)
# "Attention works by computing similarity scores between all tokens..."

Complete Architecture Summary

StageWhat HappensKey Component
TokenizationText β†’ token IDsBPE / SentencePiece tokenizer
EmbeddingIDs β†’ dense vectorsEmbedding matrix (vocab Γ— d_model)
Positional encodingAdd position infoRoPE / ALiBi / sinusoidal
RAG (optional)Retrieve relevant chunksVector DB + cosine similarity
Prompt assemblyBuild full contextSystem + history + query
Self-attentionTokens attend to each otherQ, K, V projections + softmax
FFNPer-token transformation2-layer MLP + activation
N transformer blocksRepeat attention + FFNStack of 28–96 layers
LM headHidden state β†’ vocab logitsLinear projection
SamplingPick next tokenTemperature + top-p
Autoregressive loopRepeat until EOSKV cache for efficiency
DetokenizationToken IDs β†’ textTokenizer decode

Key insight: The LLM never "understands" language the way humans do β€” it learns statistical patterns across billions of tokens and predicts the most probable continuation. The magic is entirely in the learned weight matrices inside those transformer blocks.