Self-supervised learning has emerged as one of the most promising paradigms in artificial intelligence, fundamentally changing how we think about training machine learning models. By learning from unlabeled data, self-supervised methods have achieved remarkable results in natural language processing, computer vision, and beyond. This comprehensive guide explores the principles, techniques, and applications of self-supervised learning.

The Labeling Problem

Traditional supervised learning requires labeled datasets—examples paired with their correct answers. Creating these labels is expensive, time-consuming, and often requires domain expertise:

  • Medical imaging: Expert radiologists must annotate each scan
  • Natural language: Human annotators label sentiment, entities, or translations
  • Autonomous driving: Thousands of hours of video need pixel-level annotations

Meanwhile, unlabeled data is abundant and cheap:

  • Billions of images on the internet
  • Trillions of words in text corpora
  • Endless video footage and audio recordings

Self-supervised learning bridges this gap by extracting learning signals from the data itself.

What Is Self-Supervised Learning?

Self-supervised learning creates supervisory signals from the data’s own structure. The model solves “pretext tasks”—artificial tasks that don’t require human labels—to learn useful representations.

The key insight: solving these pretext tasks requires understanding the data’s underlying structure, which transfers to downstream tasks.

The Self-Supervised Learning Pipeline

  1. Pretext Task Design: Create a task that requires meaningful understanding
  2. Pretraining: Train a model on massive unlabeled data using the pretext task
  3. Transfer Learning: Fine-tune on downstream tasks with limited labels

Comparison with Other Paradigms

| Paradigm | Label Source | Data Required |

|———-|————-|—————|

| Supervised | Human annotations | Labeled data |

| Unsupervised | None (clustering, density) | Unlabeled data |

| Self-supervised | Data itself (automatic) | Unlabeled data |

| Semi-supervised | Few labels + unlabeled | Both |

Self-Supervised Learning in NLP

Natural language processing pioneered many self-supervised techniques, leveraging the sequential nature of text.

Language Modeling

The classic pretext task: predict the next word given previous words.

python

# Causal language modeling example

text = "The cat sat on the"

target = "mat"

# Model learns: P(mat | The cat sat on the)

`

GPT and similar models use this approach:

`python

class CausalLanguageModel(nn.Module):

def __init__(self, vocab_size, d_model, nhead, num_layers):

super().__init__()

self.embedding = nn.Embedding(vocab_size, d_model)

self.transformer = nn.TransformerEncoder(

nn.TransformerEncoderLayer(d_model, nhead),

num_layers

)

self.lm_head = nn.Linear(d_model, vocab_size)

def forward(self, x):

# Create causal mask

seq_len = x.size(1)

mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

# Embed and process

x = self.embedding(x)

x = self.transformer(x, mask=mask)

return self.lm_head(x)

`

Masked Language Modeling (MLM)

BERT introduced masked language modeling: randomly mask tokens and predict them.

`python

def mask_tokens(inputs, tokenizer, mlm_probability=0.15):

"""Prepare masked tokens for masked language modeling."""

labels = inputs.clone()

# Create mask

probability_matrix = torch.full(labels.shape, mlm_probability)

masked_indices = torch.bernoulli(probability_matrix).bool()

labels[~masked_indices] = -100 # Don't compute loss for unmasked

# 80% of time, replace with [MASK]

indices_replaced = torch.bernoulli(

torch.full(labels.shape, 0.8)

).bool() & masked_indices

inputs[indices_replaced] = tokenizer.mask_token_id

# 10% of time, replace with random token

indices_random = torch.bernoulli(

torch.full(labels.shape, 0.5)

).bool() & masked_indices & ~indices_replaced

random_words = torch.randint(len(tokenizer), labels.shape)

inputs[indices_random] = random_words[indices_random]

# 10% of time, keep unchanged

return inputs, labels

`

Next Sentence Prediction (NSP)

BERT also trained on predicting whether two sentences are consecutive:

`python

# Positive example (consecutive sentences)

