Difference between LLM training vs inference?
#training#inference#llm#gradients#kv-cache
Answer
LLM Training vs Inference
Training teaches an LLM by updating its weights. Inference uses the trained model to generate outputs. They are fundamentally different in compute requirements, memory usage, and optimization goals.
Side-by-Side Comparison
| Aspect | Training | Inference |
|---|---|---|
| Goal | Learn weights from data | Generate predictions from inputs |
| Passes | Forward + backward pass | Forward pass only |
| Weight updates | Yes (gradient descent) | No (weights are frozen) |
| Memory (VRAM) | 3ā10Ć model size (gradients + optimizer states) | ~2Ć model size (KV cache) |
| Batch size | Large (32ā512+) | Small (1ā8 typically) |
| Speed | Slow (hours to weeks) | Fast (milliseconds to seconds) |
| Hardware | High-end GPUs (A100, H100) | Consumer GPUs, CPU, edge devices |
| Precision | FP32 / BF16 / FP16 | INT8, INT4, FP16 |
| Cost | Very high (1M+) | Low per query |
What Happens During Training
python# Training loop (simplified) import torch model.train() # activates dropout, batch norm training mode for batch in dataloader: optimizer.zero_grad() # 1. Forward pass ā compute predictions outputs = model(input_ids=batch["input_ids"], labels=batch["labels"]) loss = outputs.loss # 2. Backward pass ā compute gradients loss.backward() # 3. Update weights optimizer.step() scheduler.step() print(f"Loss: {loss.item():.4f}")
Memory During Training (3B model example)
textModel weights: ~6 GB (BF16) Gradients: ~6 GB (same size as weights) Optimizer states: ~12 GB (Adam: 2 states per param) Activations: ~8 GB (varies with batch/seq length) āāāāāāāāāāāāāāāāāāāāāāāāā Total: ~32 GB minimum
What Happens During Inference
python# Inference (no gradient computation) import torch model.eval() # disables dropout with torch.no_grad(): # saves memory ā no gradient graph built input_ids = tokenizer("Explain attention", return_tensors="pt").input_ids # Only forward pass output_ids = model.generate( input_ids, max_new_tokens=200, temperature=0.7, do_sample=True, ) print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
Memory During Inference (3B model example)
textModel weights: ~6 GB (FP16 or INT4 quantized = ~2 GB) KV cache: ~2 GB (grows with context length) āāāāāāāāāāāāāāāāāāāāāāāāā Total: ~4ā8 GB (much lighter than training)
Key Differences Explained
1. Gradient Computation
python# Training ā gradients ARE computed loss = model(inputs, labels=labels).loss loss.backward() # builds computation graph, stores gradients # Inference ā gradients NOT computed (saves ~50% memory) with torch.no_grad(): output = model.generate(inputs)
2. KV Cache (Inference Only)
During inference, the KV (Key-Value) cache stores attention keys/values from previous tokens so they don't need recomputation on each new token:
python# Without KV cache: O(n²) time ā recomputes all tokens each step # With KV cache: O(n) time ā only computes new token # HuggingFace uses KV cache automatically in generate() output = model.generate(input_ids, use_cache=True) # default True
3. Precision Trade-offs
| Phase | Common Precision | Reason |
|---|---|---|
| Training | BF16 + FP32 master weights | Stability for gradient updates |
| Inference | INT4 / INT8 / FP16 | Speed and memory efficiency |
Optimization Techniques
| Goal | Training | Inference |
|---|---|---|
| Reduce memory | Gradient checkpointing, QLoRA | Quantization (INT4/INT8) |
| Speed up | Mixed precision, Flash Attention | vLLM, continuous batching |
| Scale | DeepSpeed, FSDP, tensor parallelism | Horizontal scaling, load balancing |
| Cost | Spot instances, gradient accumulation | Smaller models, caching responses |
Rule of thumb: Training needs ~5ā10Ć more VRAM than inference for the same model. A model you can serve on a 16GB GPU may need an 80GB GPU to fine-tune.