Contrastive learning has revolutionized representation learning by teaching neural networks to distinguish between similar and dissimilar examples. This elegant approach has achieved remarkable success in computer vision, natural language processing, and multimodal AI. This comprehensive guide explores the principles, methods, and practical implementation of contrastive learning.

The Essence of Contrastive Learning

At its core, contrastive learning teaches models to recognize what makes things similar or different. Instead of predicting labels, the model learns to:

  1. Pull together representations of similar (positive) pairs
  2. Push apart representations of dissimilar (negative) pairs

This seemingly simple objective leads to powerful representations that transfer well to downstream tasks.

Why Contrastive Learning Works

Human learning often involves comparison. We understand “dog” partly by knowing how dogs differ from cats, wolves, and furniture. Contrastive learning mirrors this process:

  • Learning what features distinguish instances
  • Discovering invariances (what stays the same despite transformations)
  • Building representations that capture semantic meaning

The Contrastive Learning Framework

  1. Sample a batch of examples
  2. Create positive pairs (similar examples)
  3. Create negative pairs (dissimilar examples)
  4. Encode all examples
  5. Compute loss that brings positives closer and pushes negatives apart
  6. Update model parameters

`

Mathematical Foundations

Contrastive Loss Functions

InfoNCE (Noise Contrastive Estimation)

The most widely used contrastive loss:

`python

def info_nce_loss(query, positive_key, negative_keys, temperature=0.07):

"""

query: [batch_size, dim]

positive_key: [batch_size, dim]

negative_keys: [batch_size, num_negatives, dim]

"""

# Positive similarity

pos_sim = torch.sum(query * positive_key, dim=-1) / temperature

# Negative similarities

neg_sim = torch.bmm(

query.unsqueeze(1),

negative_keys.transpose(1, 2)

).squeeze(1) / temperature

# Combine for softmax

logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)

labels = torch.zeros(query.size(0), dtype=torch.long, device=query.device)

return F.cross_entropy(logits, labels)

`

NT-Xent (Normalized Temperature-scaled Cross Entropy)

Used in SimCLR, treats other samples in the batch as negatives:

`python

def nt_xent_loss(z1, z2, temperature=0.5):

"""

z1, z2: [batch_size, dim] - two augmented views

"""

batch_size = z1.size(0)

# Normalize embeddings

z1 = F.normalize(z1, dim=1)

z2 = F.normalize(z2, dim=1)

# Gather all embeddings

z = torch.cat([z1, z2], dim=0)

# Similarity matrix

sim_matrix = torch.mm(z, z.t()) / temperature

# Create mask for positive pairs

sim_matrix.fill_diagonal_(-float('inf'))

# Labels: positive pairs are (i, i+batch_size) and (i+batch_size, i)

labels = torch.cat([

torch.arange(batch_size, 2 * batch_size),

torch.arange(batch_size)

], dim=0).to(z.device)

loss = F.cross_entropy(sim_matrix, labels)

return loss

`

Triplet Loss

Classic formulation using anchor, positive, and negative:

`python

def triplet_loss(anchor, positive, negative, margin=0.3):

"""

anchor, positive, negative: [batch_size, dim]

"""

pos_dist = torch.sum((anchor - positive) ** 2, dim=1)

neg_dist = torch.sum((anchor - negative) ** 2, dim=1)

loss = F.relu(pos_dist - neg_dist + margin)

return loss.mean()

`

Contrastive Loss (Margin-based)

Original formulation from Siamese networks:

`python

def contrastive_loss(x1, x2, label, margin=1.0):

"""

label: 1 for positive pairs, 0 for negative pairs

"""

distance = F.pairwise_distance(x1, x2)

pos_loss = label * distance ** 2

neg_loss = (1 - label) * F.relu(margin - distance) ** 2

return (pos_loss + neg_loss).mean()

`

The Role of Temperature

Temperature controls the sharpness of the similarity distribution:

  • High temperature: Softer distribution, more uniform attention to negatives
  • Low temperature: Sharper distribution, focus on hard negatives

`python

