Concept #14Mediumgen-ai-fundamentals

How would you handle imbalanced training data for a classification task?

#gen-ai#training#mlops

Answer

Handling Imbalanced Datasets in Gen AI

Class imbalance is common in NLP classification tasks (e.g. fraud detection: 99% legitimate, 1% fraud). Standard training on imbalanced data causes the model to predict the majority class almost always.

Detecting Imbalance

python
from collections import Counter

labels = ["positive"] * 9000 + ["negative"] * 800 + ["neutral"] * 200
counts = Counter(labels)
print(counts)  # Counter({'positive': 9000, 'negative': 800, 'neutral': 200})

# Imbalance ratio
majority = max(counts.values())
for label, count in counts.items():
    print(f"{label}: {count} ({count/majority:.2%} of majority)")

Technique 1: Class Weights (Most Common)

Weight the loss function to penalise mistakes on minority classes more heavily.

python
from sklearn.utils.class_weight import compute_class_weight
from transformers import Trainer
import torch
import numpy as np

class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.array([0, 1, 2]),
    y=train_labels
)
weights = torch.tensor(class_weights, dtype=torch.float).to(device)

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = torch.nn.functional.cross_entropy(logits, labels, weight=weights)
        return (loss, outputs) if return_outputs else loss

Technique 2: Oversampling with SMOTE (for embeddings)

python
from imblearn.over_sampling import SMOTE

# Use on embeddings / feature vectors, not raw text
smote = SMOTE(sampling_strategy="auto", k_neighbors=5)
X_resampled, y_resampled = smote.fit_resample(X_embeddings, y_labels)

Technique 3: Data Augmentation with LLMs

Generate synthetic minority class examples using an LLM.

python
from openai import OpenAI
client = OpenAI()

def generate_synthetic_examples(label: str, n: int = 10) -> list[str]:
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "user",
            "content": f"Generate {n} diverse examples of {label} customer reviews. Return as JSON list."
        }]
    )
    return eval(response.choices[0].message.content)

synthetic_negatives = generate_synthetic_examples("angry customer complaint", n=500)

Technique 4: Threshold Tuning

Instead of using 0.5 as the decision boundary, find the optimal threshold.

python
from sklearn.metrics import precision_recall_curve
import numpy as np

# Find threshold that maximises F1 for the minority class
precision, recall, thresholds = precision_recall_curve(y_true, y_proba[:, 1])
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
best_threshold = thresholds[np.argmax(f1_scores)]

Choosing the Right Metric

Never use accuracy on imbalanced data. Use:

  • F1 score (macro-averaged across classes)
  • Precision/Recall for the minority class
  • AUROC (area under ROC curve)
  • Matthews Correlation Coefficient (MCC) — most robust single metric for imbalance
TechniqueWhen to UseImplementation Effort
Class weightsAlways try firstVery low
Oversample minorityWhen data is scarceLow
Undersample majorityLarge datasetLow
Synthetic generation (LLM)Very few minority samplesMedium
Threshold tuningPost-training optimizationLow
Focal lossExtreme imbalance (1:1000+)Medium