sentence_a = "The dog is happy."

sentence_b = "It wags its tail."

label = 1

# Negative example (random sentences)

sentence_a = "The dog is happy."

sentence_b = "The stock market crashed."

label = 0

`

Sentence Order Prediction

ALBERT replaced NSP with sentence order prediction—determining if two sentences are in the correct order:

`python

# Correct order

sentence_a = "First, preheat the oven."

sentence_b = "Then, place the dish inside."

label = 1

# Swapped order

sentence_a = "Then, place the dish inside."

sentence_b = "First, preheat the oven."

label = 0

`

Span Corruption (T5)

T5 uses span corruption: mask contiguous spans and reconstruct them:

`python

# Original: "The quick brown fox jumps over the lazy dog"

# Corrupted: "The fox jumps dog"

# Target: " quick brown over the lazy"

`

Self-Supervised Learning in Computer Vision

Visual self-supervision has developed unique approaches suited to image data.

Pretext Tasks for Vision

Rotation Prediction: Predict which rotation (0°, 90°, 180°, 270°) was applied:

`python

class RotationPredictor(nn.Module):

def __init__(self, encoder):

super().__init__()

self.encoder = encoder

self.classifier = nn.Linear(encoder.output_dim, 4)

def forward(self, x, rotation_labels):

# Rotate images

rotated = self.apply_rotations(x, rotation_labels)

# Predict rotation

features = self.encoder(rotated)

predictions = self.classifier(features)

return predictions

`

Jigsaw Puzzle: Divide image into patches, shuffle, predict permutation:

`python

def create_jigsaw_puzzle(image, num_patches=9):

"""Split image into patches and shuffle."""

patches = split_into_patches(image, int(num_patches**0.5))

# Create a random permutation

perm = torch.randperm(num_patches)

shuffled = patches[perm]

return shuffled, perm

`

Colorization: Convert to grayscale, predict colors:

`python

class ColorizationNetwork(nn.Module):

def __init__(self):

super().__init__()

self.encoder = ResNetEncoder()

self.decoder = UNetDecoder()

def forward(self, grayscale_image):

features = self.encoder(grayscale_image)

ab_channels = self.decoder(features) # Predict a, b in Lab color space

return ab_channels

`

Inpainting: Mask regions, predict missing content:

`python

def create_inpainting_task(image, mask_ratio=0.25):

"""Create masked image for inpainting."""

mask = create_random_mask(image.shape, mask_ratio)

masked_image = image * (1 - mask)

return masked_image, image, mask

`

Contrastive Learning

Contrastive learning has become the dominant paradigm for visual self-supervision. The core idea: learn representations where similar samples are close and dissimilar samples are far apart.

The Contrastive Learning Framework

  1. Create two augmented views of each image
  2. Encode both views
  3. Pull together representations of the same image (positives)
  4. Push apart representations of different images (negatives)

SimCLR

SimCLR (Simple Framework for Contrastive Learning) established a strong baseline:

`python

class SimCLR(nn.Module):

def __init__(self, encoder, projection_dim=128, temperature=0.5):

super().__init__()

self.encoder = encoder

self.projector = nn.Sequential(

nn.Linear(encoder.output_dim, encoder.output_dim),

nn.ReLU(),

nn.Linear(encoder.output_dim, projection_dim)

)

self.temperature = temperature

def forward(self, x1, x2):

# Encode both views

h1 = self.encoder(x1)

h2 = self.encoder(x2)

# Project to contrastive space

z1 = self.projector(h1)

z2 = self.projector(h2)

return z1, z2

def contrastive_loss(self, z1, z2):

batch_size = z1.size(0)

# Normalize

z1 = F.normalize(z1, dim=1)

z2 = F.normalize(z2, dim=1)

# Concatenate all projections

z = torch.cat([z1, z2], dim=0)

# Compute similarity matrix

sim = torch.mm(z, z.t()) / self.temperature

# Create labels (positive pairs)

labels = torch.arange(batch_size, device=z.device)