# Temperature effect demonstration

def similarity_distribution(similarities, temperatures):

for temp in temperatures:

probs = F.softmax(similarities / temp, dim=-1)

print(f"Temperature {temp}: {probs}")

# With similarities [0.8, 0.3, 0.1, -0.2]

# temp=1.0: [0.42, 0.25, 0.21, 0.15] (spread out)

# temp=0.1: [0.99, 0.01, 0.00, 0.00] (concentrated)

`

Creating Positive and Negative Pairs

Instance Discrimination

Treat each image as its own class:

  • Positive: Different augmentations of the same image
  • Negative: Any other image in the batch

`python

class InstanceDiscrimination(nn.Module):

def __init__(self, encoder, temperature=0.1):

super().__init__()

self.encoder = encoder

self.temperature = temperature

def forward(self, batch):

# Create two augmented views

view1 = self.augment(batch)

view2 = self.augment(batch)

# Encode

z1 = F.normalize(self.encoder(view1), dim=1)

z2 = F.normalize(self.encoder(view2), dim=1)

# In-batch negatives

loss = self.nt_xent_loss(z1, z2)

return loss

`

Supervised Contrastive Learning

Use labels to define positive pairs:

  • Positive: Images of the same class
  • Negative: Images of different classes

`python

def supervised_contrastive_loss(features, labels, temperature=0.07):

"""

features: [batch_size, dim]

labels: [batch_size]

"""

features = F.normalize(features, dim=1)

batch_size = features.size(0)

# Compute similarity matrix

sim_matrix = torch.mm(features, features.t()) / temperature

# Mask for positive pairs (same class)

labels = labels.view(-1, 1)

mask = torch.eq(labels, labels.t()).float()

# Remove self-similarity

mask.fill_diagonal_(0)

# For numerical stability

logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)

logits = sim_matrix - logits_max.detach()

# Compute log probabilities

exp_logits = torch.exp(logits)

# Mask out self

exp_logits.fill_diagonal_(0)

log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)

# Mean of positive log probabilities

mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)

loss = -mean_log_prob_pos.mean()

return loss

`

Hard Negative Mining

Select challenging negatives for more effective learning:

`python

def hard_negative_mining(anchor, all_negatives, num_hard=10):

"""Select the most similar (hardest) negatives."""

similarities = torch.mm(anchor, all_negatives.t())

# Get top-k most similar negatives

_, hard_indices = similarities.topk(num_hard, dim=1)

hard_negatives = torch.gather(

all_negatives.unsqueeze(0).expand(anchor.size(0), -1, -1),

1,

hard_indices.unsqueeze(-1).expand(-1, -1, all_negatives.size(-1))

)

return hard_negatives

`

Semi-Hard Negative Mining

Select negatives that are hard but not too hard:

`python

def semi_hard_mining(anchor, positive, all_negatives, margin=0.3):

"""

Select negatives that are farther than positive but within margin.

"""

pos_dist = torch.sum((anchor - positive) ** 2, dim=1, keepdim=True)

neg_dists = torch.sum((anchor.unsqueeze(1) - all_negatives) ** 2, dim=2)

# Semi-hard: neg_dist > pos_dist but neg_dist < pos_dist + margin

mask = (neg_dists > pos_dist) & (neg_dists < pos_dist + margin)

# Select semi-hard negatives (or fallback to hardest)

semi_hard_negatives = []

for i in range(anchor.size(0)):

valid = mask[i].nonzero().squeeze()

if valid.numel() > 0:

idx = valid[torch.randint(valid.numel(), (1,))]

else:

idx = neg_dists[i].argmin()

semi_hard_negatives.append(all_negatives[idx])

return torch.stack(semi_hard_negatives)

`

Data Augmentation Strategies

Image Augmentations

Strong, diverse augmentations are crucial for contrastive learning:

`python

class ContrastiveAugmentation:

def __init__(self, img_size=224):

