Few-shot learning represents one of the most practical and challenging frontiers in machine learning. While deep learning has achieved remarkable success, it typically requires thousands or millions of labeled examples. Few-shot learning tackles the realistic scenario where only a handful of examples are available for new classes. This comprehensive guide explores the techniques, algorithms, and practical applications of few-shot learning.

The Few-Shot Learning Problem

Definition and Motivation

Few-shot learning aims to learn new concepts from very few examples:

  • 1-shot learning: Learn from a single example per class
  • 5-shot learning: Learn from five examples per class
  • Zero-shot learning: Learn without any examples (using auxiliary information)

This mirrors human cognitive abilities—we can recognize a new animal after seeing just one picture, or understand a new word from a single definition.

Why Traditional Deep Learning Fails

Deep neural networks struggle with limited data for several reasons:

  1. High capacity: Modern networks have millions of parameters that overfit easily
  2. No prior structure: They learn everything from scratch without leveraging prior knowledge
  3. Gradient-based optimization: Requires many iterations to converge
  4. Data augmentation limits: Can only stretch limited data so far

python

# Demonstration of overfitting with limited data

def train_with_limited_data(model, few_examples):

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(1000):

outputs = model(few_examples['images'])

loss = F.cross_entropy(outputs, few_examples['labels'])

optimizer.zero_grad()

loss.backward()

optimizer.step()

# Training accuracy quickly reaches 100%

# But test accuracy remains poor - overfitting!

`

The N-way K-shot Setting

Few-shot learning is typically evaluated in the N-way K-shot setting:

  • N-way: N new classes to distinguish
  • K-shot: K examples per class in the support set
  • Query set: Examples to classify after seeing the support set

`python

# 5-way 1-shot example

n_way = 5 # 5 classes to distinguish

k_shot = 1 # 1 example per class

n_query = 15 # 15 query examples to classify

# Total: 5 support examples, 75 query examples

support_set_size = n_way * k_shot # 5

query_set_size = n_way * n_query # 75

`

Transfer Learning Approaches

Pretrain and Fine-tune

The simplest approach: pretrain on large dataset, fine-tune on few examples.

`python

class PretrainFinetune:

def __init__(self, pretrained_model, num_new_classes):

self.encoder = pretrained_model.encoder

# Replace classifier head

self.classifier = nn.Linear(

self.encoder.output_dim,

num_new_classes

)

# Freeze encoder initially

for param in self.encoder.parameters():

param.requires_grad = False

def finetune(self, support_set, epochs=100):

images, labels = support_set

# Unfreeze last few layers

for param in self.encoder.layer4.parameters():

param.requires_grad = True

optimizer = torch.optim.Adam([

{'params': self.encoder.layer4.parameters(), 'lr': 1e-5},

{'params': self.classifier.parameters(), 'lr': 1e-3}

])

for epoch in range(epochs):

features = self.encoder(images)

outputs = self.classifier(features)