labels = torch.cat([labels + batch_size, labels])

# Mask self-similarity

mask = torch.eye(2 * batch_size, device=z.device).bool()

sim.masked_fill_(mask, -float('inf'))

# NT-Xent loss

loss = F.cross_entropy(sim, labels)

return loss

`

Data Augmentation in Contrastive Learning

Strong augmentations are crucial. SimCLR uses:

`python

contrastive_transforms = transforms.Compose([

transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),

transforms.RandomHorizontalFlip(),

transforms.RandomApply([

transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)

], p=0.8),

transforms.RandomGrayscale(p=0.2),

transforms.GaussianBlur(kernel_size=23),

transforms.ToTensor(),

transforms.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

])

`

MoCo (Momentum Contrast)

MoCo addresses the need for large negative samples using a memory bank and momentum encoder:

`python

class MoCo(nn.Module):

def __init__(self, encoder, dim=128, K=65536, m=0.999, T=0.07):

super().__init__()

self.K = K # Queue size

self.m = m # Momentum coefficient

self.T = T # Temperature

# Query encoder

self.encoder_q = encoder

self.projector_q = nn.Sequential(

nn.Linear(encoder.output_dim, dim)

)

# Key encoder (momentum updated)

self.encoder_k = copy.deepcopy(encoder)

self.projector_k = copy.deepcopy(self.projector_q)

# Freeze key encoder

for param in self.encoder_k.parameters():

param.requires_grad = False

for param in self.projector_k.parameters():

param.requires_grad = False

# Queue of negative samples

self.register_buffer("queue", torch.randn(dim, K))

self.queue = F.normalize(self.queue, dim=0)

self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

@torch.no_grad()

def momentum_update(self):

"""Update key encoder with momentum."""

for param_q, param_k in zip(

self.encoder_q.parameters(),

self.encoder_k.parameters()

):

param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)

def forward(self, x_q, x_k):

# Query features

q = self.projector_q(self.encoder_q(x_q))

q = F.normalize(q, dim=1)

# Key features (no gradient)

with torch.no_grad():

self.momentum_update()

k = self.projector_k(self.encoder_k(x_k))

k = F.normalize(k, dim=1)

# Positive logits: Nx1

l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

# Negative logits: NxK

l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

# Logits: Nx(1+K)

logits = torch.cat([l_pos, l_neg], dim=1) / self.T

# Labels: positive is always first

labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

# Update queue

self.update_queue(k)

return F.cross_entropy(logits, labels)

`

BYOL (Bootstrap Your Own Latent)

BYOL shows that negative samples aren't strictly necessary:

`python

class BYOL(nn.Module):

def __init__(self, encoder, projection_dim=256, hidden_dim=4096):

super().__init__()

# Online network

self.online_encoder = encoder

self.online_projector = nn.Sequential(

nn.Linear(encoder.output_dim, hidden_dim),

nn.BatchNorm1d(hidden_dim),

nn.ReLU(),

nn.Linear(hidden_dim, projection_dim)

)

self.predictor = nn.Sequential(

nn.Linear(projection_dim, hidden_dim),

nn.BatchNorm1d(hidden_dim),

nn.ReLU(),

nn.Linear(hidden_dim, projection_dim)

)

# Target network (momentum updated)

self.target_encoder = copy.deepcopy(encoder)

self.target_projector = copy.deepcopy(self.online_projector)

for param in self.target_encoder.parameters():

param.requires_grad = False

for param in self.target_projector.parameters():

param.requires_grad = False

def forward(self, x1, x2):

# Online predictions

online_proj_1 = self.online_projector(self.online_encoder(x1))

online_proj_2 = self.online_projector(self.online_encoder(x2))

online_pred_1 = self.predictor(online_proj_1)

online_pred_2 = self.predictor(online_proj_2)

# Target projections (no gradient)

with torch.no_grad():

target_proj_1 = self.target_projector(self.target_encoder(x1))