self.transform = transforms.Compose([

# Geometric transformations

transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),

transforms.RandomHorizontalFlip(p=0.5),

# Color transformations

transforms.RandomApply([

transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)

], p=0.8),

transforms.RandomGrayscale(p=0.2),

# Blur

transforms.RandomApply([

transforms.GaussianBlur(kernel_size=23)

], p=0.5),

# Solarization (for larger models)

transforms.RandomSolarize(threshold=128, p=0.2),

# Normalize

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

def __call__(self, x):

return self.transform(x), self.transform(x)

`

Text Augmentations

For NLP contrastive learning:

`python

class TextAugmentation:

def __init__(self, tokenizer):

self.tokenizer = tokenizer

def word_dropout(self, tokens, p=0.1):

"""Randomly drop words."""

mask = torch.rand(len(tokens)) > p

return [t for t, m in zip(tokens, mask) if m]

def word_shuffle(self, tokens, k=3):

"""Shuffle words within a window."""

positions = list(range(len(tokens)))

for i in range(len(tokens)):

j = min(len(tokens) - 1, i + int(torch.rand(1) * k))

positions[i], positions[j] = positions[j], positions[i]

return [tokens[p] for p in positions]

def back_translation(self, text):

"""Translate to another language and back."""

# Requires translation models

pass

def synonym_replacement(self, tokens, p=0.1):

"""Replace words with synonyms."""

# Requires synonym database like WordNet

pass

`

Multimodal Augmentations

For image-text pairs:

`python

class MultimodalAugmentation:

def __init__(self):

self.image_aug = ContrastiveAugmentation()

self.text_aug = TextAugmentation()

def __call__(self, image, text):

# Multiple augmented views

img_view1, img_view2 = self.image_aug(image)

text_view1 = self.text_aug.word_dropout(text)

text_view2 = self.text_aug.synonym_replacement(text)

return (img_view1, text_view1), (img_view2, text_view2)

`

Key Contrastive Learning Methods

SimCLR (Simple Framework for Contrastive Learning)

`python

class SimCLR(nn.Module):

def __init__(self, encoder, hidden_dim=2048, proj_dim=128, temperature=0.5):

super().__init__()

self.encoder = encoder

self.projector = nn.Sequential(

nn.Linear(encoder.output_dim, hidden_dim),

nn.ReLU(),

nn.Linear(hidden_dim, proj_dim)

)

self.temperature = temperature

def forward(self, x1, x2):

# Encode and project

h1 = self.encoder(x1)

h2 = self.encoder(x2)

z1 = self.projector(h1)

z2 = self.projector(h2)

# NT-Xent loss

loss = self.nt_xent(z1, z2)

return loss, h1 # Return representations for downstream

def nt_xent(self, z1, z2):

batch_size = z1.size(0)

z1 = F.normalize(z1, dim=1)

z2 = F.normalize(z2, dim=1)

z = torch.cat([z1, z2], dim=0)

sim = torch.mm(z, z.t()) / self.temperature

sim.fill_diagonal_(-1e9)

labels = torch.cat([

torch.arange(batch_size, 2*batch_size),

torch.arange(batch_size)

]).to(z.device)

return F.cross_entropy(sim, labels)

`

Key insights:

  • Large batch sizes (4096+) provide more negatives
  • Projection head is crucial but discarded after training
  • Strong augmentation is essential

MoCo v3

Combines momentum contrast with Vision Transformer:

`python

class MoCoV3(nn.Module):

def __init__(self, encoder, dim=256, mlp_dim=4096, temperature=0.2):

super().__init__()

# Encoder with projection

self.encoder = encoder

self.projector = nn.Sequential(

nn.Linear(encoder.output_dim, mlp_dim),

nn.BatchNorm1d(mlp_dim),

nn.ReLU(),

nn.Linear(mlp_dim, mlp_dim),

nn.BatchNorm1d(mlp_dim),

nn.ReLU(),

nn.Linear(mlp_dim, dim)

)

# Predictor (asymmetric)

