How would you handle imbalanced training data for a classification task?
#gen-ai#training#mlops
Answer
Handling Imbalanced Datasets in Gen AI
Class imbalance is common in NLP classification tasks (e.g. fraud detection: 99% legitimate, 1% fraud). Standard training on imbalanced data causes the model to predict the majority class almost always.
Detecting Imbalance
pythonfrom collections import Counter labels = ["positive"] * 9000 + ["negative"] * 800 + ["neutral"] * 200 counts = Counter(labels) print(counts) # Counter({'positive': 9000, 'negative': 800, 'neutral': 200}) # Imbalance ratio majority = max(counts.values()) for label, count in counts.items(): print(f"{label}: {count} ({count/majority:.2%} of majority)")
Technique 1: Class Weights (Most Common)
Weight the loss function to penalise mistakes on minority classes more heavily.
pythonfrom sklearn.utils.class_weight import compute_class_weight from transformers import Trainer import torch import numpy as np class_weights = compute_class_weight( class_weight="balanced", classes=np.array([0, 1, 2]), y=train_labels ) weights = torch.tensor(class_weights, dtype=torch.float).to(device) class WeightedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") outputs = model(**inputs) logits = outputs.get("logits") loss = torch.nn.functional.cross_entropy(logits, labels, weight=weights) return (loss, outputs) if return_outputs else loss
Technique 2: Oversampling with SMOTE (for embeddings)
pythonfrom imblearn.over_sampling import SMOTE # Use on embeddings / feature vectors, not raw text smote = SMOTE(sampling_strategy="auto", k_neighbors=5) X_resampled, y_resampled = smote.fit_resample(X_embeddings, y_labels)
Technique 3: Data Augmentation with LLMs
Generate synthetic minority class examples using an LLM.
pythonfrom openai import OpenAI client = OpenAI() def generate_synthetic_examples(label: str, n: int = 10) -> list[str]: response = client.chat.completions.create( model="gpt-4o-mini", messages=[{ "role": "user", "content": f"Generate {n} diverse examples of {label} customer reviews. Return as JSON list." }] ) return eval(response.choices[0].message.content) synthetic_negatives = generate_synthetic_examples("angry customer complaint", n=500)
Technique 4: Threshold Tuning
Instead of using 0.5 as the decision boundary, find the optimal threshold.
pythonfrom sklearn.metrics import precision_recall_curve import numpy as np # Find threshold that maximises F1 for the minority class precision, recall, thresholds = precision_recall_curve(y_true, y_proba[:, 1]) f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8) best_threshold = thresholds[np.argmax(f1_scores)]
Choosing the Right Metric
Never use accuracy on imbalanced data. Use:
- F1 score (macro-averaged across classes)
- Precision/Recall for the minority class
- AUROC (area under ROC curve)
- Matthews Correlation Coefficient (MCC) — most robust single metric for imbalance
| Technique | When to Use | Implementation Effort |
|---|---|---|
| Class weights | Always try first | Very low |
| Oversample minority | When data is scarce | Low |
| Undersample majority | Large dataset | Low |
| Synthetic generation (LLM) | Very few minority samples | Medium |
| Threshold tuning | Post-training optimization | Low |
| Focal loss | Extreme imbalance (1:1000+) | Medium |