loss = F.cross_entropy(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

`

Feature Extraction and Nearest Neighbor

Use pretrained features directly with simple classifier:

`python

class FeatureNN:

def __init__(self, pretrained_encoder):

self.encoder = pretrained_encoder

self.encoder.eval()

def predict(self, support_set, query_images):

support_images, support_labels = support_set

with torch.no_grad():

# Extract features

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

# Normalize

support_features = F.normalize(support_features, dim=1)

query_features = F.normalize(query_features, dim=1)

# Compute cosine similarity

similarity = query_features @ support_features.t()

# Predict nearest neighbor's label

nn_indices = similarity.argmax(dim=1)

predictions = support_labels[nn_indices]

return predictions

`

Linear Probing

Train only a linear classifier on frozen features:

`python

class LinearProbe:

def __init__(self, encoder, num_classes):

self.encoder = encoder

self.encoder.eval()

self.classifier = nn.Linear(encoder.output_dim, num_classes)

def fit(self, support_set, epochs=100):

images, labels = support_set

# Extract features once

with torch.no_grad():

features = self.encoder(images)

# Train linear classifier

optimizer = torch.optim.LBFGS(self.classifier.parameters())

def closure():

optimizer.zero_grad()

outputs = self.classifier(features)

loss = F.cross_entropy(outputs, labels)

loss.backward()

return loss

for _ in range(epochs):

optimizer.step(closure)

def predict(self, query_images):

with torch.no_grad():

features = self.encoder(query_images)

return self.classifier(features).argmax(dim=1)

`

Metric Learning for Few-Shot

Prototypical Networks

Create class prototypes from support examples:

`python

class PrototypicalNetwork(nn.Module):

def __init__(self, encoder):

super().__init__()

self.encoder = encoder

def compute_prototypes(self, support_features, support_labels):

n_way = support_labels.max() + 1

prototypes = torch.zeros(n_way, support_features.size(-1))

for c in range(n_way):

mask = (support_labels == c)

prototypes[c] = support_features[mask].mean(dim=0)

return prototypes

def forward(self, support_images, support_labels, query_images):

# Encode

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

# Compute prototypes

prototypes = self.compute_prototypes(support_features, support_labels)

# Compute distances

dists = torch.cdist(query_features, prototypes)

# Return negative distances as logits

return -dists

def predict(self, support_set, query_images):

support_images, support_labels = support_set

logits = self.forward(support_images, support_labels, query_images)

return logits.argmax(dim=1)

`

Matching Networks with Full Context Embeddings

Use attention over support set with context-aware embeddings:

`python

class MatchingNetwork(nn.Module):

def __init__(self, encoder, hidden_dim):

super().__init__()

self.encoder = encoder

# Bidirectional LSTM for full context embedding

self.support_lstm = nn.LSTM(

encoder.output_dim, hidden_dim,

bidirectional=True, batch_first=True

)

self.query_lstm = nn.LSTM(

encoder.output_dim, hidden_dim,

bidirectional=True, batch_first=True

)

def full_context_embed_support(self, support_features):

"""Use BiLSTM to create context-aware support embeddings."""

# support_features: [N*K, D]

output, _ = self.support_lstm(support_features.unsqueeze(0))

return output.squeeze(0)

def full_context_embed_query(self, query_features, support_context):

"""Attentive embedding of query with support context."""

batch_size = query_features.size(0)

hidden = None

for step in range(3): # Multiple attention steps

# Attention over support

attn = F.softmax(

torch.mm(query_features, support_context.t()),

dim=1

)

read = torch.mm(attn, support_context)

# LSTM step

lstm_input = torch.cat([query_features, read], dim=1)

output, hidden = self.query_lstm(

lstm_input.unsqueeze(1), hidden

)

query_features = output.squeeze(1) + query_features

return query_features

def forward(self, support_images, support_labels, query_images):

# Initial embeddings

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

# Full context embeddings

support_context = self.full_context_embed_support(support_features)

query_context = self.full_context_embed_query(

query_features, support_context

)

# Attention-based classification

attn = F.softmax(

torch.mm(query_context, support_context.t()),

dim=1

)

# Weighted sum of one-hot labels

support_onehot = F.one_hot(support_labels).float()

predictions = torch.mm(attn, support_onehot)

return predictions

`

Induction Networks

Learn class-level representations through induction:

`python

class InductionNetwork(nn.Module):

def __init__(self, encoder, relation_dim):

super().__init__()

self.encoder = encoder

# Induction module: aggregate class features

self.induction = nn.Sequential(

nn.Linear(encoder.output_dim, relation_dim),

nn.ReLU(),

nn.Linear(relation_dim, encoder.output_dim)

)

# Relation module

self.relation = nn.Sequential(

nn.Linear(encoder.output_dim * 2, relation_dim),

nn.ReLU(),

nn.Linear(relation_dim, 1),

nn.Sigmoid()

)

def forward(self, support_images, support_labels, query_images):

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

n_way = support_labels.max() + 1

n_query = query_images.size(0)

# Induce class representations

class_vectors = []

for c in range(n_way):

mask = (support_labels == c)

class_features = support_features[mask]

# Dynamic routing or attention-based aggregation

induced = self.induction(class_features.mean(dim=0))

class_vectors.append(induced)

class_vectors = torch.stack(class_vectors) # [N, D]

# Compute relations

relations = torch.zeros(n_query, n_way)

for i, query in enumerate(query_features):

for c, class_vec in enumerate(class_vectors):

pair = torch.cat([query, class_vec])

relations[i, c] = self.relation(pair)

return relations

`

Data Augmentation for Few-Shot

Traditional Augmentation

Apply heavy augmentation to stretch limited data:

`python

few_shot_augmentation = transforms.Compose([

transforms.RandomResizedCrop(84, scale=(0.5, 1.0)),

transforms.RandomHorizontalFlip(),

transforms.RandomRotation(15),

transforms.ColorJitter(

brightness=0.4, contrast=0.4,

saturation=0.4, hue=0.1

),

transforms.RandomGrayscale(p=0.1),

transforms.ToTensor(),

transforms.Normalize(mean, std)

])

def augment_support_set(support_set, augmentation, n_augment=10):

"""Create augmented versions of support examples."""

images, labels = support_set

augmented_images = []

augmented_labels = []

for img, label in zip(images, labels):

for _ in range(n_augment):

aug_img = augmentation(img)

augmented_images.append(aug_img)

augmented_labels.append(label)

return torch.stack(augmented_images), torch.tensor(augmented_labels)

`

Learned Augmentation

Learn task-specific augmentations:

`python

class MetaAugmentation(nn.Module):

"""Learn to generate useful augmented examples."""

def __init__(self, feature_dim, augment_dim):

super().__init__()

# Transformation network

self.transform_net = nn.Sequential(

nn.Linear(feature_dim + augment_dim, feature_dim),

nn.ReLU(),

nn.Linear(feature_dim, feature_dim)

)

def forward(self, features, n_augment=5):

batch_size = features.size(0)

augmented = [features]

for _ in range(n_augment):

# Random augmentation code

z = torch.randn(batch_size, self.augment_dim)

# Generate augmented features

input_aug = torch.cat([features, z], dim=1)

aug_features = self.transform_net(input_aug)

augmented.append(aug_features)

return torch.cat(augmented, dim=0)

`

Hallucination Networks

Generate synthetic examples from support set:

`python

class HallucinationNetwork(nn.Module):

"""Generate synthetic examples for each class."""

def __init__(self, encoder, generator):

super().__init__()

self.encoder = encoder

self.generator = generator

def hallucinate(self, support_images, support_labels, n_hallucinate=10):

support_features = self.encoder(support_images)

n_way = support_labels.max() + 1

hallucinated_features = []

hallucinated_labels = []

for c in range(n_way):

mask = (support_labels == c)

class_features = support_features[mask]

# Generate synthetic features

for _ in range(n_hallucinate):

noise = torch.randn_like(class_features[0])

synthetic = self.generator(class_features.mean(0), noise)

hallucinated_features.append(synthetic)

hallucinated_labels.append(c)

return (

torch.stack(hallucinated_features),

torch.tensor(hallucinated_labels)

)

`

Transductive Few-Shot Learning

Use query set statistics during inference:

`python

class TransductiveFewShot(nn.Module):

"""Leverage unlabeled query examples for better predictions."""

def __init__(self, encoder):

super().__init__()

self.encoder = encoder

def forward(self, support_images, support_labels, query_images,

n_iterations=10):

# Encode all examples

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

n_way = support_labels.max() + 1

# Initial prototypes from support set

prototypes = self.compute_prototypes(

support_features, support_labels, n_way

)

# Iteratively refine with query set (soft labels)

for iteration in range(n_iterations):

# Compute soft assignments for queries

dists = torch.cdist(query_features, prototypes)

soft_labels = F.softmax(-dists, dim=1)

# Recompute prototypes including soft-labeled queries

new_prototypes = []

for c in range(n_way):

# Support contribution

support_mask = (support_labels == c)

support_contrib = support_features[support_mask].sum(0)

support_count = support_mask.sum()

# Query contribution (weighted)

query_weights = soft_labels[:, c]

query_contrib = (query_features * query_weights.unsqueeze(1)).sum(0)

query_count = query_weights.sum()

# Combined prototype

new_proto = (support_contrib + query_contrib) / (

support_count + query_count

)

new_prototypes.append(new_proto)

prototypes = torch.stack(new_prototypes)

# Final predictions

dists = torch.cdist(query_features, prototypes)

return -dists

def compute_prototypes(self, features, labels, n_way):

prototypes = []

for c in range(n_way):

mask = (labels == c)

prototypes.append(features[mask].mean(0))

return torch.stack(prototypes)

`

Label Propagation

Propagate labels through feature similarity graph:

`python

class LabelPropagation:

def __init__(self, encoder, alpha=0.5, n_iterations=20):

self.encoder = encoder

self.alpha = alpha

self.n_iterations = n_iterations

def predict(self, support_set, query_images):

support_images, support_labels = support_set

n_support = support_images.size(0)

n_query = query_images.size(0)

n_way = support_labels.max() + 1

# Encode all

all_images = torch.cat([support_images, query_images])

all_features = self.encoder(all_images)

all_features = F.normalize(all_features, dim=1)

# Build affinity matrix

W = all_features @ all_features.t()

W = F.softmax(W / 0.1, dim=1) # Temperature-scaled softmax

# Initialize labels (one-hot for support, zeros for query)

Y = torch.zeros(n_support + n_query, n_way)

for i, label in enumerate(support_labels):

Y[i, label] = 1.0

# Label propagation iterations

for _ in range(self.n_iterations):

Y_new = self.alpha * W @ Y + (1 - self.alpha) * Y

# Clamp support labels

for i, label in enumerate(support_labels):

Y_new[i] = 0

Y_new[i, label] = 1.0

Y = Y_new

# Extract query predictions

query_preds = Y[n_support:].argmax(dim=1)

return query_preds

`

Cross-Domain Few-Shot Learning

Handle domain shift between training and testing:

`python

class CrossDomainFewShot(nn.Module):

def __init__(self, encoder, domain_adapter):

super().__init__()

self.encoder = encoder

self.domain_adapter = domain_adapter

def adapt_features(self, features, domain_stats):

"""Adapt features to target domain statistics."""

# Compute current stats

mean = features.mean(dim=0)

std = features.std(dim=0)

# Normalize

normalized = (features - mean) / (std + 1e-5)

# Apply target domain stats

adapted = normalized * domain_stats['std'] + domain_stats['mean']

return adapted

def forward(self, support_images, support_labels, query_images):

# Encode

support_features = self.encoder(support_images)

query_features = self.encoder(query_images)

# Estimate target domain statistics from support + query

all_features = torch.cat([support_features, query_features])

domain_stats = {

'mean': all_features.mean(dim=0),

'std': all_features.std(dim=0)

}

# Domain adaptation

adapted_support = self.domain_adapter(support_features, domain_stats)

adapted_query = self.domain_adapter(query_features, domain_stats)

# Prototypical classification

prototypes = self.compute_prototypes(adapted_support, support_labels)

dists = torch.cdist(adapted_query, prototypes)

return -dists

`

Practical Implementation

Training Pipeline

`python

def train_few_shot_model(model, train_dataset, val_dataset, config):

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

scheduler = torch.optim.lr_scheduler.StepLR(

optimizer, step_size=20, gamma=0.5

)

best_accuracy = 0

for epoch in range(config['epochs']):

model.train()

train_loss = 0

for episode in range(config['episodes_per_epoch']):

# Sample task

task = sample_task(

train_dataset,

config['n_way'],

config['k_shot'],

config['n_query']

)

# Forward pass

logits = model(

task.support_images,

task.support_labels,

task.query_images

)

loss = F.cross_entropy(logits, task.query_labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

train_loss += loss.item()

scheduler.step()

# Validation

val_accuracy = evaluate(

model, val_dataset,

config['n_way'], config['k_shot'],

num_episodes=600

)

if val_accuracy > best_accuracy:

best_accuracy = val_accuracy

torch.save(model.state_dict(), 'best_model.pt')

print(f"Epoch {epoch}: Loss={train_loss/config['episodes_per_epoch']:.4f}, "

f"Val Acc={val_accuracy:.2f}%")

def evaluate(model, dataset, n_way, k_shot, num_episodes=1000):

model.eval()

accuracies = []

with torch.no_grad():

for _ in range(num_episodes):

task = sample_task(dataset, n_way, k_shot, n_query=15)

predictions = model.predict(

(task.support_images, task.support_labels),

task.query_images

)

accuracy = (predictions == task.query_labels).float().mean()

accuracies.append(accuracy.item())

return np.mean(accuracies) * 100

`

Ensemble Methods

Combine multiple few-shot models:

`python

class FewShotEnsemble:

def __init__(self, models):

self.models = models

def predict(self, support_set, query_images):

all_logits = []

for model in self.models:

model.eval()

with torch.no_grad():

logits = model(

support_set[0], support_set[1], query_images

)

all_logits.append(F.softmax(logits, dim=1))

# Average predictions

ensemble_probs = torch.stack(all_logits).mean(dim=0)

return ensemble_probs.argmax(dim=1)

`

Applications

Medical Imaging

Diagnose rare diseases with few examples:

`python

# Rare disease classification

# Support: 5 scans showing the rare condition

# Query: New patient scans to classify

`

Robotics

Quick adaptation to new objects or tasks:

`python

# Object manipulation

# Support: 3 demonstrations of grasping new object

# Query: Grasp the object in new orientations

`

Personalized AI

Adapt to user preferences with minimal examples:

`python

# Content recommendation

# Support: User's 5 explicitly rated items

# Query: Predict preferences for unseen items

`

Quality Control

Detect new types of defects:

`python

# Manufacturing defect detection

# Support: 3-5 examples of new defect type

# Query: Identify defective products

`

Wildlife Monitoring

Identify rare species:

`python

# Species identification

# Support: Few images of endangered species

# Query: Classify camera trap images

`

Benchmarks and Results

Common Benchmarks

| Dataset | Classes | Images | Typical Results (5-way 5-shot) |

|---------|---------|--------|-------------------------------|

| Omniglot | 1,623 | 32,460 | ~99% |

| miniImageNet | 100 | 60,000 | ~70-80% |

| tieredImageNet | 608 | 779,165 | ~72-82% |

| CIFAR-FS | 100 | 60,000 | ~75-85% |

| CUB-200 | 200 | 11,788 | ~80-88% |

State-of-the-Art Methods

`python

# Approximate performance on miniImageNet (5-way 5-shot)

results = {

'Prototypical Networks': 68.2,

'Matching Networks': 65.7,

'MAML': 63.1,

'Relation Networks': 67.1,

'MetaOptNet': 78.6,

'DeepEMD': 75.6,

'FEAT': 78.5,

'SimpleShot (transfer)': 70.1,

}

Conclusion

Few-shot learning bridges the gap between data-hungry deep learning and human-like rapid learning. By leveraging meta-learning, metric learning, and careful data augmentation, few-shot methods enable AI systems to generalize from minimal examples.

Key takeaways:

  1. Problem setting: N-way K-shot classification with support and query sets
  2. Transfer learning: Pretrained features provide strong baselines
  3. Metric learning: Learn embedding spaces for comparison-based classification
  4. Transductive methods: Leverage query set statistics for improvement
  5. Data augmentation: Critical for stretching limited examples
  6. Cross-domain: Additional challenges when training and test domains differ

Few-shot learning is essential for deploying AI in domains where labeled data is scarce, expensive, or constantly evolving. As AI systems are expected to handle more diverse and dynamic scenarios, few-shot learning capabilities become increasingly important.

Leave a Reply

Your email address will not be published. Required fields are marked *