self.predictor = nn.Sequential(

nn.Linear(dim, mlp_dim),

nn.BatchNorm1d(mlp_dim),

nn.ReLU(),

nn.Linear(mlp_dim, dim)

)

self.temperature = temperature

def forward(self, x1, x2):

# Query features

q1 = self.predictor(self.projector(self.encoder(x1)))

q2 = self.predictor(self.projector(self.encoder(x2)))

# Key features (stop gradient)

with torch.no_grad():

k1 = self.projector(self.encoder(x1))

k2 = self.projector(self.encoder(x2))

# Symmetrized loss

loss = self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)

return loss / 2

def contrastive_loss(self, q, k):

q = F.normalize(q, dim=1)

k = F.normalize(k, dim=1)

logits = torch.mm(q, k.t()) / self.temperature

labels = torch.arange(q.size(0), device=q.device)

return F.cross_entropy(logits, labels)

`

SwAV (Swapping Assignments between Views)

Uses online clustering instead of explicit negatives:

`python

class SwAV(nn.Module):

def __init__(self, encoder, dim=256, num_prototypes=3000, temperature=0.1):

super().__init__()

self.encoder = encoder

self.projector = nn.Sequential(

nn.Linear(encoder.output_dim, 2048),

nn.BatchNorm1d(2048),

nn.ReLU(),

nn.Linear(2048, dim)

)

# Learnable prototypes (cluster centers)

self.prototypes = nn.Linear(dim, num_prototypes, bias=False)

self.temperature = temperature

def forward(self, x1, x2):

# Get features

z1 = F.normalize(self.projector(self.encoder(x1)), dim=1)

z2 = F.normalize(self.projector(self.encoder(x2)), dim=1)

# Compute prototype assignments

with torch.no_grad():

# Sinkhorn-Knopp to get soft codes

q1 = self.sinkhorn(self.prototypes(z1))

q2 = self.sinkhorn(self.prototypes(z2))

# Swapped prediction loss

p1 = self.prototypes(z1) / self.temperature

p2 = self.prototypes(z2) / self.temperature

loss = -0.5 * (

torch.sum(q1 * F.log_softmax(p2, dim=1), dim=1).mean() +

torch.sum(q2 * F.log_softmax(p1, dim=1), dim=1).mean()

)

return loss

def sinkhorn(self, scores, num_iters=3, epsilon=0.05):

"""Sinkhorn-Knopp algorithm for balanced assignments."""

Q = torch.exp(scores / epsilon).t()

Q /= Q.sum()

K, B = Q.shape

for _ in range(num_iters):

Q /= Q.sum(dim=1, keepdim=True) * K

Q /= Q.sum(dim=0, keepdim=True) * B

return (Q / Q.sum(dim=0, keepdim=True)).t()

`

CLIP (Contrastive Language-Image Pre-training)

Multimodal contrastive learning between images and text:

`python

class CLIP(nn.Module):

def __init__(self, image_encoder, text_encoder, embed_dim=512):

super().__init__()

self.image_encoder = image_encoder

self.text_encoder = text_encoder

self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)

self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)

# Learned temperature

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

def encode_image(self, image):

features = self.image_encoder(image)

features = self.image_projection(features)

return F.normalize(features, dim=-1)

def encode_text(self, text):

features = self.text_encoder(text)

features = self.text_projection(features)

return F.normalize(features, dim=-1)

def forward(self, images, texts):

image_features = self.encode_image(images)

text_features = self.encode_text(texts)

# Compute logits

logit_scale = self.logit_scale.exp()

logits_per_image = logit_scale * image_features @ text_features.t()

logits_per_text = logits_per_image.t()

# Contrastive loss

batch_size = images.size(0)

labels = torch.arange(batch_size, device=images.device)

loss_i2t = F.cross_entropy(logits_per_image, labels)

loss_t2i = F.cross_entropy(logits_per_text, labels)

return (loss_i2t + loss_t2i) / 2