target_proj_2 = self.target_projector(self.target_encoder(x2))

# Symmetrized loss

loss = self.regression_loss(online_pred_1, target_proj_2.detach())

loss += self.regression_loss(online_pred_2, target_proj_1.detach())

return loss.mean()

def regression_loss(self, x, y):

x = F.normalize(x, dim=-1)

y = F.normalize(y, dim=-1)

return 2 - 2 * (x * y).sum(dim=-1)

`

Masked Image Modeling

Inspired by BERT's success, masked image modeling has emerged as a powerful approach.

Masked Autoencoders (MAE)

MAE masks random patches and reconstructs the pixels:

`python

class MAE(nn.Module):

def __init__(self, encoder, decoder, mask_ratio=0.75):

super().__init__()

self.encoder = encoder

self.decoder = decoder

self.mask_ratio = mask_ratio

def random_masking(self, x, mask_ratio):

N, L, D = x.shape # batch, length, dim

len_keep = int(L * (1 - mask_ratio))

# Random shuffle

noise = torch.rand(N, L, device=x.device)

ids_shuffle = torch.argsort(noise, dim=1)

ids_restore = torch.argsort(ids_shuffle, dim=1)

# Keep first len_keep patches

ids_keep = ids_shuffle[:, :len_keep]

x_masked = torch.gather(x, dim=1,

index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# Generate mask

mask = torch.ones([N, L], device=x.device)

mask[:, :len_keep] = 0

mask = torch.gather(mask, dim=1, index=ids_restore)

return x_masked, mask, ids_restore

def forward(self, x):

# Patchify

patches = self.patchify(x)

# Embed patches

x = self.patch_embed(patches)

# Random masking

x_masked, mask, ids_restore = self.random_masking(x, self.mask_ratio)

# Encode visible patches

latent = self.encoder(x_masked)

# Decode full set of patches

pred = self.decoder(latent, ids_restore)

# Loss only on masked patches

loss = self.reconstruction_loss(patches, pred, mask)

return loss, pred, mask

`

BEiT (BERT Pre-Training of Image Transformers)

BEiT discretizes image patches into tokens using a learned vocabulary:

`python

class BEiT(nn.Module):

def __init__(self, encoder, tokenizer, vocab_size=8192):

super().__init__()

self.encoder = encoder

self.tokenizer = tokenizer # dVAE to convert patches to tokens

self.vocab_size = vocab_size

self.lm_head = nn.Linear(encoder.output_dim, vocab_size)

def forward(self, x, mask):

# Get discrete tokens for all patches

with torch.no_grad():

tokens = self.tokenizer.tokenize(x)

# Mask some patches

x_masked = self.apply_mask(x, mask)

# Encode

hidden = self.encoder(x_masked)

# Predict tokens for masked patches

predictions = self.lm_head(hidden)

# Cross-entropy loss on masked positions

loss = F.cross_entropy(

predictions[mask].view(-1, self.vocab_size),

tokens[mask].view(-1)

)

return loss

`

Self-Supervised Learning for Other Modalities

Audio

Wav2Vec 2.0: Contrastive learning on speech

  • Mask portions of audio
  • Predict masked segments from context
  • Quantize audio into discrete units

Audio MAE: Masked autoencoding for spectrograms

Video

VideoMAE: Mask and reconstruct video frames

TimeSformer: Self-attention across space and time

Contrastive Video Representation Learning: Multiple views of the same video

Multimodal

CLIP: Contrastive learning between images and text

`python

class CLIP(nn.Module):

def __init__(self, image_encoder, text_encoder, embed_dim):

super().__init__()

self.image_encoder = image_encoder

self.text_encoder = text_encoder

self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)

self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)

self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

def forward(self, images, texts):

# Encode

image_features = self.image_projection(self.image_encoder(images))

text_features = self.text_projection(self.text_encoder(texts))

# Normalize

image_features = F.normalize(image_features, dim=-1)

text_features = F.normalize(text_features, dim=-1)

# Contrastive loss

