Concept #148Mediumextended-ai-concepts

Difference between LLM training vs inference?

#training#inference#llm#gradients#kv-cache

Answer

LLM Training vs Inference

Training teaches an LLM by updating its weights. Inference uses the trained model to generate outputs. They are fundamentally different in compute requirements, memory usage, and optimization goals.


Side-by-Side Comparison

AspectTrainingInference
GoalLearn weights from dataGenerate predictions from inputs
PassesForward + backward passForward pass only
Weight updatesYes (gradient descent)No (weights are frozen)
Memory (VRAM)3–10Ɨ model size (gradients + optimizer states)~2Ɨ model size (KV cache)
Batch sizeLarge (32–512+)Small (1–8 typically)
SpeedSlow (hours to weeks)Fast (milliseconds to seconds)
HardwareHigh-end GPUs (A100, H100)Consumer GPUs, CPU, edge devices
PrecisionFP32 / BF16 / FP16INT8, INT4, FP16
CostVery high (100s–100s–1M+)Low per query

What Happens During Training

python
# Training loop (simplified)
import torch

model.train()  # activates dropout, batch norm training mode

for batch in dataloader:
    optimizer.zero_grad()

    # 1. Forward pass — compute predictions
    outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
    loss = outputs.loss

    # 2. Backward pass — compute gradients
    loss.backward()

    # 3. Update weights
    optimizer.step()
    scheduler.step()

    print(f"Loss: {loss.item():.4f}")

Memory During Training (3B model example)

text
Model weights:     ~6 GB  (BF16)
Gradients:         ~6 GB  (same size as weights)
Optimizer states:  ~12 GB (Adam: 2 states per param)
Activations:       ~8 GB  (varies with batch/seq length)
─────────────────────────
Total:            ~32 GB  minimum

What Happens During Inference

python
# Inference (no gradient computation)
import torch

model.eval()  # disables dropout

with torch.no_grad():  # saves memory — no gradient graph built
    input_ids = tokenizer("Explain attention", return_tensors="pt").input_ids

    # Only forward pass
    output_ids = model.generate(
        input_ids,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True,
    )

    print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

Memory During Inference (3B model example)

text
Model weights:    ~6 GB  (FP16 or INT4 quantized = ~2 GB)
KV cache:         ~2 GB  (grows with context length)
─────────────────────────
Total:            ~4–8 GB  (much lighter than training)

Key Differences Explained

1. Gradient Computation

python
# Training — gradients ARE computed
loss = model(inputs, labels=labels).loss
loss.backward()  # builds computation graph, stores gradients

# Inference — gradients NOT computed (saves ~50% memory)
with torch.no_grad():
    output = model.generate(inputs)

2. KV Cache (Inference Only)

During inference, the KV (Key-Value) cache stores attention keys/values from previous tokens so they don't need recomputation on each new token:

python
# Without KV cache: O(n²) time — recomputes all tokens each step
# With KV cache:    O(n) time  — only computes new token

# HuggingFace uses KV cache automatically in generate()
output = model.generate(input_ids, use_cache=True)  # default True

3. Precision Trade-offs

PhaseCommon PrecisionReason
TrainingBF16 + FP32 master weightsStability for gradient updates
InferenceINT4 / INT8 / FP16Speed and memory efficiency

Optimization Techniques

GoalTrainingInference
Reduce memoryGradient checkpointing, QLoRAQuantization (INT4/INT8)
Speed upMixed precision, Flash AttentionvLLM, continuous batching
ScaleDeepSpeed, FSDP, tensor parallelismHorizontal scaling, load balancing
CostSpot instances, gradient accumulationSmaller models, caching responses

Rule of thumb: Training needs ~5–10Ɨ more VRAM than inference for the same model. A model you can serve on a 16GB GPU may need an 80GB GPU to fine-tune.