`

Training Considerations

Batch Size and Distributed Training

Large batch sizes are crucial for contrastive learning:

`python

class DistributedContrastiveLoss(nn.Module):

"""Gather representations across GPUs for more negatives."""

def __init__(self, temperature=0.1):

super().__init__()

self.temperature = temperature

def forward(self, z1, z2):

# Gather from all GPUs

z1_all = self.gather_from_all(z1)

z2_all = self.gather_from_all(z2)

# Compute loss with all negatives

return self.contrastive_loss(z1, z2, z1_all, z2_all)

@torch.no_grad()

def gather_from_all(self, tensor):

"""Gather tensors from all processes."""

tensors_gather = [

torch.ones_like(tensor)

for _ in range(dist.get_world_size())

]

dist.all_gather(tensors_gather, tensor)

return torch.cat(tensors_gather, dim=0)

`

Learning Rate Scheduling

Warmup followed by cosine decay:

`python

def get_lr_schedule(optimizer, warmup_epochs, total_epochs, base_lr):

def lr_lambda(epoch):

if epoch < warmup_epochs:

return epoch / warmup_epochs

else:

progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)

return 0.5 * (1 + math.cos(math.pi * progress))

return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

`

LARS/LAMB Optimizers

Layer-wise adaptive learning rates for large batch training:

`python

class LARS(torch.optim.Optimizer):

"""Layer-wise Adaptive Rate Scaling."""

def __init__(self, params, lr, weight_decay=0, momentum=0.9, trust_coef=0.001):

defaults = dict(lr=lr, weight_decay=weight_decay,

momentum=momentum, trust_coef=trust_coef)

super().__init__(params, defaults)

def step(self):

for group in self.param_groups:

for p in group['params']:

if p.grad is None:

continue

grad = p.grad

# Weight decay

if group['weight_decay'] != 0:

grad = grad.add(p, alpha=group['weight_decay'])

# LARS scaling

param_norm = torch.norm(p)

grad_norm = torch.norm(grad)

if param_norm > 0 and grad_norm > 0:

adaptive_lr = group['trust_coef'] * param_norm / grad_norm

else:

adaptive_lr = 1

# Apply update

p.add_(grad, alpha=-group['lr'] * adaptive_lr)

`

Avoiding Representation Collapse

Several strategies prevent all representations from becoming identical:

  1. Negative samples: Explicitly push apart different samples
  2. Asymmetric architectures: Predictor on one branch (BYOL, SimSiam)
  3. Stop gradients: Prevent certain gradient paths
  4. Batch normalization: Implicit regularization

`python

# SimSiam: No negatives, uses stop-gradient

class SimSiam(nn.Module):

def __init__(self, encoder, dim=2048, pred_dim=512):

super().__init__()

self.encoder = encoder

self.projector = nn.Sequential(

nn.Linear(encoder.output_dim, dim),

nn.BatchNorm1d(dim),

nn.ReLU(),

nn.Linear(dim, dim),

nn.BatchNorm1d(dim),

nn.ReLU(),

nn.Linear(dim, dim),

nn.BatchNorm1d(dim)

)

self.predictor = nn.Sequential(

nn.Linear(dim, pred_dim),

nn.BatchNorm1d(pred_dim),

nn.ReLU(),

nn.Linear(pred_dim, dim)

)

def forward(self, x1, x2):

z1 = self.projector(self.encoder(x1))

z2 = self.projector(self.encoder(x2))

p1 = self.predictor(z1)

p2 = self.predictor(z2)

# Stop gradient on z

loss = (

self.cosine_loss(p1, z2.detach()) +

self.cosine_loss(p2, z1.detach())

) / 2

return loss

def cosine_loss(self, p, z):

p = F.normalize(p, dim=1)

z = F.normalize(z, dim=1)

return -(p * z).sum(dim=1).mean()

`

Evaluation Protocols

Linear Probe

The standard evaluation:

`python

def linear_probe_eval(encoder, train_loader, test_loader, num_classes, epochs=100):