logits = torch.matmul(image_features, text_features.t()) * self.temperature.exp()

labels = torch.arange(len(images), device=logits.device)

loss_i2t = F.cross_entropy(logits, labels)

loss_t2i = F.cross_entropy(logits.t(), labels)

return (loss_i2t + loss_t2i) / 2

`

Evaluation and Transfer

Linear Probing

Freeze the pretrained encoder, train only a linear classifier:

`python

def linear_probe(pretrained_encoder, train_data, num_classes):

# Freeze encoder

for param in pretrained_encoder.parameters():

param.requires_grad = False

# Add linear classifier

classifier = nn.Linear(pretrained_encoder.output_dim, num_classes)

# Train only classifier

optimizer = optim.Adam(classifier.parameters())

for images, labels in train_data:

with torch.no_grad():

features = pretrained_encoder(images)

logits = classifier(features)

loss = F.cross_entropy(logits, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

`

Fine-Tuning

Unfreeze and train the entire network on downstream task:

`python

def fine_tune(pretrained_encoder, train_data, num_classes):

# Add classifier head

model = nn.Sequential(

pretrained_encoder,

nn.Linear(pretrained_encoder.output_dim, num_classes)

)

# Lower learning rate for pretrained layers

optimizer = optim.Adam([

{'params': pretrained_encoder.parameters(), 'lr': 1e-5},

{'params': model[-1].parameters(), 'lr': 1e-3}

])

# Train

for images, labels in train_data:

logits = model(images)

loss = F.cross_entropy(logits, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

`

Few-Shot Evaluation

Test with very limited labeled data (1-10 examples per class):

`python

def few_shot_evaluate(encoder, support_set, query_set):

"""Nearest neighbor classification for few-shot learning."""

with torch.no_grad():

# Encode support set

support_features = encoder(support_set['images'])

support_features = F.normalize(support_features, dim=1)

# Encode query set

query_features = encoder(query_set['images'])

query_features = F.normalize(query_features, dim=1)

# Nearest neighbor prediction

similarity = torch.mm(query_features, support_features.t())

predictions = support_set['labels'][similarity.argmax(dim=1)]

accuracy = (predictions == query_set['labels']).float().mean()

return accuracy

Best Practices

Training Tips

  1. Large batch sizes: Contrastive learning benefits from more negatives
  2. Strong augmentation: Critical for contrastive methods
  3. Long training: Self-supervised methods often need 100-400 epochs
  4. Proper momentum: For momentum-based methods, 0.99-0.999 works well
  5. Learning rate warmup: Gradual warmup stabilizes training

Architecture Choices

  1. Projection heads: Use MLP projectors, discard after pretraining
  2. Batch normalization: Important for avoiding collapse
  3. Large encoders: Self-supervision scales well with model size

Common Pitfalls

Representation Collapse: All representations become identical

  • Solution: Use asymmetric architectures, stop gradients, or negatives

Shortcut Learning: Model learns trivial features

  • Solution: Strong augmentations, careful pretext task design

Overfitting to Pretext Task: Representations don’t transfer

  • Solution: Evaluate on downstream tasks, adjust pretext complexity

Conclusion

Self-supervised learning has transformed how we train neural networks, enabling models to learn from the vast amounts of unlabeled data available. From language modeling in NLP to contrastive learning and masked autoencoders in vision, these techniques have achieved remarkable results.

Key takeaways:

  1. Learning from data structure: Self-supervision creates labels from data itself
  2. Pretext tasks matter: Task design significantly affects representation quality
  3. Contrastive learning: Learn by comparing similar and dissimilar samples
  4. Masked prediction: Reconstruct masked portions of input
  5. Transfer is key: Success measured by downstream task performance
  6. Scale is important: Benefits from large models and extensive pretraining

As data continues to grow faster than our ability to label it, self-supervised learning will only become more important. Understanding these techniques is essential for anyone working in modern AI.

Leave a Reply

Your email address will not be published. Required fields are marked *