# Freeze encoder

encoder.eval()

for param in encoder.parameters():

param.requires_grad = False

# Linear classifier

classifier = nn.Linear(encoder.output_dim, num_classes)

optimizer = torch.optim.SGD(classifier.parameters(), lr=0.3, momentum=0.9)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

# Training

for epoch in range(epochs):

for images, labels in train_loader:

with torch.no_grad():

features = encoder(images)

logits = classifier(features)

loss = F.cross_entropy(logits, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

scheduler.step()

# Evaluation

correct = 0

total = 0

classifier.eval()

with torch.no_grad():

for images, labels in test_loader:

features = encoder(images)

predictions = classifier(features).argmax(dim=1)

correct += (predictions == labels).sum()

total += labels.size(0)

return correct / total

`

k-NN Evaluation

Memory-based evaluation without training:

`python

def knn_eval(encoder, train_loader, test_loader, k=20):

encoder.eval()

# Extract features

train_features, train_labels = extract_features(encoder, train_loader)

test_features, test_labels = extract_features(encoder, test_loader)

# Normalize

train_features = F.normalize(train_features, dim=1)

test_features = F.normalize(test_features, dim=1)

# Compute similarities

sim = torch.mm(test_features, train_features.t())

# Get k nearest neighbors

_, indices = sim.topk(k, dim=1)

retrieved_labels = train_labels[indices]

# Vote

predictions = torch.mode(retrieved_labels, dim=1).values

accuracy = (predictions == test_labels).float().mean()

return accuracy

`

Applications

Self-Supervised Pretraining

Pretrain on ImageNet, transfer to downstream tasks:

`python

# Pretrain

encoder = ResNet50()

model = SimCLR(encoder)

train_contrastive(model, imagenet_unlabeled)

# Transfer

encoder.load_state_dict(model.encoder.state_dict())

classifier = nn.Linear(encoder.output_dim, num_classes)

fine_tune(encoder, classifier, downstream_dataset)

`

Zero-Shot Classification (CLIP-style)

`python

def zero_shot_classify(clip_model, image, class_names):

# Encode image

image_features = clip_model.encode_image(image)

# Encode class names as text

texts = [f"a photo of a {name}" for name in class_names]

text_features = clip_model.encode_text(texts)

# Compute similarities

similarities = image_features @ text_features.t()

# Return most similar class

return class_names[similarities.argmax()]

`

Semantic Search

`python

class SemanticSearch:

def __init__(self, clip_model):

self.model = clip_model

self.index = None

self.images = []

def index_images(self, image_paths):

features = []

for path in image_paths:

image = load_image(path)

feat = self.model.encode_image(image)

features.append(feat)

self.images.append(path)

self.index = torch.cat(features, dim=0)

self.index = F.normalize(self.index, dim=1)

def search(self, query_text, top_k=10):

text_features = self.model.encode_text(query_text)

text_features = F.normalize(text_features, dim=1)

similarities = text_features @ self.index.t()

_, indices = similarities.topk(top_k)

return [self.images[i] for i in indices[0]]

Conclusion

Contrastive learning has proven to be one of the most effective self-supervised learning paradigms. By learning to distinguish similar from dissimilar examples, models develop rich representations that transfer well to diverse downstream tasks.

Key takeaways:

  1. Core principle: Pull together positives, push apart negatives
  2. Augmentation is crucial: Strong, diverse augmentations create meaningful positive pairs
  3. Scale matters: Large batches provide more negatives
  4. Temperature controls hardness: Lower temperature focuses on harder examples
  5. Projection heads help: Non-linear projectors improve representation learning
  6. Multiple formulations: InfoNCE, triplet loss, and clustering-based approaches all work
  7. Multimodal extension: CLIP shows power of cross-modal contrastive learning

Whether you’re pretraining vision models, building semantic search systems, or developing multimodal AI, contrastive learning provides a powerful foundation for learning from unlabeled data.

Leave a Reply

Your email address will not be published. Required fields are marked *